From 60d37b5bc934978f05954a97954240abab5e1e42 Mon Sep 17 00:00:00 2001 From: FZill <florian.zill@ufz.de> Date: Fri, 23 Feb 2024 17:51:03 +0100 Subject: [PATCH] [meshplotlib] plot streamlines update --- ogstools/meshplotlib/plot_features.py | 66 +++++++++++++++++++-------- 1 file changed, 47 insertions(+), 19 deletions(-) diff --git a/ogstools/meshplotlib/plot_features.py b/ogstools/meshplotlib/plot_features.py index 8d1baeca8..41a8d6b9e 100644 --- a/ogstools/meshplotlib/plot_features.py +++ b/ogstools/meshplotlib/plot_features.py @@ -1,6 +1,6 @@ """Specialized plot features.""" -from typing import Callable +from typing import Callable, Optional import matplotlib.pyplot as plt import numpy as np @@ -79,43 +79,71 @@ def plot_element_edges(ax: plt.Axes, surf: pv.DataSet, projection: int) -> None: def plot_streamlines( - ax: plt.Axes, surf: pv.DataSet, property: Vector, projection: int + ax: plt.Axes, + mesh: pv.DataSet, + property: Vector, + projection: Optional[int] = None, + arrows: bool = False, ) -> None: """Plot vector streamlines on a matplotlib axis.""" if (n_pts := setup.num_streamline_interp_pts) is None: return + if arrows: + n_pts = 50 + if projection is None: + mean_normal = np.abs( + np.mean(mesh.extract_surface().cell_normals, axis=0) + ) + projection = int(np.argmax(mean_normal)) x_id, y_id = np.delete([0, 1, 2], projection) - bounds = [float(b) for b in surf.bounds] + bounds = [float(b) for b in mesh.bounds] x = np.linspace(bounds[2 * x_id], bounds[2 * x_id + 1], n_pts) y = np.linspace(bounds[2 * y_id], bounds[2 * y_id + 1], n_pts) - z = np.array([np.mean(surf.points[..., projection])]) + z = np.array([np.mean(mesh.points[..., projection])]) - _surf = surf - for key in _surf.point_data: + _mesh = mesh.copy() + for key in _mesh.point_data: if key not in [property.data_name, property.mask]: - del _surf.point_data[key] + del _mesh.point_data[key] grid = pv.RectilinearGrid( [x, y, z][x_id], [x, y, z][y_id], [x, y, z][projection] ) - grid = grid.sample(_surf, pass_cell_data=False) - if np.shape(grid.point_data[property.data_name])[-1] == 3: - grid.point_data[property.data_name] = np.delete( - grid.point_data[property.data_name], projection, 1 - ) - val = np.reshape( - property(grid.point_data[property.data_name]), - (n_pts, n_pts, 2), - ) + grid = grid.sample(_mesh, pass_cell_data=False) + values = property(grid.point_data[property.data_name]) + values[np.argwhere(grid["vtkValidPointMask"] == 0), :] = np.nan + if np.shape(values)[-1] == 3: + values = np.delete(values, projection, 1) + val = np.reshape(values, (n_pts, n_pts, 2)) if property.mask in grid.point_data: mask = np.reshape(grid.point_data[property.mask], (n_pts, n_pts)) val[mask == 0, :] = 0 - val_norm = np.linalg.norm(val, axis=-1) + val_norm = np.linalg.norm(np.nan_to_num(val), axis=-1) lw = 2.5 * val_norm / max(1e-16, np.max(val_norm)) lw *= setup.rcParams_scaled["lines.linewidth"] - x_g, y_g = setup.length(np.meshgrid(x, y)) - ax.streamplot(x_g, y_g, val[..., 0], val[..., 1], color="k", linewidth=lw) + if arrows: + ax.quiver( + x_g, + y_g, + val[..., 0], + val[..., 1], + scale=1 / 0.03, + headlength=0, + headaxislength=0, + headwidth=1, + pivot="mid", + ) + else: + ax.streamplot( + x_g, + y_g, + val[..., 0], + val[..., 1], + color="k", + linewidth=lw, + density=1.5, + ) def plot_on_top( -- GitLab