Skip to content
Snippets Groups Projects
Commit 6407133e authored by Florian Zill's avatar Florian Zill
Browse files

added plot_probe

parent fd0573fc
No related branches found
No related tags found
No related merge requests found
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
from .plot_setup import _setup as setup # noqa: I001: noqa from .plot_setup import _setup as setup # noqa: I001: noqa
from .core import plot_diff, plot_limit, plot, subplot from .core import plot_diff, plot_limit, plot_probe, plot, subplot
from .plot_features import plot_contour, plot_on_top from .plot_features import plot_contour, plot_on_top
__all__ = [ __all__ = [
...@@ -13,5 +13,6 @@ __all__ = [ ...@@ -13,5 +13,6 @@ __all__ = [
"plot_diff", "plot_diff",
"plot_limit", "plot_limit",
"plot_on_top", "plot_on_top",
"plot_probe",
"subplot", "subplot",
] ]
...@@ -18,10 +18,12 @@ from matplotlib.patches import Rectangle as Rect ...@@ -18,10 +18,12 @@ from matplotlib.patches import Rectangle as Rect
from ogstools.meshlib import MeshSeries from ogstools.meshlib import MeshSeries
from ogstools.propertylib import Property, Vector from ogstools.propertylib import Property, Vector
from ogstools.propertylib.presets import _resolve_property from ogstools.propertylib.presets import _resolve_property
from ogstools.propertylib.unit_registry import u_reg
from . import plot_features as pf from . import plot_features as pf
from . import setup from . import setup
from .levels import get_levels from .levels import get_levels
from .utils import get_style_cycler, justified_labels
# TODO: define default data_name for regions in setup # TODO: define default data_name for regions in setup
...@@ -517,3 +519,74 @@ def plot_limit( ...@@ -517,3 +519,74 @@ def plot_limit(
output_name=limit + " " + property.output_name output_name=limit + " " + property.output_name
) )
return plot(mesh, limit_property) return plot(mesh, limit_property)
def plot_probe(
mesh_series: MeshSeries,
points: np.ndarray,
mesh_property: Union[Property, str],
labels: Optional[list[str]] = None,
time_unit: Optional[str] = None,
interp_method: Optional[Literal["nearest", "linear", "probefilter"]] = None,
interp_backend_pvd: Optional[Literal["vtk", "scipy"]] = None,
colors: Optional[list] = None,
linestyles: Optional[list] = None,
ax: Optional[plt.Axes] = None,
**kwargs,
) -> mfigure.Figure:
"""
Plot the transient property on the observation points in the MeshSeries.
:param mesh_series: MeshSeries object containing the data to be plotted.
:param points: The points to sample at.
:param mesh_property: The property to be sampled.
:param labels: The labels for each observation point.
:param time_unit: Output unit of the timevalues.
:param interp_method: Choose the interpolation method, defaults to
`linear` for xdmf MeshSeries and `probefilter`
for pvd MeshSeries.
:param interp_backend: Interpolation backend for PVD MeshSeries.
:param kwargs: Keyword arguments passed to the plot function.
:returns: A matplotlib Figure
"""
points = np.asarray(points)
if len(points.shape) == 1:
points = points[np.newaxis]
if isinstance(mesh_property, str):
data_shape = mesh_series.read(0)[mesh_property].shape
mesh_property = _resolve_property(mesh_property, data_shape)
values = mesh_property.magnitude.strip_units(
mesh_series.probe(
points, mesh_property.data_name, interp_method, interp_backend_pvd
)
)
if values.shape[0] == 1:
values = values.flatten()
Q_ = u_reg.Quantity
time_unit_conversion = Q_(Q_(mesh_series.time_unit), time_unit).magnitude
times = time_unit_conversion * mesh_series.timevalues
if ax is None:
fig, ax = plt.subplots()
else:
fig = None
ax.set_prop_cycle(get_style_cycler(len(points), colors, linestyles))
if labels is None:
j_labels = justified_labels(points)
labels = [f"{i}: {label}" for i, label in enumerate(j_labels)]
ax.plot(times, values, label=labels, **kwargs)
ax.legend(facecolor="white", framealpha=1, prop={"family": "monospace"})
time_label = f"time / {time_unit}" if time_unit else "time"
unit_str = (
f" / {mesh_property.get_output_unit()}"
if mesh_property.get_output_unit()
else ""
)
y_label = mesh_property.output_name.replace("_", " ") + unit_str
ax.set_xlabel(time_label)
ax.set_ylabel(y_label)
ax.grid(which="major", color="lightgrey", linestyle="-")
ax.grid(which="minor", color="0.95", linestyle="--")
ax.minorticks_on()
return fig
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment