Skip to content
Snippets Groups Projects
Commit 60d37b5b authored by Florian Zill's avatar Florian Zill
Browse files

[meshplotlib] plot streamlines update

parent 6fc92a3c
No related branches found
No related tags found
No related merge requests found
"""Specialized plot features."""
from typing import Callable
from typing import Callable, Optional
import matplotlib.pyplot as plt
import numpy as np
......@@ -79,43 +79,71 @@ def plot_element_edges(ax: plt.Axes, surf: pv.DataSet, projection: int) -> None:
def plot_streamlines(
ax: plt.Axes, surf: pv.DataSet, property: Vector, projection: int
ax: plt.Axes,
mesh: pv.DataSet,
property: Vector,
projection: Optional[int] = None,
arrows: bool = False,
) -> None:
"""Plot vector streamlines on a matplotlib axis."""
if (n_pts := setup.num_streamline_interp_pts) is None:
return
if arrows:
n_pts = 50
if projection is None:
mean_normal = np.abs(
np.mean(mesh.extract_surface().cell_normals, axis=0)
)
projection = int(np.argmax(mean_normal))
x_id, y_id = np.delete([0, 1, 2], projection)
bounds = [float(b) for b in surf.bounds]
bounds = [float(b) for b in mesh.bounds]
x = np.linspace(bounds[2 * x_id], bounds[2 * x_id + 1], n_pts)
y = np.linspace(bounds[2 * y_id], bounds[2 * y_id + 1], n_pts)
z = np.array([np.mean(surf.points[..., projection])])
z = np.array([np.mean(mesh.points[..., projection])])
_surf = surf
for key in _surf.point_data:
_mesh = mesh.copy()
for key in _mesh.point_data:
if key not in [property.data_name, property.mask]:
del _surf.point_data[key]
del _mesh.point_data[key]
grid = pv.RectilinearGrid(
[x, y, z][x_id], [x, y, z][y_id], [x, y, z][projection]
)
grid = grid.sample(_surf, pass_cell_data=False)
if np.shape(grid.point_data[property.data_name])[-1] == 3:
grid.point_data[property.data_name] = np.delete(
grid.point_data[property.data_name], projection, 1
)
val = np.reshape(
property(grid.point_data[property.data_name]),
(n_pts, n_pts, 2),
)
grid = grid.sample(_mesh, pass_cell_data=False)
values = property(grid.point_data[property.data_name])
values[np.argwhere(grid["vtkValidPointMask"] == 0), :] = np.nan
if np.shape(values)[-1] == 3:
values = np.delete(values, projection, 1)
val = np.reshape(values, (n_pts, n_pts, 2))
if property.mask in grid.point_data:
mask = np.reshape(grid.point_data[property.mask], (n_pts, n_pts))
val[mask == 0, :] = 0
val_norm = np.linalg.norm(val, axis=-1)
val_norm = np.linalg.norm(np.nan_to_num(val), axis=-1)
lw = 2.5 * val_norm / max(1e-16, np.max(val_norm))
lw *= setup.rcParams_scaled["lines.linewidth"]
x_g, y_g = setup.length(np.meshgrid(x, y))
ax.streamplot(x_g, y_g, val[..., 0], val[..., 1], color="k", linewidth=lw)
if arrows:
ax.quiver(
x_g,
y_g,
val[..., 0],
val[..., 1],
scale=1 / 0.03,
headlength=0,
headaxislength=0,
headwidth=1,
pivot="mid",
)
else:
ax.streamplot(
x_g,
y_g,
val[..., 0],
val[..., 1],
color="k",
linewidth=lw,
density=1.5,
)
def plot_on_top(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment