diff --git a/docs/examples/howto_meshplotlib/plot_xy_labels_with_shared_axes.py b/docs/examples/howto_meshplotlib/plot_xy_labels_with_shared_axes.py
index 2e74d7338a551a1ad701b8505fc35491a47d25f4..ded8afdfb74b569656f56e6a4f11efd81858c61c 100644
--- a/docs/examples/howto_meshplotlib/plot_xy_labels_with_shared_axes.py
+++ b/docs/examples/howto_meshplotlib/plot_xy_labels_with_shared_axes.py
@@ -14,7 +14,6 @@ under various conditions (shared and nor shared X and Y axes).
 # %%
 # Import Python packages, change some settings and load example data set
 import matplotlib.pyplot as plt
-import numpy as np
 
 from ogstools.meshlib import difference
 from ogstools.meshplotlib import (
diff --git a/ogstools/meshplotlib/__init__.py b/ogstools/meshplotlib/__init__.py
index e50d776f1f964a022e25d9ff3e8b97f8859d6431..74b7c05e1fe9335136b564962ef44591fbab2e0a 100644
--- a/ogstools/meshplotlib/__init__.py
+++ b/ogstools/meshplotlib/__init__.py
@@ -10,7 +10,6 @@ from .core import (
     update_font_sizes,
     label_spatial_axes,
     clear_labels,
-    clear_labels_ax,
 )
 from .plot_features import plot_contour, plot_on_top
 
@@ -24,5 +23,4 @@ __all__ = [
     "update_font_sizes",
     "label_spatial_axes",
     "clear_labels",
-    "clear_labels_ax",
 ]
diff --git a/ogstools/meshplotlib/core.py b/ogstools/meshplotlib/core.py
index 4962ef0bf5b68380e267c39f34fd1e8e5fa43724..ca139bbfab2d452876c5fcd5f6ad6ebf03b21575 100644
--- a/ogstools/meshplotlib/core.py
+++ b/ogstools/meshplotlib/core.py
@@ -6,7 +6,6 @@ from typing import Literal, Optional, Union
 
 import numpy as np
 import pyvista as pv
-from matplotlib import axes as matplax
 from matplotlib import cm as mcm
 from matplotlib import colormaps, rcParams
 from matplotlib import colors as mcolors
@@ -14,6 +13,7 @@ from matplotlib import figure as mfigure
 from matplotlib import pyplot as plt
 from matplotlib import ticker as mticker
 from matplotlib.patches import Rectangle as Rect
+from typeguard import typechecked
 
 from ogstools.meshlib import MeshSeries
 from ogstools.propertylib import Property, Vector
@@ -304,72 +304,31 @@ def subplot(
         secax.set_xlabel(f'{"xyz"[projection]} / {setup.length.output_unit}')
 
 
-def clear_labels_ax(ax: plt.axes) -> plt.axes:
-    ax.set_xlabel(None)
-    ax.set_ylabel(None)
-    return ax
-
-
-def clear_labels(ax: Union[plt.axes, np.array]) -> Union[plt.axes, np.array]:
-    if isinstance(ax, np.ndarray):
-        for id_n in range(ax.shape[0]):
-            for id_m in range(ax.shape[1]):
-                ax[id_n, id_m] = clear_labels_ax(ax[id_n, id_m])
-    if isinstance(ax, matplax.Axes):
-        # Wrap single axis in np.array
-        ax = clear_labels_ax(ax)
-    return ax
+def clear_labels(axes: Union[plt.Axes, np.ndarray]) -> None:
+    ax: plt.Axes
+    for ax in np.ravel(np.array(axes)):
+        ax.set_xlabel("")
+        ax.set_ylabel("")
 
 
+@typechecked
 def label_spatial_axes(
-    ax: Union[plt.axes, np.array],
-    ax_ids: np.array,
-) -> plt.axes:
+    axes: Union[plt.Axes, np.ndarray], x_label: str = "x", y_label: str = "y"
+) -> None:
     """
-    Add labels to X and Y axis
+    Add labels to x and y axis.
 
-    Automatically selects correct pair of directions and unit.
-    Respects sharex / sharey settings.
-
-    :param ax: Matplotlib Axis object
-    :param ax_ids: indices of axes label [0,1,2] for [x,y,z] for horizontal and vertical axis
+    If given an array of axes, only the outer axes will be labeled.
     """
-    if isinstance(ax, np.ndarray):
-        # Labels will be applied to shared axis
-        # Shared axis = value in projection is not None
-        if ax_ids[0] is not None:
-            x_label = (
-                setup.x_label
-                or f'{"xyz"[ax_ids[0]]} / {setup.length.output_unit}'
-            )
-            for ax_temp in ax[-1, :]:
-                ax_temp.set_xlabel(x_label)
-        if ax_ids[1] is not None:
-            y_label = (
-                setup.y_label
-                or f'{"xyz"[ax_ids[1]]} / {setup.length.output_unit}'
-            )
-            for ax_temp in ax[:, 0]:
-                ax_temp.set_ylabel(y_label)
-    elif isinstance(ax, matplax.Axes):
-        # Labels will only be applied to non shared axis:
-        # Non shared = value in projection is not None
-        if ax_ids[0] is not None:
-            x_label = (
-                setup.x_label
-                or f'{"xyz"[ax_ids[0]]} / {setup.length.output_unit}'
-            )
-            ax.set_xlabel(x_label)
-        if ax_ids[1] is not None:
-            y_label = (
-                setup.y_label
-                or f'{"xyz"[ax_ids[1]]} / {setup.length.output_unit}'
-            )
-            ax.set_ylabel(y_label)
+    if isinstance(axes, np.ndarray):
+        ax: plt.Axes
+        for ax in axes[-1, :]:
+            ax.set_xlabel(f"{x_label} / {setup.length.output_unit}")
+        for ax in axes[:, 0]:
+            ax.set_ylabel(f"{y_label} / {setup.length.output_unit}")
     else:
-        msg = "ax is neither Matplotlib axis nor Numpy array!"
-        raise TypeError(msg)
-    return ax
+        axes.set_xlabel(f"{x_label} / {setup.length.output_unit}")
+        axes.set_ylabel(f"{y_label} / {setup.length.output_unit}")
 
 
 def _get_rows_cols(
@@ -428,7 +387,7 @@ def get_combined_levels(
     """
     Calculate well spaced levels for the encompassing property range in meshes.
     """
-    mesh_property = get_preset(mesh_property, meshes[0])
+    mesh_property = get_preset(mesh_property, meshes.ravel()[0])
     p_min, p_max = np.inf, -np.inf
     unique_vals = np.array([])
     for mesh in np.ravel(meshes):
@@ -604,7 +563,7 @@ def plot(
     """
     rcParams.update(setup.rcParams_scaled)
     shape = _get_rows_cols(meshes)
-    _meshes = np.reshape(meshes, shape).flatten()
+    _meshes = np.reshape(meshes, shape).ravel()
     mesh_property = get_preset(mesh_property, _meshes[0])
     data_aspects = np.asarray([get_data_aspect(mesh) for mesh in _meshes])
     if setup.min_ax_aspect is None and setup.max_ax_aspect is None:
diff --git a/tests/test_meshplotlib.py b/tests/test_meshplotlib.py
index 2455e3dec72ad4d9eea0878f125c2bdd68834bfe..958dd59bf14c66be4ac7a3b62923e1564fe9f841 100644
--- a/tests/test_meshplotlib.py
+++ b/tests/test_meshplotlib.py
@@ -232,13 +232,13 @@ class MeshplotlibTest(unittest.TestCase):
     def test_spatial_label(self):
         """Test axes labeling"""
         fig, ax = plt.subplots(2, 2)
-        ax = label_spatial_axes(ax, np.array([0, 1]))
+        label_spatial_axes(ax, "x", "y")
 
     def test_spatial_label_clear(self):
         """Test axes labels clearing"""
         fig, ax = plt.subplots(2, 2)
-        ax = label_spatial_axes(ax, np.array([0, 1]))
-        ax = clear_labels(ax)
+        label_spatial_axes(ax, "x", "y")
+        clear_labels(ax)
 
     def test_limit_plots(self):
         """Test creation of limit plots."""