From bad3d516f6744c264fb7ec3ece567ede8f9b8933 Mon Sep 17 00:00:00 2001
From: FZill <florian.zill@ufz.de>
Date: Wed, 7 Feb 2024 17:12:24 +0100
Subject: [PATCH] refactoring meshplotlib with last changes

---
 ogstools/meshplotlib/core.py | 52 +++++++++++++++++++-----------------
 1 file changed, 28 insertions(+), 24 deletions(-)

diff --git a/ogstools/meshplotlib/core.py b/ogstools/meshplotlib/core.py
index 682f24fdc..54f7755a1 100644
--- a/ogstools/meshplotlib/core.py
+++ b/ogstools/meshplotlib/core.py
@@ -32,14 +32,29 @@ def _q_zero_line(property: Property, levels: np.ndarray):
     )
 
 
-def get_data(mesh: pv.UnstructuredGrid, property: Property) -> np.ndarray:
+def has_masked_values(mesh: pv.UnstructuredGrid, property: Property) -> bool:
+    return (
+        not property.is_mask()
+        and property.mask in mesh.cell_data
+        and (len(mesh.cell_data[property.mask]) != 0)
+    )
+
+
+def get_data(
+    mesh: pv.UnstructuredGrid, property: Property, masked: bool = False
+) -> np.ndarray:
     """Get the data associated with a scalar or vector property from a mesh."""
-    if property.data_name in mesh.point_data:
-        return mesh.point_data[property.data_name]
-    if property.data_name in mesh.cell_data:
-        return mesh.cell_data[property.data_name]
-    msg = f"Property not found in mesh {mesh}."
-    raise IndexError(msg)
+    if (
+        property.data_name not in mesh.point_data
+        and property.data_name not in mesh.cell_data
+    ):
+        msg = f"Property not found in mesh {mesh}."
+        raise IndexError(msg)
+    if masked:
+        return mesh.ctp(True).threshold(value=[1, 1], scalars=property.mask)[
+            property.data_name
+        ]
+    return mesh[property.data_name]
 
 
 def get_level_boundaries(levels: np.ndarray):
@@ -196,11 +211,7 @@ def subplot(
 
     ax.axis("auto")
 
-    if (
-        not property.is_mask()
-        and property.mask in mesh.cell_data
-        and len(mesh.cell_data[property.mask])
-    ):
+    if has_masked_values(mesh, property):
         subplot(mesh, property.get_mask(), ax)
         mesh = mesh.ctp(True).threshold(value=[1, 1], scalars=property.mask)
 
@@ -341,17 +352,9 @@ def get_combined_levels(
     p_min, p_max = np.inf, -np.inf
     unique_vals = np.array([])
     for mesh in np.ravel(meshes):
-        if (
-            not property.is_mask()
-            and property.mask in mesh.cell_data
-            and len(mesh.cell_data[property.mask])
-        ):
-            _mesh = mesh.ctp(True).threshold(
-                value=[1, 1], scalars=property.mask
-            )
-            values = property.magnitude.strip_units(get_data(_mesh, property))
-        else:
-            values = property.magnitude.strip_units(get_data(mesh, property))
+        values = property.magnitude.strip_units(
+            get_data(mesh, property, masked=has_masked_values(mesh, property))
+        )
         if setup.log_scaled:  # TODO: can be improved
             values = np.log10(np.where(values > 1e-14, values, 1e-14))
         p_min = min(p_min, np.nanmin(values)) if setup.p_min is None else p_min
@@ -508,7 +511,8 @@ def plot_limit(
         data_shape = mesh[property].shape
         property = _resolve_property(property, data_shape)
     func = {"min": np.min, "max": np.max}[limit]
-    mesh = mesh_series.reduce_with(func, property.data_name)
+    vals = mesh_series.values(property.data_name)
+    func(vals, out=mesh[property.data_name], axis=0)
     limit_property = property.replace(
         output_name=limit + " " + property.output_name
     )
-- 
GitLab