diff --git a/docs/examples/howto_meshplotlib/plot_xy_labels_with_shared_axes.py b/docs/examples/howto_meshplotlib/plot_xy_labels_with_shared_axes.py index 2e74d7338a551a1ad701b8505fc35491a47d25f4..ded8afdfb74b569656f56e6a4f11efd81858c61c 100644 --- a/docs/examples/howto_meshplotlib/plot_xy_labels_with_shared_axes.py +++ b/docs/examples/howto_meshplotlib/plot_xy_labels_with_shared_axes.py @@ -14,7 +14,6 @@ under various conditions (shared and nor shared X and Y axes). # %% # Import Python packages, change some settings and load example data set import matplotlib.pyplot as plt -import numpy as np from ogstools.meshlib import difference from ogstools.meshplotlib import ( diff --git a/ogstools/meshplotlib/__init__.py b/ogstools/meshplotlib/__init__.py index e50d776f1f964a022e25d9ff3e8b97f8859d6431..74b7c05e1fe9335136b564962ef44591fbab2e0a 100644 --- a/ogstools/meshplotlib/__init__.py +++ b/ogstools/meshplotlib/__init__.py @@ -10,7 +10,6 @@ from .core import ( update_font_sizes, label_spatial_axes, clear_labels, - clear_labels_ax, ) from .plot_features import plot_contour, plot_on_top @@ -24,5 +23,4 @@ __all__ = [ "update_font_sizes", "label_spatial_axes", "clear_labels", - "clear_labels_ax", ] diff --git a/ogstools/meshplotlib/core.py b/ogstools/meshplotlib/core.py index 4962ef0bf5b68380e267c39f34fd1e8e5fa43724..ca139bbfab2d452876c5fcd5f6ad6ebf03b21575 100644 --- a/ogstools/meshplotlib/core.py +++ b/ogstools/meshplotlib/core.py @@ -6,7 +6,6 @@ from typing import Literal, Optional, Union import numpy as np import pyvista as pv -from matplotlib import axes as matplax from matplotlib import cm as mcm from matplotlib import colormaps, rcParams from matplotlib import colors as mcolors @@ -14,6 +13,7 @@ from matplotlib import figure as mfigure from matplotlib import pyplot as plt from matplotlib import ticker as mticker from matplotlib.patches import Rectangle as Rect +from typeguard import typechecked from ogstools.meshlib import MeshSeries from ogstools.propertylib import Property, Vector @@ -304,72 +304,31 @@ def subplot( secax.set_xlabel(f'{"xyz"[projection]} / {setup.length.output_unit}') -def clear_labels_ax(ax: plt.axes) -> plt.axes: - ax.set_xlabel(None) - ax.set_ylabel(None) - return ax - - -def clear_labels(ax: Union[plt.axes, np.array]) -> Union[plt.axes, np.array]: - if isinstance(ax, np.ndarray): - for id_n in range(ax.shape[0]): - for id_m in range(ax.shape[1]): - ax[id_n, id_m] = clear_labels_ax(ax[id_n, id_m]) - if isinstance(ax, matplax.Axes): - # Wrap single axis in np.array - ax = clear_labels_ax(ax) - return ax +def clear_labels(axes: Union[plt.Axes, np.ndarray]) -> None: + ax: plt.Axes + for ax in np.ravel(np.array(axes)): + ax.set_xlabel("") + ax.set_ylabel("") +@typechecked def label_spatial_axes( - ax: Union[plt.axes, np.array], - ax_ids: np.array, -) -> plt.axes: + axes: Union[plt.Axes, np.ndarray], x_label: str = "x", y_label: str = "y" +) -> None: """ - Add labels to X and Y axis + Add labels to x and y axis. - Automatically selects correct pair of directions and unit. - Respects sharex / sharey settings. - - :param ax: Matplotlib Axis object - :param ax_ids: indices of axes label [0,1,2] for [x,y,z] for horizontal and vertical axis + If given an array of axes, only the outer axes will be labeled. """ - if isinstance(ax, np.ndarray): - # Labels will be applied to shared axis - # Shared axis = value in projection is not None - if ax_ids[0] is not None: - x_label = ( - setup.x_label - or f'{"xyz"[ax_ids[0]]} / {setup.length.output_unit}' - ) - for ax_temp in ax[-1, :]: - ax_temp.set_xlabel(x_label) - if ax_ids[1] is not None: - y_label = ( - setup.y_label - or f'{"xyz"[ax_ids[1]]} / {setup.length.output_unit}' - ) - for ax_temp in ax[:, 0]: - ax_temp.set_ylabel(y_label) - elif isinstance(ax, matplax.Axes): - # Labels will only be applied to non shared axis: - # Non shared = value in projection is not None - if ax_ids[0] is not None: - x_label = ( - setup.x_label - or f'{"xyz"[ax_ids[0]]} / {setup.length.output_unit}' - ) - ax.set_xlabel(x_label) - if ax_ids[1] is not None: - y_label = ( - setup.y_label - or f'{"xyz"[ax_ids[1]]} / {setup.length.output_unit}' - ) - ax.set_ylabel(y_label) + if isinstance(axes, np.ndarray): + ax: plt.Axes + for ax in axes[-1, :]: + ax.set_xlabel(f"{x_label} / {setup.length.output_unit}") + for ax in axes[:, 0]: + ax.set_ylabel(f"{y_label} / {setup.length.output_unit}") else: - msg = "ax is neither Matplotlib axis nor Numpy array!" - raise TypeError(msg) - return ax + axes.set_xlabel(f"{x_label} / {setup.length.output_unit}") + axes.set_ylabel(f"{y_label} / {setup.length.output_unit}") def _get_rows_cols( @@ -428,7 +387,7 @@ def get_combined_levels( """ Calculate well spaced levels for the encompassing property range in meshes. """ - mesh_property = get_preset(mesh_property, meshes[0]) + mesh_property = get_preset(mesh_property, meshes.ravel()[0]) p_min, p_max = np.inf, -np.inf unique_vals = np.array([]) for mesh in np.ravel(meshes): @@ -604,7 +563,7 @@ def plot( """ rcParams.update(setup.rcParams_scaled) shape = _get_rows_cols(meshes) - _meshes = np.reshape(meshes, shape).flatten() + _meshes = np.reshape(meshes, shape).ravel() mesh_property = get_preset(mesh_property, _meshes[0]) data_aspects = np.asarray([get_data_aspect(mesh) for mesh in _meshes]) if setup.min_ax_aspect is None and setup.max_ax_aspect is None: diff --git a/tests/test_meshplotlib.py b/tests/test_meshplotlib.py index 2455e3dec72ad4d9eea0878f125c2bdd68834bfe..958dd59bf14c66be4ac7a3b62923e1564fe9f841 100644 --- a/tests/test_meshplotlib.py +++ b/tests/test_meshplotlib.py @@ -232,13 +232,13 @@ class MeshplotlibTest(unittest.TestCase): def test_spatial_label(self): """Test axes labeling""" fig, ax = plt.subplots(2, 2) - ax = label_spatial_axes(ax, np.array([0, 1])) + label_spatial_axes(ax, "x", "y") def test_spatial_label_clear(self): """Test axes labels clearing""" fig, ax = plt.subplots(2, 2) - ax = label_spatial_axes(ax, np.array([0, 1])) - ax = clear_labels(ax) + label_spatial_axes(ax, "x", "y") + clear_labels(ax) def test_limit_plots(self): """Test creation of limit plots."""