diff --git a/ogstools/meshlib/mesh_series.py b/ogstools/meshlib/mesh_series.py index c72b80b026f05a3b6047ef99348f368c33eb2eaf..0022b345f8fae8c8b6ecae9ca2a5ce0705136c45 100644 --- a/ogstools/meshlib/mesh_series.py +++ b/ogstools/meshlib/mesh_series.py @@ -1,16 +1,18 @@ """A class to handle Meshseries data.""" from pathlib import Path -from typing import Optional, Union +from typing import Literal, Optional, Union import meshio import numpy as np import pyvista as pv +import vtuIO from meshio.xdmf.time_series import ( ReadError, cell_data_from_raw, xdmf_to_numpy_type, ) +from scipy.interpolate import LinearNDInterpolator, NearestNDInterpolator from tqdm.auto import tqdm @@ -163,6 +165,7 @@ class MeshSeries: """ if isinstance(filepath, Path): filepath = str(filepath) + self.filepath = filepath self.time_unit = time_unit self._data: dict[int, pv.UnstructuredGrid] = {} self._data_type = filepath.split(".")[-1] @@ -290,3 +293,75 @@ class MeshSeries: if self._data_type == "pvd": return [self.read(t)[data_name] for t in tqdm(self.timesteps)] return mesh[data_name] + + def _probe_pvd( + self, + points: np.ndarray, + data_name: str, + interp_method: Optional[Literal["nearest", "probefilter"]] = None, + interp_backend: Optional[Literal["vtk", "scipy"]] = None, + ): + obs_pts_dict = {f"pt{j}": point for j, point in enumerate(points)} + dim = self.read(0).get_cell(0).dimension + pvd_path = self.filepath + pvdio = vtuIO.PVDIO( + pvd_path, dim=dim, interpolation_backend=interp_backend + ) + values_dict = pvdio.read_time_series( + data_name, obs_pts_dict, interpolation_method=interp_method + ) + return np.asarray(list(values_dict.values())) + + def _probe_xdmf( + self, + points: np.ndarray, + data_name: str, + interp_method: Optional[Literal["nearest", "linear"]] = None, + ): + values = self.hdf5["meshes"][self.hdf5_bulk_name][data_name][:] + geom = self.hdf5["meshes"][self.hdf5_bulk_name]["geometry"][0] + + # remove flat dimensions for interpolation + for index, axis in enumerate(geom.T): + if np.all(np.isclose(axis, axis[0])): + geom = np.delete(geom, index, 1) + points = np.delete(points, index, 1) + + if interp_method is None: + interp_method = "linear" + interp = { + "nearest": NearestNDInterpolator(geom, values.T), + "linear": LinearNDInterpolator(geom, values.T, np.nan), + }[interp_method] + + return interp(points).T + + def probe( + self, + points: np.ndarray, + data_name: str, + interp_method: Optional[ + Literal["nearest", "linear", "probefilter"] + ] = None, + interp_backend_pvd: Optional[Literal["vtk", "scipy"]] = None, + ) -> np.ndarray: + """ + Probe the MeshSeries at observation points. + + :param points: The points to sample at. + :param data_name: Name of the data to sample. + :param interp_method: Choose the interpolation method, defaults to + `linear` for xdmf MeshSeries and `probefilter` + for pvd MeshSeries. + :param interp_backend: Interpolation backend for PVD MeshSeries. + + :returns: `numpy` array of interpolated data at observation points. + """ + if self._data_type == "xdmf": + assert interp_method != "probefilter" + return self._probe_xdmf(points, data_name, interp_method) + assert self._data_type == "pvd" + assert interp_method != "linear" + return self._probe_pvd( + points, data_name, interp_method, interp_backend_pvd + )