From 60d37b5bc934978f05954a97954240abab5e1e42 Mon Sep 17 00:00:00 2001
From: FZill <florian.zill@ufz.de>
Date: Fri, 23 Feb 2024 17:51:03 +0100
Subject: [PATCH] [meshplotlib] plot streamlines update

---
 ogstools/meshplotlib/plot_features.py | 66 +++++++++++++++++++--------
 1 file changed, 47 insertions(+), 19 deletions(-)

diff --git a/ogstools/meshplotlib/plot_features.py b/ogstools/meshplotlib/plot_features.py
index 8d1baeca8..41a8d6b9e 100644
--- a/ogstools/meshplotlib/plot_features.py
+++ b/ogstools/meshplotlib/plot_features.py
@@ -1,6 +1,6 @@
 """Specialized plot features."""
 
-from typing import Callable
+from typing import Callable, Optional
 
 import matplotlib.pyplot as plt
 import numpy as np
@@ -79,43 +79,71 @@ def plot_element_edges(ax: plt.Axes, surf: pv.DataSet, projection: int) -> None:
 
 
 def plot_streamlines(
-    ax: plt.Axes, surf: pv.DataSet, property: Vector, projection: int
+    ax: plt.Axes,
+    mesh: pv.DataSet,
+    property: Vector,
+    projection: Optional[int] = None,
+    arrows: bool = False,
 ) -> None:
     """Plot vector streamlines on a matplotlib axis."""
     if (n_pts := setup.num_streamline_interp_pts) is None:
         return
+    if arrows:
+        n_pts = 50
+    if projection is None:
+        mean_normal = np.abs(
+            np.mean(mesh.extract_surface().cell_normals, axis=0)
+        )
+        projection = int(np.argmax(mean_normal))
     x_id, y_id = np.delete([0, 1, 2], projection)
-    bounds = [float(b) for b in surf.bounds]
+    bounds = [float(b) for b in mesh.bounds]
     x = np.linspace(bounds[2 * x_id], bounds[2 * x_id + 1], n_pts)
     y = np.linspace(bounds[2 * y_id], bounds[2 * y_id + 1], n_pts)
-    z = np.array([np.mean(surf.points[..., projection])])
+    z = np.array([np.mean(mesh.points[..., projection])])
 
-    _surf = surf
-    for key in _surf.point_data:
+    _mesh = mesh.copy()
+    for key in _mesh.point_data:
         if key not in [property.data_name, property.mask]:
-            del _surf.point_data[key]
+            del _mesh.point_data[key]
     grid = pv.RectilinearGrid(
         [x, y, z][x_id], [x, y, z][y_id], [x, y, z][projection]
     )
-    grid = grid.sample(_surf, pass_cell_data=False)
-    if np.shape(grid.point_data[property.data_name])[-1] == 3:
-        grid.point_data[property.data_name] = np.delete(
-            grid.point_data[property.data_name], projection, 1
-        )
-    val = np.reshape(
-        property(grid.point_data[property.data_name]),
-        (n_pts, n_pts, 2),
-    )
+    grid = grid.sample(_mesh, pass_cell_data=False)
+    values = property(grid.point_data[property.data_name])
+    values[np.argwhere(grid["vtkValidPointMask"] == 0), :] = np.nan
+    if np.shape(values)[-1] == 3:
+        values = np.delete(values, projection, 1)
+    val = np.reshape(values, (n_pts, n_pts, 2))
 
     if property.mask in grid.point_data:
         mask = np.reshape(grid.point_data[property.mask], (n_pts, n_pts))
         val[mask == 0, :] = 0
-    val_norm = np.linalg.norm(val, axis=-1)
+    val_norm = np.linalg.norm(np.nan_to_num(val), axis=-1)
     lw = 2.5 * val_norm / max(1e-16, np.max(val_norm))
     lw *= setup.rcParams_scaled["lines.linewidth"]
-
     x_g, y_g = setup.length(np.meshgrid(x, y))
-    ax.streamplot(x_g, y_g, val[..., 0], val[..., 1], color="k", linewidth=lw)
+    if arrows:
+        ax.quiver(
+            x_g,
+            y_g,
+            val[..., 0],
+            val[..., 1],
+            scale=1 / 0.03,
+            headlength=0,
+            headaxislength=0,
+            headwidth=1,
+            pivot="mid",
+        )
+    else:
+        ax.streamplot(
+            x_g,
+            y_g,
+            val[..., 0],
+            val[..., 1],
+            color="k",
+            linewidth=lw,
+            density=1.5,
+        )
 
 
 def plot_on_top(
-- 
GitLab