From 837cb517e2872c05fe2da222a6fbac51839a04d3 Mon Sep 17 00:00:00 2001
From: Feliks Kiszkurno <feliks-kuba.kiszkurno@ufz.de>
Date: Thu, 14 Mar 2024 13:28:16 +0000
Subject: [PATCH] Semi-automatic labeling spatial shared axes

---
 .../plot_xy_labels_with_shared_axes.py        |  88 +++++++++++++
 ogstools/meshplotlib/__init__.py              |   6 +
 ogstools/meshplotlib/core.py                  | 120 ++++++++++++++----
 tests/test_meshplotlib.py                     |  59 +++++++++
 4 files changed, 250 insertions(+), 23 deletions(-)
 create mode 100644 docs/examples/howto_meshplotlib/plot_xy_labels_with_shared_axes.py

diff --git a/docs/examples/howto_meshplotlib/plot_xy_labels_with_shared_axes.py b/docs/examples/howto_meshplotlib/plot_xy_labels_with_shared_axes.py
new file mode 100644
index 000000000..dad4adef7
--- /dev/null
+++ b/docs/examples/howto_meshplotlib/plot_xy_labels_with_shared_axes.py
@@ -0,0 +1,88 @@
+"""
+Labeling directional shared axes
+=================================
+
+.. sectionauthor:: Feliks Kiszkurno (Helmholtz Centre for Environmental Research GmbH - UFZ)
+
+.. warning:: This example discusses functionality that may experience breaking changes in the near future!
+
+For this example we load a 2D meshseries dataset from within the ``meshplotlib`` examples.
+This tutorial covers automatic labeling the directional axes (X and Y) under various conditions (shared and nor shared X and Y axes).
+"""
+
+# %%
+# Import Python packages, change some settings and load example data set
+import matplotlib.pyplot as plt
+import numpy as np
+
+import ogstools.meshplotlib as mpl
+from ogstools.meshplotlib import (
+    clear_labels,
+    examples,
+    label_spatial_axes,
+    plot,
+    plot_diff,
+)
+from ogstools.propertylib import presets
+
+plt.rcParams.update({"font.size": 32})
+
+mpl.setup.reset()
+mpl.setup.length.output_unit = "km"
+
+meshseries = examples.meshseries_THM_2D
+
+# %%
+# First, by default (without shared axes) both X and Y axes will be labeled automatically. The default is that both axes are shared and this will be respected.
+
+mpl.setup.combined_colorbar = False
+fig = plot([meshseries.read(0), meshseries.read(1)], presets.temperature)
+
+
+# %%
+# On user provided figure and axis, this behaviour is different. To allow for more complex combinations of plot functions, meshseries and process variables, the axis belonging to specific subplot has to be passed. In this case the default is to plot labels on each axis regardless of whether it is share or not.
+
+mpl.setup.combined_colorbar = False
+fig, ax = plt.subplots(2, 2, figsize=(40, 20), sharex=True, sharey=True)
+plot(meshseries.read(0), presets.temperature, fig=fig, ax=ax[0][0])
+plot(meshseries.read(1), presets.temperature, fig=fig, ax=ax[1][0])
+plot_diff(
+    meshseries.read(0),
+    meshseries.read(1),
+    presets.temperature,
+    fig=fig,
+    ax=ax[0][1],
+)
+plot_diff(
+    meshseries.read(1),
+    meshseries.read(0),
+    presets.temperature,
+    fig=fig,
+    ax=ax[1][1],
+)
+fig.tight_layout()
+
+# %%
+# If user wishes to have labels respecting shared axes, they need to be first removed and applied again. Meshplotlib provides two function that make it easy: clear_labels and label_spatial_axes. They have to be called after the last plot related function call.
+
+mpl.setup.combined_colorbar = False
+fig, ax = plt.subplots(2, 2, figsize=(40, 20), sharex=True, sharey=True)
+plot(meshseries.read(0), presets.temperature, fig=fig, ax=ax[0][0])
+plot(meshseries.read(1), presets.temperature, fig=fig, ax=ax[1][0])
+plot_diff(
+    meshseries.read(0),
+    meshseries.read(1),
+    presets.temperature,
+    fig=fig,
+    ax=ax[0][1],
+)
+plot_diff(
+    meshseries.read(1),
+    meshseries.read(0),
+    presets.temperature,
+    fig=fig,
+    ax=ax[1][1],
+)
+ax = clear_labels(ax)
+ax = label_spatial_axes(ax, np.array([0, 1]))
+fig.tight_layout()
diff --git a/ogstools/meshplotlib/__init__.py b/ogstools/meshplotlib/__init__.py
index 2168c807c..29164f1b5 100644
--- a/ogstools/meshplotlib/__init__.py
+++ b/ogstools/meshplotlib/__init__.py
@@ -10,6 +10,9 @@ from .core import (
     plot,
     subplot,
     update_font_sizes,
+    label_spatial_axes,
+    clear_labels,
+    clear_labels_ax,
 )
 from .plot_features import plot_contour, plot_on_top
 
@@ -23,4 +26,7 @@ __all__ = [
     "plot_probe",
     "subplot",
     "update_font_sizes",
+    "label_spatial_axes",
+    "clear_labels",
+    "clear_labels_ax",
 ]
diff --git a/ogstools/meshplotlib/core.py b/ogstools/meshplotlib/core.py
index e64ea94c8..47911216d 100644
--- a/ogstools/meshplotlib/core.py
+++ b/ogstools/meshplotlib/core.py
@@ -7,6 +7,7 @@ from typing import Literal, Optional, Union
 
 import numpy as np
 import pyvista as pv
+from matplotlib import axes as matplax
 from matplotlib import cm as mcm
 from matplotlib import colormaps, rcParams
 from matplotlib import colors as mcolors
@@ -195,6 +196,21 @@ def add_colorbars(
         )
 
 
+def get_projection(
+    mesh: pv.UnstructuredGrid,
+) -> tuple[int, int]:
+    """
+    Identify which projection is used: XY, XZ or YZ.
+
+    :param mesh: singular mesh
+
+    """
+    mean_normal = np.abs(np.mean(mesh.extract_surface().cell_normals, axis=0))
+    projection = int(np.argmax(mean_normal))
+    x_id, y_id = np.delete([0, 1, 2], projection)
+    return x_id, y_id
+
+
 def subplot(
     mesh: pv.UnstructuredGrid,
     mesh_property: Union[Property, str],
@@ -227,9 +243,9 @@ def subplot(
     surf_tri = mesh.triangulate().extract_surface()
 
     # get projection
+    x_id, y_id = get_projection(mesh)
     mean_normal = np.abs(np.mean(mesh.extract_surface().cell_normals, axis=0))
     projection = int(np.argmax(mean_normal))
-    x_id, y_id = np.delete([0, 1, 2], projection)
 
     # faces contains a padding indicating number of points per face which gets
     # removed with this reshaping and slicing to get the array of tri's
@@ -295,10 +311,73 @@ def subplot(
         secax.set_xticklabels(sec_labels)
         secax.set_xlabel(f'{"xyz"[projection]} / {setup.length.output_unit}')
 
-    x_label = setup.x_label or f'{"xyz"[x_id]} / {setup.length.output_unit}'
-    y_label = setup.y_label or f'{"xyz"[y_id]} / {setup.length.output_unit}'
-    ax.set_xlabel(x_label)
-    ax.set_ylabel(y_label)
+
+def clear_labels_ax(ax: plt.axes) -> plt.axes:
+    ax.set_xlabel(None)
+    ax.set_ylabel(None)
+    return ax
+
+
+def clear_labels(ax: Union[plt.axes, np.array]) -> Union[plt.axes, np.array]:
+    if isinstance(ax, np.ndarray):
+        for id_n in range(ax.shape[0]):
+            for id_m in range(ax.shape[1]):
+                ax[id_n, id_m] = clear_labels_ax(ax[id_n, id_m])
+    if isinstance(ax, matplax.Axes):
+        # Wrap single axis in np.array
+        ax = clear_labels_ax(ax)
+    return ax
+
+
+def label_spatial_axes(
+    ax: Union[plt.axes, np.array],
+    ax_ids: np.array,
+) -> plt.axes:
+    """
+    Add labels to X and Y axis
+
+    Automatically selects correct pair of directions and unit.
+    Respects sharex / sharey settings.
+
+    :param ax: Matplotlib Axis object
+    :param ax_ids: indices of axes label [0,1,2] for [x,y,z] for horizontal and vertical axis
+    """
+    if isinstance(ax, np.ndarray):
+        # Labels will be applied to shared axis
+        # Shared axis = value in projection is not None
+        if ax_ids[0] is not None:
+            x_label = (
+                setup.x_label
+                or f'{"xyz"[ax_ids[0]]} / {setup.length.output_unit}'
+            )
+            for ax_temp in ax[-1, :]:
+                ax_temp.set_xlabel(x_label)
+        if ax_ids[1] is not None:
+            y_label = (
+                setup.y_label
+                or f'{"xyz"[ax_ids[1]]} / {setup.length.output_unit}'
+            )
+            for ax_temp in ax[:, 0]:
+                ax_temp.set_ylabel(y_label)
+    elif isinstance(ax, matplax.Axes):
+        # Labels will only be applied to non shared axis:
+        # Non shared = value in projection is not None
+        if ax_ids[0] is not None:
+            x_label = (
+                setup.x_label
+                or f'{"xyz"[ax_ids[0]]} / {setup.length.output_unit}'
+            )
+            ax.set_xlabel(x_label)
+        if ax_ids[1] is not None:
+            y_label = (
+                setup.y_label
+                or f'{"xyz"[ax_ids[1]]} / {setup.length.output_unit}'
+            )
+            ax.set_ylabel(y_label)
+    else:
+        msg = "ax is neither Matplotlib axis nor Numpy array!"
+        raise TypeError(msg)
+    return ax
 
 
 def _get_rows_cols(
@@ -324,7 +403,9 @@ def _get_rows_cols(
 # TODO: fixed_figure_size -> ax aspect automatic
 
 
-def _fig_init(rows: int, cols: int, aspect: float = 1.0) -> mfigure.Figure:
+def _fig_init(
+    rows: int, cols: int, aspect: float = 1.0
+) -> tuple[mfigure.Figure, plt.Axes]:
     nx_cb = 1 if setup.combined_colorbar else cols
     default_size = 8
     cb_width = 3
@@ -336,7 +417,7 @@ def _fig_init(rows: int, cols: int, aspect: float = 1.0) -> mfigure.Figure:
             default_size * rows + x_label_height,
         ]
     )
-    fig, _ = plt.subplots(
+    fig, ax = plt.subplots(
         rows,
         cols,
         dpi=setup.dpi * setup.fig_scale,
@@ -346,7 +427,7 @@ def _fig_init(rows: int, cols: int, aspect: float = 1.0) -> mfigure.Figure:
         sharey=True,
     )
     fig.patch.set_alpha(1)
-    return fig
+    return fig, ax
 
 
 def get_combined_levels(
@@ -402,19 +483,7 @@ def _draw_plot(
     shape = _get_rows_cols(meshes)
     np_meshes = np.reshape(meshes, shape)
     if fig is not None and ax is not None:
-        # Fig is given, but plotting on specific ax is intended
-        # Only one mesh in meshes shall be accepted
-        if shape != (1, 1):
-            msg = "You have provided only one Axis object but multiple meshes. Provide only one mesh per Axis object, or provide Figure object instead."
-            raise ValueError(msg)
-        np_axs = np.reshape(np.array(ax), (1, 1))
-    elif fig is not None and ax is not None:
-        # Fig is given and Ax is given
-        # Multiple meshes should be accepted
-        if shape != (1, 1):
-            msg = "You have provided only one Axis object but multiple meshes. Provide only one mesh per Axis object, or provide Figure object instead."
-            raise ValueError(msg)
-        np_axs = np.reshape(ax, (1, 1))
+        np_axs = np.reshape(np.array(ax), shape)
     elif fig is not None and ax is None:
         # Only Fig is given
         # Multiple meshes should be accepted
@@ -444,6 +513,11 @@ def _draw_plot(
                 else get_combined_levels(np_meshes[i, j], mesh_property)
             )
             subplot(np_meshes[i, j], mesh_property, np_axs[i, j], _levels)
+
+    x_id, y_id = get_projection(
+        np_meshes[0, 0]
+    )  # One mesh is sufficient, it should be the same for all of them
+    np_axs = label_spatial_axes(np_axs, np.array([x_id, y_id]))
     np_axs[0, 0].set_title(setup.title_center, loc="center", y=1.02)
     np_axs[0, 0].set_title(setup.title_left, loc="left", y=1.02)
     np_axs[0, 0].set_title(setup.title_right, loc="right", y=1.02)
@@ -554,8 +628,8 @@ def plot(
     ax_aspects = fig_aspect / data_aspects
     n_axs = shape[0] * shape[1]
     if ax is None and fig is None:
-        _fig = _fig_init(rows=shape[0], cols=shape[1], aspect=fig_aspect)
-        fig = _draw_plot(meshes, mesh_property, fig=_fig)
+        _fig, _ax = _fig_init(rows=shape[0], cols=shape[1], aspect=fig_aspect)
+        fig = _draw_plot(meshes, mesh_property, fig=_fig, ax=_ax)
         for ax, aspect in zip(fig.axes[: n_axs + 1], ax_aspects):
             ax.set_aspect(1.0 / aspect)
     elif ax is not None and fig is None:
diff --git a/tests/test_meshplotlib.py b/tests/test_meshplotlib.py
index ddc3bf56f..1e8e8ad67 100644
--- a/tests/test_meshplotlib.py
+++ b/tests/test_meshplotlib.py
@@ -9,7 +9,9 @@ import numpy as np
 from pyvista import examples as pv_examples
 
 from ogstools.meshplotlib import (
+    clear_labels,
     examples,
+    label_spatial_axes,
     plot,
     plot_diff,
     plot_limit,
@@ -189,6 +191,63 @@ class MeshplotlibTest(unittest.TestCase):
         fig = update_font_sizes(fig, fontsize=25)
         fig.suptitle("Test user defined fig")
 
+    def test_sharexy(self):
+        """Test if labels are skipped if axis are shared"""
+        meshseries = examples.meshseries_THM_2D
+        fig, ax = plt.subplots(2, 2, sharex=True, sharey=True)
+        ax = ax.flatten()
+        plot(meshseries.read(0), presets.temperature, fig=fig, ax=ax[0])
+        plot(meshseries.read(1), presets.temperature, fig=fig, ax=ax[1])
+        plot_diff(
+            meshseries.read(0),
+            meshseries.read(1),
+            presets.temperature,
+            fig=fig,
+            ax=ax[2],
+        )
+        plot_diff(
+            meshseries.read(1),
+            meshseries.read(0),
+            presets.temperature,
+            fig=fig,
+            ax=ax[3],
+        )
+        fig.tight_layout()
+
+    def test_label_sharedxy(self):
+        """Test labeling shared x and y axes"""
+        meshseries = examples.meshseries_THM_2D
+        fig, ax = plt.subplots(2, 2, sharex=True, sharey=True)
+        plot(meshseries.read(0), presets.temperature, fig=fig, ax=ax[0][0])
+        plot(meshseries.read(1), presets.temperature, fig=fig, ax=ax[1][0])
+        plot_diff(
+            meshseries.read(0),
+            meshseries.read(1),
+            presets.temperature,
+            fig=fig,
+            ax=ax[0][1],
+        )
+        plot_diff(
+            meshseries.read(1),
+            meshseries.read(0),
+            presets.temperature,
+            fig=fig,
+            ax=ax[1][1],
+        )
+        ax = label_spatial_axes(ax, np.array([0, 1]))
+        fig.tight_layout()
+
+    def test_spatial_label(self):
+        """Test if labels are added to x and y axes"""
+        fig, ax = plt.subplots(2, 2)
+        ax = label_spatial_axes(ax, np.array([0, 1]))
+
+    def test_spatial_label_clear(self):
+        """Test if labels are added to x and y axes"""
+        fig, ax = plt.subplots(2, 2)
+        ax = label_spatial_axes(ax, np.array([0, 1]))
+        ax = clear_labels(ax)
+
     def test_limit_plots(self):
         """Test creation of limit plots."""
         meshseries = examples.meshseries_CT_2D
-- 
GitLab