diff --git a/ogstools/meshlib/data_processing.py b/ogstools/meshlib/data_processing.py index 20771c2b8b1597022dcbc8bdc98f183eba87dfb7..b547566f2bce87aabebdd604caa3daa20f888a06 100644 --- a/ogstools/meshlib/data_processing.py +++ b/ogstools/meshlib/data_processing.py @@ -18,9 +18,7 @@ def difference( :param mesh_property: The property to of interest. :returns: A new mesh representing the difference between mesh1 and mesh2. """ - if isinstance(mesh_property, str): - data_shape = mesh1[mesh_property].shape - mesh_property = presets.get_preset(mesh_property, data_shape) + mesh_property = presets.get_preset(mesh_property, mesh1) diff_mesh = mesh1.copy(deep=True) diff_mesh[mesh_property.data_name] -= mesh2[mesh_property.data_name] return diff_mesh diff --git a/ogstools/meshplotlib/core.py b/ogstools/meshplotlib/core.py index 2d5af77aef8bef9336ec295b9e4a35c0dda8c1b9..b9f659847860756e02208ed8dcb6c7dd501ef01a 100644 --- a/ogstools/meshplotlib/core.py +++ b/ogstools/meshplotlib/core.py @@ -224,9 +224,7 @@ def subplot( Custom levels and a colormap string can be provided. """ - if isinstance(mesh_property, str): - data_shape = mesh[mesh_property].shape - mesh_property = get_preset(mesh_property, data_shape) + mesh_property = get_preset(mesh_property, mesh) if mesh.get_cell(0).dimension == 3: msg = "meshplotlib is for 2D meshes only, but found 3D elements." raise ValueError(msg) @@ -261,18 +259,15 @@ def subplot( levels = compute_levels(p_min, p_max, num_levels) cmap, norm = get_cmap_norm(levels, mesh_property) - if ( - mesh_property.data_name in mesh.cell_data - and mesh_property.data_name not in mesh.point_data - ): + if mesh_property.data_name in mesh.point_data: + ax.tricontourf(x, y, tri, values, levels=levels, cmap=cmap, norm=norm) + if _q_zero_line(mesh_property, levels): + ax.tricontour(x, y, tri, values, levels=[0], colors="w") + else: ax.tripcolor(x, y, tri, facecolors=values, cmap=cmap, norm=norm) if mesh_property.is_mask(): ax.tripcolor(x, y, tri, facecolors=values, mask=(values == 1), cmap=cmap, norm=norm, hatch="/") # fmt: skip - else: - ax.tricontourf(x, y, tri, values, levels=levels, cmap=cmap, norm=norm) - if _q_zero_line(mesh_property, levels): - ax.tricontour(x, y, tri, values, levels=[0], colors="w") surf = mesh.extract_surface() @@ -435,9 +430,7 @@ def get_combined_levels( """ Calculate well spaced levels for the encompassing property range in meshes. """ - if isinstance(mesh_property, str): - data_shape = meshes[0][mesh_property].shape - mesh_property = get_preset(mesh_property, data_shape) + mesh_property = get_preset(mesh_property, meshes[0]) p_min, p_max = np.inf, -np.inf unique_vals = np.array([]) for mesh in np.ravel(meshes): @@ -509,7 +502,7 @@ def _draw_plot( _levels = ( combined_levels if setup.combined_colorbar - else get_combined_levels(np_meshes[i, j], mesh_property) + else get_combined_levels([np_meshes[i, j]], mesh_property) ) subplot(np_meshes[i, j], mesh_property, np_axs[i, j], _levels) @@ -547,7 +540,7 @@ def _draw_plot( for i in range(shape[0]): for j in range(shape[1]): _levels = get_combined_levels( - np_meshes[i, j], mesh_property + [np_meshes[i, j]], mesh_property ) add_colorbars(fig, np_axs[i, j], mesh_property, _levels) return fig @@ -614,9 +607,7 @@ def plot( rcParams.update(setup.rcParams_scaled) shape = _get_rows_cols(meshes) _meshes = np.reshape(meshes, shape).flatten() - if isinstance(mesh_property, str): - data_shape = _meshes[0][mesh_property].shape - mesh_property = get_preset(mesh_property, data_shape) + mesh_property = get_preset(mesh_property, _meshes[0]) data_aspects = np.asarray([get_data_aspect(mesh) for mesh in _meshes]) if setup.min_ax_aspect is None and setup.max_ax_aspect is None: fig_aspect = np.mean(data_aspects) @@ -663,9 +654,7 @@ def plot_limit( :returns: A matplotlib Figure """ mesh = mesh_series.read(0) - if isinstance(mesh_property, str): - data_shape = mesh[mesh_property].shape - mesh_property = get_preset(mesh_property, data_shape) + mesh_property = get_preset(mesh_property, mesh) func = {"min": np.min, "max": np.max}[limit] vals = mesh_series.values(mesh_property.data_name) func(vals, out=mesh[mesh_property.data_name], axis=0) @@ -710,9 +699,7 @@ def plot_probe( points = np.asarray(points) if len(points.shape) == 1: points = points[np.newaxis] - if isinstance(mesh_property, str): - data_shape = mesh_series.read(0)[mesh_property].shape - mesh_property = get_preset(mesh_property, data_shape) + mesh_property = get_preset(mesh_property, mesh_series.read(0)) values = mesh_property.magnitude.transform( mesh_series.probe( points, mesh_property.data_name, interp_method, interp_backend_pvd @@ -726,11 +713,9 @@ def plot_probe( x_values = time_unit_conversion * mesh_series.timevalues x_label = f"time / {time_unit}" if time_unit else "time" else: - if isinstance(mesh_property_abscissa, str): - data_shape = mesh_series.read(0)[mesh_property_abscissa].shape - mesh_property_abscissa = get_preset( - mesh_property_abscissa, data_shape - ) + mesh_property_abscissa = get_preset( + mesh_property_abscissa, mesh_series.read(0) + ) x_values = mesh_property_abscissa.magnitude.transform( mesh_series.probe( points, diff --git a/ogstools/propertylib/presets.py b/ogstools/propertylib/presets.py index d91b4c1ac7f44320743af0a6127b7b954c42b1b6..7d07843b4d497e4be4fc746f0dce93c498354398 100644 --- a/ogstools/propertylib/presets.py +++ b/ogstools/propertylib/presets.py @@ -5,9 +5,10 @@ """ from functools import partial -from typing import Optional +from typing import Union import pandas as pd +import pyvista as pv from . import mesh_dependent, tensor_math from .custom_colormaps import integrity_cmap, temperature_cmap @@ -140,24 +141,38 @@ nodal_forces = Vector(data_name="NodalForces", mask=M_MASK) all_properties = [v for v in locals().values() if isinstance(v, Property)] -def get_preset(property_name: str, shape: Optional[tuple] = None) -> Property: +def get_preset( + mesh_property: Union[Property, str], mesh: pv.DataSet +) -> Property: """ - Returns a Property preset or create one with correct type. + Returns a Property preset or creates one with correct type. Searches for presets by data_name and output_name and returns if found. - Otherwise create Scalar, Vector or Matrix Property depending on shape. + Otherwise create Scalar, Vector, or Matrix Property depending on the shape + of data in mesh. + + :param mesh_property: The property to retrieve or its name if a string. + :param mesh: The mesh dataset containing the property data. + :returns: A corresponding Property preset or a new Property of correct type. """ + if isinstance(mesh_property, Property): + return mesh_property + for prop in all_properties: - if prop.output_name == property_name: + if prop.output_name == mesh_property: return prop for prop in all_properties: - if prop.data_name == property_name: + if prop.data_name == mesh_property: return prop - if shape is None or len(shape) == 1: - return Scalar(property_name) - if shape[1] in [2, 3]: - return Vector(property_name) - return Matrix(property_name) + if mesh_property not in set().union(mesh.point_data, mesh.cell_data): + msg = f"Property {mesh_property} not found in mesh." + raise KeyError(msg) + data_shape = mesh[mesh_property].shape + if len(data_shape) == 1: + return Scalar(mesh_property) + if data_shape[1] in [2, 3]: + return Vector(mesh_property) + return Matrix(mesh_property) def get_dataframe() -> pd.DataFrame: diff --git a/ogstools/propertylib/property.py b/ogstools/propertylib/property.py index ff983d56396334c71dc13ae8dbf8e249919294b1..2cc89232f066cc47c7906aaf1b1c3e58e6d7a8f0 100644 --- a/ogstools/propertylib/property.py +++ b/ogstools/propertylib/property.py @@ -173,14 +173,11 @@ class Property: def _get_data( self, mesh: pv.UnstructuredGrid, masked: bool = True - ) -> pv.UnstructuredGrid: + ) -> np.ndarray: """Get the data associated with a scalar or vector property from a mesh.""" - if ( - self.data_name not in mesh.point_data - and self.data_name not in mesh.cell_data - ): - msg = f"Property {self.data_name} not found in mesh." - raise IndexError(msg) + if self.data_name not in set().union(mesh.point_data, mesh.cell_data): + msg = f"Data name {self.data_name} not found in mesh." + raise KeyError(msg) if masked and self.mask_used(mesh): return mesh.ctp(True).threshold(value=[1, 1], scalars=self.mask)[ self.data_name diff --git a/ogstools/studies/templates/convergence_study.py b/ogstools/studies/templates/convergence_study.py index 8bfef1c5d6b863cc1734ef0b265eccbd1972317d..30ffe046c2d27ed88c20a8d0f6a706c82043a84b 100644 --- a/ogstools/studies/templates/convergence_study.py +++ b/ogstools/studies/templates/convergence_study.py @@ -41,12 +41,7 @@ mesh_series = [meshlib.MeshSeries(mesh_path) for mesh_path in mesh_paths] timestep_sizes = [np.mean(np.diff(ms.timevalues)) for ms in mesh_series] meshes = [ms.read_closest(timevalue) for ms in mesh_series] topology: pv.DataSet = meshes[-3] -data_shape = ( - meshes[0][property_name].shape - if property_name in meshes[0].point_data - else None -) -mesh_property = propertylib.presets.get_preset(property_name, data_shape) +mesh_property = propertylib.presets.get_preset(property_name, meshes[0]) richardson = studies.convergence.richardson_extrapolation( meshes, mesh_property, topology, refinement_ratio ) diff --git a/tests/test_meshplotlib.py b/tests/test_meshplotlib.py index 345557cb7f043b5434fe84d698bfca2ed70926bc..bde3133f30d73190058d88313ce7907bcd5a996c 100644 --- a/tests/test_meshplotlib.py +++ b/tests/test_meshplotlib.py @@ -114,7 +114,7 @@ class MeshplotlibTest(unittest.TestCase): def test_missing_data(self): """Test missing data in mesh.""" mesh = pv_examples.load_uniform() - self.assertRaises(IndexError, plot, mesh, Scalar("missing_data")) + self.assertRaises(KeyError, plot, mesh, Scalar("missing_data")) def test_plot_2D(self): """Test creation of 2D plots.""" diff --git a/tests/test_propertylib.py b/tests/test_propertylib.py index c7a5b41c49ae30d74f96576dbd3442157269cb64..9d3bdc72c64fabc2b828b3cbcafad79f4d973f45 100644 --- a/tests/test_propertylib.py +++ b/tests/test_propertylib.py @@ -8,10 +8,8 @@ from pint.facets.plain import PlainQuantity from ogstools.meshplotlib.examples import mesh_mechanics from ogstools.propertylib import presets as pp -from ogstools.propertylib.matrix import Matrix from ogstools.propertylib.mesh_dependent import depth -from ogstools.propertylib.property import Scalar, u_reg -from ogstools.propertylib.vector import Vector +from ogstools.propertylib.property import u_reg Qty = u_reg.Quantity @@ -21,7 +19,7 @@ class PhysicalPropertyTest(unittest.TestCase): EPS = 1e-7 - def equality(self, p: Scalar, vals: np.ndarray, res: PlainQuantity): + def equality(self, p: pp.Scalar, vals: np.ndarray, res: PlainQuantity): """ Assert the equality of property calculations. @@ -185,11 +183,17 @@ class PhysicalPropertyTest(unittest.TestCase): def test_get_preset(self): """Test find property function.""" - self.assertEqual(pp.get_preset("pressure"), pp.pressure) - self.assertEqual(pp.get_preset("pore_pressure"), pp.pressure) - self.assertEqual(pp.get_preset("test"), Scalar("test")) - self.assertEqual(pp.get_preset("test", shape=(100, 3)), Vector("test")) - self.assertEqual(pp.get_preset("test", shape=(100, 6)), Matrix("test")) + mesh = mesh_mechanics + mesh.point_data["scalar"] = mesh["temperature"] + mesh.point_data["vector"] = mesh["displacement"] + mesh.point_data["matrix"] = mesh["sigma"] + self.assertEqual(pp.get_preset("temperature", mesh), pp.temperature) + self.assertEqual(pp.get_preset("displacement", mesh), pp.displacement) + self.assertEqual(pp.get_preset("sigma", mesh), pp.stress) + self.assertEqual(pp.get_preset("scalar", mesh), pp.Scalar("scalar")) + self.assertEqual(pp.get_preset("vector", mesh), pp.Vector("vector")) + self.assertEqual(pp.get_preset("matrix", mesh), pp.Matrix("matrix")) + self.assertRaises(KeyError, pp.get_preset, "test", mesh) def test_copy_ctor(self): """Test replace constructor."""