diff --git a/docs/examples/howto_meshplotlib/plot_observation_points.py b/docs/examples/howto_meshplotlib/plot_observation_points.py new file mode 100644 index 0000000000000000000000000000000000000000..70b6331cd2875d0deb1cc1173b66de2bdd714021 --- /dev/null +++ b/docs/examples/howto_meshplotlib/plot_observation_points.py @@ -0,0 +1,78 @@ +""" +How to plot data at observation points +====================================== + +.. sectionauthor:: Florian Zill (Helmholtz Centre for Environmental Research GmbH - UFZ) + +In this example we plot the data values on observation points over all +timesteps. Since requested observation points don't necessarily coincide with +actual nodes of the mesh different interpolation options are available. See +:py:mod:`ogstools.meshlib.mesh_series.MeshSeries.probe` for more details. +Here we use a component transport example from the ogs benchmark gallery +(https://www.opengeosys.org/docs/benchmarks/hydro-component/elder/). +""" + +# %% + +# sphinx_gallery_start_ignore + +# sphinx_gallery_thumbnail_number = 2 +# fmt:off + +# sphinx_gallery_end_ignore + +import matplotlib.pyplot as plt +import numpy as np + +from ogstools import meshplotlib +from ogstools.meshplotlib.examples import meshseries_CT_2D as mesh_series +from ogstools.meshplotlib.utils import justified_labels +from ogstools.propertylib import Scalar + +meshplotlib.setup.reset() +si = Scalar( + data_name="Si", data_unit="", output_unit="%", output_name="Saturation" +) +# alternatively: +# from ogstools.meshlib import MeshSeries +# mesh_series = MeshSeries("filepath/filename_pvd_or_xdmf") + +# %% [markdown] +# Let's define 4 observation points and plot them on the mesh. + +# %% +points = np.asarray( + [[x, 0, 60] for x in [0, 40, 80, 120]] + + [[x, 0, 40] for x in [0, 40, 80, 120]] +) +fig = meshplotlib.plot(mesh_series.read(0), si) +fig.axes[0].scatter(points[:, 0], points[:, 2], s=50, fc="none", ec="r", lw=3) +for i, point in enumerate(points): + fig.axes[0].annotate(str(i), (point[0], point[2] - 5), va="top") +plt.rcdefaults() + +# %% [markdown] +# And now probe the points and the values over time: + +# %% +labels = [f"{i}: {label}" for i, label in enumerate(justified_labels(points))] +fig = meshplotlib.plot_probe( + mesh_series=mesh_series, points=points[:4], mesh_property=si, + time_unit="a", labels=labels[:4] +) +# %% [markdown] +# You can also pass create your own matplotlib figure and pass the axes object. +# Additionally, you can pass any keyword arguments which are known by +# matplotlibs plot function to further customize the curves. +# In this case `marker` and `linewidth` are not part of the API of `plot_probe` +# but get processed correctly anyway. + +# %% +fig, axs = plt.subplots(nrows=2, figsize=[10, 5]) +meshplotlib.plot_probe( + mesh_series, points[:4], si, time_unit="a", ax=axs[0], colors=["k"], + labels=labels[:4], marker=".") +meshplotlib.plot_probe( + mesh_series, points[4:], si, time_unit="a", ax=axs[1], linestyles=["-"], + labels=labels[4:], linewidth=1, +) diff --git a/ogstools/meshlib/examples/__init__.py b/ogstools/meshlib/examples/__init__.py index 658c837bf1b505bbb9ab9f31536509ab9bd924aa..fee23029b35dd02300a89f203da3bf34a0ba51f3 100644 --- a/ogstools/meshlib/examples/__init__.py +++ b/ogstools/meshlib/examples/__init__.py @@ -3,4 +3,5 @@ from ogstools.definitions import ROOT_DIR # Will probably be replaced with some dynamically generated example examples = ROOT_DIR / "_examples" pvd_file = examples / "2D.pvd" +vtu_file = examples / "2D_ts_0_t_0e+00.vtu" xdmf_file = examples / "2D_single_fracture_HT_2D_single_fracture.xdmf" diff --git a/ogstools/meshlib/mesh_series.py b/ogstools/meshlib/mesh_series.py index 552e13f34fc3c8604bde450bf700b2e9d7f4b37e..d43f396541e919545b9263f3cf3650a2d2c620bb 100644 --- a/ogstools/meshlib/mesh_series.py +++ b/ogstools/meshlib/mesh_series.py @@ -1,144 +1,16 @@ """A class to handle Meshseries data.""" from pathlib import Path -from typing import Union +from typing import Literal, Optional, Union import meshio import numpy as np import pyvista as pv -from meshio.xdmf.time_series import ( - ReadError, - cell_data_from_raw, - xdmf_to_numpy_type, -) +import vtuIO +from scipy.interpolate import LinearNDInterpolator, NearestNDInterpolator from tqdm.auto import tqdm - -class TimeSeriesReader(meshio.xdmf.TimeSeriesReader): - def __init__(self, filename): - super().__init__(filename) - - def read_data(self, k: int): - point_data = {} - cell_data_raw = {} - other_data = {} - - t = None - - for c in list(self.collection[k]): - if c.tag == "Time": - t = float(c.attrib["Value"]) - elif c.tag == "Attribute": - name = c.get("Name") - - if len(list(c)) != 1: - raise ReadError() - data_item = list(c)[0] - data = self._read_data_item(data_item) - - if c.get("Center") == "Node": - point_data[name] = data - elif c.get("Center") == "Cell": - cell_data_raw[name] = data - elif c.get("Center") == "Other": - other_data[name] = data - else: - raise ReadError() - - else: - # skip the xi:included mesh - continue - - if self.cells is None: - raise ReadError() - cell_data = cell_data_from_raw(self.cells, cell_data_raw) - if t is None: - raise ReadError() - - return t, point_data, cell_data, other_data - - def _read_data_item(self, data_item): - dims = [int(d) for d in data_item.get("Dimensions").split()] - - # Actually, `NumberType` is XDMF2 and `DataType` XDMF3, but many files out there - # use both keys interchangeably. - if data_item.get("DataType"): - if data_item.get("NumberType"): - raise ReadError() - data_type = data_item.get("DataType") - elif data_item.get("NumberType"): - if data_item.get("DataType"): - raise ReadError() - data_type = data_item.get("NumberType") - else: - # Default, see - # <https://xdmf.org/index.php/XDMF_Model_and_Format#XML_Element_.28Xdmf_ClassName.29_and_Default_XML_Attributes> - data_type = "Float" - - try: - precision = data_item.attrib["Precision"] - except KeyError: - precision = "4" - - data_format = data_item.attrib["Format"] - - if data_format == "XML": - return np.fromstring( - data_item.text, - dtype=xdmf_to_numpy_type[(data_type, precision)], - sep=" ", - ).reshape(dims) - if data_format == "Binary": - return np.fromfile( - data_item.text.strip(), - dtype=xdmf_to_numpy_type[(data_type, precision)], - ).reshape(dims) - - if data_format != "HDF": - msg = f"Unknown XDMF Format '{data_format}'." - raise ReadError(msg) - - file_info = data_item.text.strip() - file_h5path__selections = file_info.split("|") - file_h5path = file_h5path__selections[0] - selections = ( - file_h5path__selections[1] - if len(file_h5path__selections) > 1 - else None - ) - filename, h5path = file_h5path.split(":") - if selections: - # offsets, slices, current_data_extends, global_data_extends by dimension - m = [ - list(map(int, att.split(" "))) for att in selections.split(":") - ] - t = np.transpose(m) - selection = tuple( - slice(start, start + extend, step) - for start, step, extend, _ in t - ) - else: - selection = () - - # The HDF5 file path is given with respect to the XDMF (XML) file. - dirpath = self.filename.resolve().parent - full_hdf5_path = dirpath / filename - - if full_hdf5_path in self.hdf5_files: - f = self.hdf5_files[full_hdf5_path] - else: - import h5py - - f = h5py.File(full_hdf5_path, "r") - self.hdf5_files[full_hdf5_path] = f - - if h5path[0] != "/": - raise ReadError() - - for key in h5path[1:].split("/"): - f = f[key] - # `[()]` gives a np.ndarray - return f[selection].squeeze() +from .xdmf_reader import XDMFReader class MeshSeries: @@ -148,15 +20,27 @@ class MeshSeries: Will be replaced by own module in ogstools with similar interface. """ - def __init__(self, filepath: Union[str, Path]) -> None: + def __init__( + self, filepath: Union[str, Path], time_unit: Optional[str] = "s" + ) -> None: + """ + Initialize a MeshSeries object + + :param filepath: Path to the PVD or XDMF file. + :param time_unit: Data unit of the timevalues. + + :returns: A MeshSeries object + """ if isinstance(filepath, Path): filepath = str(filepath) + self.filepath = filepath + self.time_unit = time_unit self._data: dict[int, pv.UnstructuredGrid] = {} self._data_type = filepath.split(".")[-1] if self._data_type == "pvd": self._pvd_reader = pv.PVDReader(filepath) elif self._data_type == "xdmf": - self._xdmf_reader = TimeSeriesReader(filepath) + self._xdmf_reader = XDMFReader(filepath) self._read_xdmf(0) # necessary to initialize hdf5_files meshes = self.hdf5["meshes"] self.hdf5_bulk_name = list(meshes.keys())[ @@ -215,23 +99,23 @@ class MeshSeries: return range(len(self.timevalues)) @property - def timevalues(self) -> list[float]: + def timevalues(self) -> np.ndarray: """Return the timevalues of the timeseries data.""" if self._data_type == "vtu": - return [0] + return np.zeros(1) if self._data_type == "pvd": - return self._pvd_reader.time_values + return np.asarray(self._pvd_reader.time_values) # elif self._data_type == "xdmf": time_values = [] for collection_i in self._xdmf_reader.collection: for element in collection_i: if element.tag == "Time": time_values += [float(element.attrib["Value"])] - return time_values + return np.asarray(time_values) def closest_timestep(self, timevalue: float) -> int: """Return the corresponding timestep from a timevalue.""" - return int(np.argmin(np.abs(np.array(self.timevalues) - timevalue))) + return int(np.argmin(np.abs(self.timevalues - timevalue))) def closest_timevalue(self, timevalue: float) -> float: """Return the closest timevalue to a timevalue.""" @@ -245,7 +129,7 @@ class MeshSeries: self, timevalue: float, lazy_eval: bool = True ) -> pv.UnstructuredGrid: """Return the temporal interpolated mesh for a given timevalue.""" - t_vals = np.array(self.timevalues) + t_vals = self.timevalues ts1 = int(t_vals.searchsorted(timevalue, "right") - 1) ts2 = min(ts1 + 1, len(t_vals) - 1) if np.isclose(timevalue, t_vals[ts1]): @@ -275,5 +159,80 @@ class MeshSeries: if self._data_type == "xdmf": return self.hdf5["meshes"][self.hdf5_bulk_name][data_name] if self._data_type == "pvd": - return [self.read(t)[data_name] for t in tqdm(self.timesteps)] + return np.asarray( + [self.read(t)[data_name] for t in tqdm(self.timesteps)] + ) return mesh[data_name] + + def _probe_pvd( + self, + points: np.ndarray, + data_name: str, + interp_method: Optional[Literal["nearest", "probefilter"]] = None, + interp_backend: Optional[Literal["vtk", "scipy"]] = None, + ): + obs_pts_dict = {f"pt{j}": point for j, point in enumerate(points)} + dim = self.read(0).get_cell(0).dimension + pvd_path = self.filepath + pvdio = vtuIO.PVDIO( + pvd_path, dim=dim, interpolation_backend=interp_backend + ) + values_dict = pvdio.read_time_series( + data_name, obs_pts_dict, interpolation_method=interp_method + ) + return np.asarray(list(values_dict.values())) + + def _probe_xdmf( + self, + points: np.ndarray, + data_name: str, + interp_method: Optional[Literal["nearest", "linear"]] = None, + ): + values = self.hdf5["meshes"][self.hdf5_bulk_name][data_name][:] + geom = self.hdf5["meshes"][self.hdf5_bulk_name]["geometry"][0] + values = np.swapaxes(values, 0, 1) + + # remove flat dimensions for interpolation + for index, axis in enumerate(geom.T): + if np.all(np.isclose(axis, axis[0])): + geom = np.delete(geom, index, 1) + points = np.delete(points, index, 1) + + if interp_method is None: + interp_method = "linear" + interp = { + "nearest": NearestNDInterpolator(geom, values), + "linear": LinearNDInterpolator(geom, values, np.nan), + }[interp_method] + + return np.swapaxes(interp(points), 0, 1) + + def probe( + self, + points: np.ndarray, + data_name: str, + interp_method: Optional[ + Literal["nearest", "linear", "probefilter"] + ] = None, + interp_backend_pvd: Optional[Literal["vtk", "scipy"]] = None, + ) -> np.ndarray: + """ + Probe the MeshSeries at observation points. + + :param points: The points to sample at. + :param data_name: Name of the data to sample. + :param interp_method: Choose the interpolation method, defaults to + `linear` for xdmf MeshSeries and `probefilter` + for pvd MeshSeries. + :param interp_backend: Interpolation backend for PVD MeshSeries. + + :returns: `numpy` array of interpolated data at observation points. + """ + if self._data_type == "xdmf": + assert interp_method != "probefilter" + return self._probe_xdmf(points, data_name, interp_method) + assert self._data_type == "pvd" + assert interp_method != "linear" + return self._probe_pvd( + points, data_name, interp_method, interp_backend_pvd + ) diff --git a/ogstools/meshlib/xdmf_reader.py b/ogstools/meshlib/xdmf_reader.py new file mode 100644 index 0000000000000000000000000000000000000000..6f1b49b5797a73bc50b140484ae44eac20611adb --- /dev/null +++ b/ogstools/meshlib/xdmf_reader.py @@ -0,0 +1,151 @@ +""" +This file provides an override to meshios XDMF Reader since it misses a feature +to handle hyperslabs (there are two ways to handle hyperslab: +the common documented `here <https://www.xdmf.org/index.php/XDMF_Model_and_Format#HyperSlab>`_ +and the way paraview supports it (documentation missing). + +Example:: + + 2D_single_fracture_HT.h5:/meshes/2D_single_fracture/temperature|0 0:1 1:1 190:97 190 + + +to be read like:: + + | start : stride : count : end + +""" + +import meshio +import numpy as np +from meshio.xdmf.time_series import ( + ReadError, + cell_data_from_raw, + xdmf_to_numpy_type, +) + + +class XDMFReader(meshio.xdmf.TimeSeriesReader): + def __init__(self, filename): + super().__init__(filename) + + def read_data(self, k: int): + point_data = {} + cell_data_raw = {} + other_data = {} + + t = None + + for c in list(self.collection[k]): + if c.tag == "Time": + t = float(c.attrib["Value"]) + elif c.tag == "Attribute": + name = c.get("Name") + + if len(list(c)) != 1: + raise ReadError() + data_item = list(c)[0] + data = self._read_data_item(data_item) + + if c.get("Center") == "Node": + point_data[name] = data + elif c.get("Center") == "Cell": + cell_data_raw[name] = data + elif c.get("Center") == "Other": + other_data[name] = data + else: + raise ReadError() + + else: + # skip the xi:included mesh + continue + + if self.cells is None: + raise ReadError() + cell_data = cell_data_from_raw(self.cells, cell_data_raw) + if t is None: + raise ReadError() + + return t, point_data, cell_data, other_data + + def _read_data_item(self, data_item): + dims = [int(d) for d in data_item.get("Dimensions").split()] + + # Actually, `NumberType` is XDMF2 and `DataType` XDMF3, but many files out there + # use both keys interchangeably. + if data_item.get("DataType"): + if data_item.get("NumberType"): + raise ReadError() + data_type = data_item.get("DataType") + elif data_item.get("NumberType"): + if data_item.get("DataType"): + raise ReadError() + data_type = data_item.get("NumberType") + else: + # Default, see + # <https://xdmf.org/index.php/XDMF_Model_and_Format#XML_Element_.28Xdmf_ClassName.29_and_Default_XML_Attributes> + data_type = "Float" + + try: + precision = data_item.attrib["Precision"] + except KeyError: + precision = "4" + + data_format = data_item.attrib["Format"] + + if data_format == "XML": + return np.fromstring( + data_item.text, + dtype=xdmf_to_numpy_type[(data_type, precision)], + sep=" ", + ).reshape(dims) + if data_format == "Binary": + return np.fromfile( + data_item.text.strip(), + dtype=xdmf_to_numpy_type[(data_type, precision)], + ).reshape(dims) + + if data_format != "HDF": + msg = f"Unknown XDMF Format '{data_format}'." + raise ReadError(msg) + + file_info = data_item.text.strip() + file_h5path__selections = file_info.split("|") + file_h5path = file_h5path__selections[0] + selections = ( + file_h5path__selections[1] + if len(file_h5path__selections) > 1 + else None + ) + filename, h5path = file_h5path.split(":") + if selections: + # offsets, slices, current_data_extends, global_data_extends by dimension + m = [ + list(map(int, att.split(" "))) for att in selections.split(":") + ] + t = np.transpose(m) + selection = tuple( + slice(start, start + extend, step) + for start, step, extend, _ in t + ) + else: + selection = () + + # The HDF5 file path is given with respect to the XDMF (XML) file. + dirpath = self.filename.resolve().parent + full_hdf5_path = dirpath / filename + + if full_hdf5_path in self.hdf5_files: + f = self.hdf5_files[full_hdf5_path] + else: + import h5py + + f = h5py.File(full_hdf5_path, "r") + self.hdf5_files[full_hdf5_path] = f + + if h5path[0] != "/": + raise ReadError() + + for key in h5path[1:].split("/"): + f = f[key] + # `[()]` gives a np.ndarray + return f[selection].squeeze() diff --git a/ogstools/meshplotlib/__init__.py b/ogstools/meshplotlib/__init__.py index 3702477db2f628fab93fe6f853e1748e22bb6355..ef020e43b6b86c9560df44526d94a40a5d6c4ac2 100644 --- a/ogstools/meshplotlib/__init__.py +++ b/ogstools/meshplotlib/__init__.py @@ -3,7 +3,7 @@ from .plot_setup import _setup as setup # noqa: I001: noqa -from .core import plot_diff, plot_limit, plot, subplot +from .core import plot_diff, plot_limit, plot_probe, plot, subplot from .plot_features import plot_contour, plot_on_top __all__ = [ @@ -13,5 +13,6 @@ __all__ = [ "plot_diff", "plot_limit", "plot_on_top", + "plot_probe", "subplot", ] diff --git a/ogstools/meshplotlib/core.py b/ogstools/meshplotlib/core.py index 54f7755a1c9567b70dba0f98ca3c9382111db765..134d156b34e7ffc6a4f5b34a3b1111ffd0f267f2 100644 --- a/ogstools/meshplotlib/core.py +++ b/ogstools/meshplotlib/core.py @@ -18,10 +18,12 @@ from matplotlib.patches import Rectangle as Rect from ogstools.meshlib import MeshSeries from ogstools.propertylib import Property, Vector from ogstools.propertylib.presets import _resolve_property +from ogstools.propertylib.unit_registry import u_reg from . import plot_features as pf from . import setup from .levels import get_levels +from .utils import get_style_cycler # TODO: define default data_name for regions in setup @@ -517,3 +519,74 @@ def plot_limit( output_name=limit + " " + property.output_name ) return plot(mesh, limit_property) + + +def plot_probe( + mesh_series: MeshSeries, + points: np.ndarray, + mesh_property: Union[Property, str], + labels: Optional[list[str]] = None, + time_unit: Optional[str] = "s", + interp_method: Optional[Literal["nearest", "linear", "probefilter"]] = None, + interp_backend_pvd: Optional[Literal["vtk", "scipy"]] = None, + colors: Optional[list] = None, + linestyles: Optional[list] = None, + ax: Optional[plt.Axes] = None, + **kwargs, +) -> mfigure.Figure: + """ + Plot the transient property on the observation points in the MeshSeries. + + :param mesh_series: MeshSeries object containing the data to be plotted. + :param points: The points to sample at. + :param mesh_property: The property to be sampled. + :param labels: The labels for each observation point. + :param time_unit: Output unit of the timevalues. + :param interp_method: Choose the interpolation method, defaults to + `linear` for xdmf MeshSeries and `probefilter` + for pvd MeshSeries. + :param interp_backend: Interpolation backend for PVD MeshSeries. + :param kwargs: Keyword arguments passed to matplotlib's plot + function. + + :returns: A matplotlib Figure + """ + points = np.asarray(points) + if len(points.shape) == 1: + points = points[np.newaxis] + if isinstance(mesh_property, str): + data_shape = mesh_series.read(0)[mesh_property].shape + mesh_property = _resolve_property(mesh_property, data_shape) + values = mesh_property.magnitude.strip_units( + mesh_series.probe( + points, mesh_property.data_name, interp_method, interp_backend_pvd + ) + ) + if values.shape[0] == 1: + values = values.flatten() + Q_ = u_reg.Quantity + time_unit_conversion = Q_(Q_(mesh_series.time_unit), time_unit).magnitude + times = time_unit_conversion * mesh_series.timevalues + if ax is None: + fig, ax = plt.subplots() + else: + fig = None + ax.set_prop_cycle(get_style_cycler(len(points), colors, linestyles)) + ax.plot(times, values, label=labels, **kwargs) + if labels is not None: + ax.legend(facecolor="white", framealpha=1, prop={"family": "monospace"}) + time_label = f"time / {time_unit}" if time_unit else "time" + ax.set_axisbelow(True) + ax.grid(which="major", color="lightgrey", linestyle="-") + ax.grid(which="minor", color="0.95", linestyle="--") + unit_str = ( + f" / {mesh_property.get_output_unit()}" + if mesh_property.get_output_unit() + else "" + ) + y_label = mesh_property.output_name.replace("_", " ") + unit_str + ax.set_xlabel(time_label) + ax.set_ylabel(y_label) + ax.label_outer() + ax.minorticks_on() + return fig diff --git a/ogstools/meshplotlib/examples/__init__.py b/ogstools/meshplotlib/examples/__init__.py index 48de8b4874bc71aabaa5d5d7f8bb42d74ca8aee2..bf876b023308f7b8fe4cec04586017b58eb09c99 100644 --- a/ogstools/meshplotlib/examples/__init__.py +++ b/ogstools/meshplotlib/examples/__init__.py @@ -3,8 +3,9 @@ from ogstools.meshlib import MeshSeries # Will probably be replaced with some dynamically generated example examples = ROOT_DIR / "_examples" -meshseries_THM_2D = MeshSeries(str(examples / "2D.pvd")) -meshseries_CT_2D = MeshSeries(str(examples / "elder.xdmf")) +meshseries_THM_2D = MeshSeries(str(examples / "2D.pvd"), time_unit="s") +meshseries_CT_2D = MeshSeries(str(examples / "elder.xdmf"), time_unit="s") meshseries_XDMF = MeshSeries( - str(examples / "2D_single_fracture_HT_2D_single_fracture.xdmf") + str(examples / "2D_single_fracture_HT_2D_single_fracture.xdmf"), + time_unit="s", ) diff --git a/ogstools/meshplotlib/utils.py b/ogstools/meshplotlib/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e31ce760c94b50f10df3f2ac01913d0ebc62718d --- /dev/null +++ b/ogstools/meshplotlib/utils.py @@ -0,0 +1,39 @@ +from typing import Optional + +import matplotlib.pyplot as plt +import numpy as np + + +def justified_labels(points: np.ndarray) -> list[str]: + "Formats an array of points to a list of aligned str." + + def fmt(val: float): + return f"{val:.2f}".rstrip("0").rstrip(".") + + col_lens = np.max( + [[len(fmt(coord)) for coord in point] for point in points], axis=0 + ) + dim = points.shape[1] + return [ + ",".join(fmt(point[i]).rjust(col_lens[i]) for i in range(dim)) + for point in points + ] + + +def get_style_cycler( + min_number_of_styles: int, + colors: Optional[Optional[list]] = None, + linestyles: Optional[list] = None, +) -> plt.cycler: + if colors is None: + colors = plt.rcParams["axes.prop_cycle"].by_key()["color"] + if linestyles is None: + linestyles = ["-", "--", ":", "-."] + styles_len = min(len(colors), len(linestyles)) + c_cycler = plt.cycler(color=colors) + ls_cycler = plt.cycler(linestyle=linestyles) + if min_number_of_styles <= styles_len: + style_cycler = c_cycler[:styles_len] + ls_cycler[:styles_len] + else: + style_cycler = ls_cycler * c_cycler + return style_cycler diff --git a/tests/test_meshlib.py b/tests/test_meshlib.py new file mode 100644 index 0000000000000000000000000000000000000000..ffc9e7f2944ac97ecfadb12623b09e2c622123ee --- /dev/null +++ b/tests/test_meshlib.py @@ -0,0 +1,47 @@ +"""Unit tests for meshlib.""" + +import unittest + +import numpy as np + +from ogstools.meshlib import MeshSeries, examples + + +class UtilsTest(unittest.TestCase): + """Test case for ogstools utilities.""" + + def test_all_types(self): + pvd = MeshSeries(examples.pvd_file) + vtu = MeshSeries(examples.vtu_file) + xdmf = MeshSeries(examples.xdmf_file) + self.assertRaises(TypeError, MeshSeries, __file__) + + for mesh_series in [pvd, xdmf, vtu]: + self.assertTrue( + mesh_series.read(0) == mesh_series.read_closest(1e-6) + ) + self.assertTrue(not np.any(np.isnan(mesh_series.timesteps))) + self.assertTrue( + not np.any(np.isnan(mesh_series.values("temperature"))) + ) + self.assertTrue( + mesh_series.timevalues[mesh_series.closest_timestep(1.0)] + == mesh_series.closest_timevalue(1.0) + ) + mesh_series.clear() + + def test_probe_pvd(self): + "Test point probing on pvd." + mesh_series = MeshSeries(examples.pvd_file) + points = mesh_series.read(0).cell_centers().points + for method in ["nearest", "probefilter"]: + values = mesh_series.probe(points, "temperature", method) + self.assertTrue(not np.any(np.isnan(values))) + + def test_probe_xdmf(self): + "Test point probing on xdmf." + mesh_series = MeshSeries(examples.xdmf_file) + points = mesh_series.read(0).cell_centers().points + for method in ["nearest", "linear", None]: + values = mesh_series.probe(points, "temperature", method) + self.assertTrue(not np.any(np.isnan(values))) diff --git a/tests/test_meshplotlib.py b/tests/test_meshplotlib.py index 6113d32282e643101961abfcb29c03195e25fdb2..e74feedaaf4cbbadc882df0a933d64e335385808 100644 --- a/tests/test_meshplotlib.py +++ b/tests/test_meshplotlib.py @@ -7,10 +7,18 @@ from tempfile import mkstemp import numpy as np from pyvista import examples as pv_examples -from ogstools.meshplotlib import examples, plot, plot_diff, plot_limit, setup +from ogstools.meshplotlib import ( + examples, + plot, + plot_diff, + plot_limit, + plot_probe, + setup, +) from ogstools.meshplotlib.animation import animate, save_animation from ogstools.meshplotlib.levels import get_levels from ogstools.meshplotlib.plot_features import plot_on_top +from ogstools.meshplotlib.utils import justified_labels from ogstools.propertylib import Scalar, presets equality = partial(np.testing.assert_allclose, rtol=1e-7, verbose=True) @@ -37,6 +45,19 @@ class MeshplotlibTest(unittest.TestCase): equality(get_levels(0.0, 0.0, 10), [0.0, 1e-6]) equality(get_levels(1e9, 1e9, 10), [1e9, 1e9 + 1e-6]) + def test_justified_labels(self): + points = np.asarray( + [ + [x, y, z] + for x in np.linspace(-1, 0, 3) + for y in np.linspace(-10, 10, 5) + for z in np.linspace(1e-6, 1e6, 7) + ] + ) + labels = justified_labels(points) + str_lens = np.asarray([len(label) for label in labels]) + self.assertTrue(np.all(str_lens == str_lens[0])) + def test_missing_data(self): """Test missing data in mesh.""" mesh = pv_examples.load_uniform() @@ -69,6 +90,19 @@ class MeshplotlibTest(unittest.TestCase): plot_limit(meshseries, "Si", "min") plot_limit(meshseries, "Si", "max") + def test_plot_probe(self): + """Test creation of probe plots.""" + mesh_series = examples.meshseries_THM_2D + points = mesh_series.read(0).center + plot_probe(mesh_series, points, presets.temperature) + points = mesh_series.read(0).points[[0, -1]] + plot_probe(mesh_series, points, presets.temperature) + mesh_series = examples.meshseries_XDMF + points = mesh_series.read(0).center + plot_probe(mesh_series, points, presets.temperature) + mesh_property = presets.velocity.replace(data_name="darcy_velocity") + plot_probe(mesh_series, points, mesh_property) + def test_animation(self): """Test creation of animation.""" meshseries = examples.meshseries_THM_2D