From 6407133eeb1888ee922030d1ad0e257db45e42c8 Mon Sep 17 00:00:00 2001
From: FZill <florian.zill@ufz.de>
Date: Thu, 8 Feb 2024 10:00:36 +0100
Subject: [PATCH] added plot_probe

---
 ogstools/meshplotlib/__init__.py |  3 +-
 ogstools/meshplotlib/core.py     | 73 ++++++++++++++++++++++++++++++++
 2 files changed, 75 insertions(+), 1 deletion(-)

diff --git a/ogstools/meshplotlib/__init__.py b/ogstools/meshplotlib/__init__.py
index 3702477db..ef020e43b 100644
--- a/ogstools/meshplotlib/__init__.py
+++ b/ogstools/meshplotlib/__init__.py
@@ -3,7 +3,7 @@
 
 from .plot_setup import _setup as setup  # noqa: I001: noqa
 
-from .core import plot_diff, plot_limit, plot, subplot
+from .core import plot_diff, plot_limit, plot_probe, plot, subplot
 from .plot_features import plot_contour, plot_on_top
 
 __all__ = [
@@ -13,5 +13,6 @@ __all__ = [
     "plot_diff",
     "plot_limit",
     "plot_on_top",
+    "plot_probe",
     "subplot",
 ]
diff --git a/ogstools/meshplotlib/core.py b/ogstools/meshplotlib/core.py
index 54f7755a1..bf7bb19c6 100644
--- a/ogstools/meshplotlib/core.py
+++ b/ogstools/meshplotlib/core.py
@@ -18,10 +18,12 @@ from matplotlib.patches import Rectangle as Rect
 from ogstools.meshlib import MeshSeries
 from ogstools.propertylib import Property, Vector
 from ogstools.propertylib.presets import _resolve_property
+from ogstools.propertylib.unit_registry import u_reg
 
 from . import plot_features as pf
 from . import setup
 from .levels import get_levels
+from .utils import get_style_cycler, justified_labels
 
 # TODO: define default data_name for regions in setup
 
@@ -517,3 +519,74 @@ def plot_limit(
         output_name=limit + " " + property.output_name
     )
     return plot(mesh, limit_property)
+
+
+def plot_probe(
+    mesh_series: MeshSeries,
+    points: np.ndarray,
+    mesh_property: Union[Property, str],
+    labels: Optional[list[str]] = None,
+    time_unit: Optional[str] = None,
+    interp_method: Optional[Literal["nearest", "linear", "probefilter"]] = None,
+    interp_backend_pvd: Optional[Literal["vtk", "scipy"]] = None,
+    colors: Optional[list] = None,
+    linestyles: Optional[list] = None,
+    ax: Optional[plt.Axes] = None,
+    **kwargs,
+) -> mfigure.Figure:
+    """
+    Plot the transient property on the observation points in the MeshSeries.
+
+        :param mesh_series: MeshSeries object containing the data to be plotted.
+        :param points:          The points to sample at.
+        :param mesh_property:   The property to be sampled.
+        :param labels:          The labels for each observation point.
+        :param time_unit:       Output unit of the timevalues.
+        :param interp_method:   Choose the interpolation method, defaults to
+                                `linear` for xdmf MeshSeries and `probefilter`
+                                for pvd MeshSeries.
+        :param interp_backend:  Interpolation backend for PVD MeshSeries.
+        :param kwargs:          Keyword arguments passed to the plot function.
+
+        :returns:   A matplotlib Figure
+    """
+    points = np.asarray(points)
+    if len(points.shape) == 1:
+        points = points[np.newaxis]
+    if isinstance(mesh_property, str):
+        data_shape = mesh_series.read(0)[mesh_property].shape
+        mesh_property = _resolve_property(mesh_property, data_shape)
+    values = mesh_property.magnitude.strip_units(
+        mesh_series.probe(
+            points, mesh_property.data_name, interp_method, interp_backend_pvd
+        )
+    )
+    if values.shape[0] == 1:
+        values = values.flatten()
+    Q_ = u_reg.Quantity
+    time_unit_conversion = Q_(Q_(mesh_series.time_unit), time_unit).magnitude
+    times = time_unit_conversion * mesh_series.timevalues
+    if ax is None:
+        fig, ax = plt.subplots()
+    else:
+        fig = None
+    ax.set_prop_cycle(get_style_cycler(len(points), colors, linestyles))
+
+    if labels is None:
+        j_labels = justified_labels(points)
+        labels = [f"{i}: {label}" for i, label in enumerate(j_labels)]
+    ax.plot(times, values, label=labels, **kwargs)
+    ax.legend(facecolor="white", framealpha=1, prop={"family": "monospace"})
+    time_label = f"time / {time_unit}" if time_unit else "time"
+    unit_str = (
+        f" / {mesh_property.get_output_unit()}"
+        if mesh_property.get_output_unit()
+        else ""
+    )
+    y_label = mesh_property.output_name.replace("_", " ") + unit_str
+    ax.set_xlabel(time_label)
+    ax.set_ylabel(y_label)
+    ax.grid(which="major", color="lightgrey", linestyle="-")
+    ax.grid(which="minor", color="0.95", linestyle="--")
+    ax.minorticks_on()
+    return fig
-- 
GitLab