diff --git a/ogstools/meshplotlib/core.py b/ogstools/meshplotlib/core.py index 682f24fdcafb867a6db86cb6472fcd7d54ee09ee..54f7755a1c9567b70dba0f98ca3c9382111db765 100644 --- a/ogstools/meshplotlib/core.py +++ b/ogstools/meshplotlib/core.py @@ -32,14 +32,29 @@ def _q_zero_line(property: Property, levels: np.ndarray): ) -def get_data(mesh: pv.UnstructuredGrid, property: Property) -> np.ndarray: +def has_masked_values(mesh: pv.UnstructuredGrid, property: Property) -> bool: + return ( + not property.is_mask() + and property.mask in mesh.cell_data + and (len(mesh.cell_data[property.mask]) != 0) + ) + + +def get_data( + mesh: pv.UnstructuredGrid, property: Property, masked: bool = False +) -> np.ndarray: """Get the data associated with a scalar or vector property from a mesh.""" - if property.data_name in mesh.point_data: - return mesh.point_data[property.data_name] - if property.data_name in mesh.cell_data: - return mesh.cell_data[property.data_name] - msg = f"Property not found in mesh {mesh}." - raise IndexError(msg) + if ( + property.data_name not in mesh.point_data + and property.data_name not in mesh.cell_data + ): + msg = f"Property not found in mesh {mesh}." + raise IndexError(msg) + if masked: + return mesh.ctp(True).threshold(value=[1, 1], scalars=property.mask)[ + property.data_name + ] + return mesh[property.data_name] def get_level_boundaries(levels: np.ndarray): @@ -196,11 +211,7 @@ def subplot( ax.axis("auto") - if ( - not property.is_mask() - and property.mask in mesh.cell_data - and len(mesh.cell_data[property.mask]) - ): + if has_masked_values(mesh, property): subplot(mesh, property.get_mask(), ax) mesh = mesh.ctp(True).threshold(value=[1, 1], scalars=property.mask) @@ -341,17 +352,9 @@ def get_combined_levels( p_min, p_max = np.inf, -np.inf unique_vals = np.array([]) for mesh in np.ravel(meshes): - if ( - not property.is_mask() - and property.mask in mesh.cell_data - and len(mesh.cell_data[property.mask]) - ): - _mesh = mesh.ctp(True).threshold( - value=[1, 1], scalars=property.mask - ) - values = property.magnitude.strip_units(get_data(_mesh, property)) - else: - values = property.magnitude.strip_units(get_data(mesh, property)) + values = property.magnitude.strip_units( + get_data(mesh, property, masked=has_masked_values(mesh, property)) + ) if setup.log_scaled: # TODO: can be improved values = np.log10(np.where(values > 1e-14, values, 1e-14)) p_min = min(p_min, np.nanmin(values)) if setup.p_min is None else p_min @@ -508,7 +511,8 @@ def plot_limit( data_shape = mesh[property].shape property = _resolve_property(property, data_shape) func = {"min": np.min, "max": np.max}[limit] - mesh = mesh_series.reduce_with(func, property.data_name) + vals = mesh_series.values(property.data_name) + func(vals, out=mesh[property.data_name], axis=0) limit_property = property.replace( output_name=limit + " " + property.output_name )