diff --git a/docs/examples/howto_meshplotlib/plot_with_custom_fig_ax.py b/docs/examples/howto_meshplotlib/plot_with_custom_fig_ax.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd86df98f71825dcc3e1615be27de16081d8aae2
--- /dev/null
+++ b/docs/examples/howto_meshplotlib/plot_with_custom_fig_ax.py
@@ -0,0 +1,54 @@
+"""
+Plotting different process variables on already existing Matplotlib figures / axes
+==================================================================================
+
+.. sectionauthor:: Feliks Kiszkurno (Helmholtz Centre for Environmental Research GmbH - UFZ)
+
+For this example we load a 2D meshseries dataset from within the ``meshplotlib`` examples.
+This tutorial covers using meshplotlib to plot meshseries data using Matplotlib objects for Figure and / or Axis. This is useful if different plotting functions from Meshplotlib are to be used on different subplots within the same figure
+"""
+
+# %%
+# Import Python packages, change some settings and load example data set
+import matplotlib.pyplot as plt
+
+import ogstools.meshplotlib as mpl
+from ogstools.meshplotlib import examples, plot, plot_diff
+from ogstools.propertylib import presets
+
+plt.rcParams.update({"font.size": 32})
+
+mpl.setup.reset()
+mpl.setup.length.output_unit = "km"
+
+meshseries = examples.meshseries_THM_2D
+
+
+# %%
+# It is possible to plot various process parameter in different subplots of the same figure. But each mesh series and parameter pair need a separate call of plot function:
+
+mpl.setup.combined_colorbar = False
+fig, ax = plt.subplots(2, 1, figsize=(15, 15))
+plot(meshseries.read(0), presets.temperature, fig=fig, ax=ax[0])
+plot(meshseries.read(1), presets.displacement, fig=fig, ax=ax[1])
+fig.suptitle("Compare temperature and displacement")
+fig.tight_layout()
+
+# %%
+# The same way difference between process variables at different time steps can be plotted. Color bars can be drawn automatically, if user provides both Figure and Axes objects:
+
+fig, ax = plt.subplots(3, 1, figsize=(20, 30))
+plot(meshseries.read(0), presets.temperature, fig=fig, ax=ax[0])
+ax[0].set_title(r"$T(\mathrm{t}_{0})$")
+plot(meshseries.read(1), presets.temperature, fig=fig, ax=ax[1])
+ax[1].set_title(r"$T(\mathrm{t}_{end})$")
+plot_diff(
+    meshseries.read(0),
+    meshseries.read(1),
+    presets.temperature,
+    fig=fig,
+    ax=ax[2],
+)
+ax[2].set_title(r"$T(\mathrm{t}_{end})$-$T(\mathrm{t}_{0})$")
+fig.suptitle("Plot two time steps and their difference - with colorbars")
+fig.tight_layout()
diff --git a/ogstools/meshplotlib/__init__.py b/ogstools/meshplotlib/__init__.py
index ef020e43b6b86c9560df44526d94a40a5d6c4ac2..2168c807c341e29aed440991a4293ac43bbcc71f 100644
--- a/ogstools/meshplotlib/__init__.py
+++ b/ogstools/meshplotlib/__init__.py
@@ -3,7 +3,14 @@
 
 from .plot_setup import _setup as setup  # noqa: I001: noqa
 
-from .core import plot_diff, plot_limit, plot_probe, plot, subplot
+from .core import (
+    plot_diff,
+    plot_limit,
+    plot_probe,
+    plot,
+    subplot,
+    update_font_sizes,
+)
 from .plot_features import plot_contour, plot_on_top
 
 __all__ = [
@@ -15,4 +22,5 @@ __all__ = [
     "plot_on_top",
     "plot_probe",
     "subplot",
+    "update_font_sizes",
 ]
diff --git a/ogstools/meshplotlib/animation.py b/ogstools/meshplotlib/animation.py
index 295d97e42b756873e41842e65a034a6c94406e30..d210699339c16074efe1444aaa9f6c8c94d9c997 100644
--- a/ogstools/meshplotlib/animation.py
+++ b/ogstools/meshplotlib/animation.py
@@ -13,7 +13,7 @@ from ogstools.meshlib import MeshSeries
 from ogstools.propertylib import Property
 
 from . import setup
-from .core import _plot_on_figure, plot
+from .core import _draw_plot, plot
 
 
 def animate(
@@ -54,7 +54,7 @@ def animate(
             mesh = mesh_series.read_interp(i, True)
         with warnings.catch_warnings():
             warnings.simplefilter("ignore")
-            fig = _plot_on_figure(fig, mesh, property)
+            fig = _draw_plot(mesh, property, fig=fig)
 
     _func = partial(animate_func, fig=fig)
 
diff --git a/ogstools/meshplotlib/core.py b/ogstools/meshplotlib/core.py
index 0be970041c81d296795d741ac5e0eef0cdc92fbe..e64ea94c8d4178181d2b415af12e185341fd7c14 100644
--- a/ogstools/meshplotlib/core.py
+++ b/ogstools/meshplotlib/core.py
@@ -1,5 +1,6 @@
 """Meshplotlib core utilitites."""
 
+import warnings
 from copy import deepcopy
 from math import nextafter
 from typing import Literal, Optional, Union
@@ -141,8 +142,14 @@ def add_colorbars(
     cm = mcm.ScalarMappable(norm=norm, cmap=cmap)
 
     cb = fig.colorbar(
-        cm, norm=norm, ax=ax, ticks=ticks, drawedges=True, location="right",
-        spacing="uniform", pad=pad  # fmt: skip
+        cm,
+        norm=norm,
+        ax=ax,
+        ticks=ticks,
+        drawedges=True,
+        location="right",
+        spacing="uniform",
+        pad=pad,  # fmt: skip
     )
     # Formatting the colorbar label and ticks
 
@@ -300,7 +307,7 @@ def _get_rows_cols(
         np.ndarray,
         pv.UnstructuredGrid,
         pv.MultiBlock,
-    ]
+    ],
 ) -> tuple[int, ...]:
     if isinstance(meshes, np.ndarray):
         if meshes.ndim in [1, 2]:
@@ -375,23 +382,60 @@ def get_combined_levels(
     return compute_levels(p_min, p_max, setup.num_levels)
 
 
-def _plot_on_figure(
-    fig: mfigure.Figure,
+# TODO: Have a look at fig and ax logic and make it more readable
+
+
+def _draw_plot(
     meshes: Union[list[pv.UnstructuredGrid], np.ndarray, pv.UnstructuredGrid],
     mesh_property: Property,
+    fig: Optional[mfigure.Figure] = None,
+    ax: Optional[plt.Axes] = None,
 ) -> mfigure.Figure:
     """
     Plot the property field of meshes on existing figure.
 
     :param meshes: Singular mesh of 2D numpy array of meshes
     :param property: the property field to be visualized on all meshes
+    :param fig: Matplotlib Figure object to use for plotting (optional)
+    :param ax: Matplotlib Axis object to use for plotting (optional)
     """
     shape = _get_rows_cols(meshes)
     np_meshes = np.reshape(meshes, shape)
-    np_axs = np.reshape(fig.axes, shape)
+    if fig is not None and ax is not None:
+        # Fig is given, but plotting on specific ax is intended
+        # Only one mesh in meshes shall be accepted
+        if shape != (1, 1):
+            msg = "You have provided only one Axis object but multiple meshes. Provide only one mesh per Axis object, or provide Figure object instead."
+            raise ValueError(msg)
+        np_axs = np.reshape(np.array(ax), (1, 1))
+    elif fig is not None and ax is not None:
+        # Fig is given and Ax is given
+        # Multiple meshes should be accepted
+        if shape != (1, 1):
+            msg = "You have provided only one Axis object but multiple meshes. Provide only one mesh per Axis object, or provide Figure object instead."
+            raise ValueError(msg)
+        np_axs = np.reshape(ax, (1, 1))
+    elif fig is not None and ax is None:
+        # Only Fig is given
+        # Multiple meshes should be accepted
+        warnings.warn(
+            "This is not a good practice. Consider providing both fig and ax instead. This option may lead to unexpected behaviour and may be removed without warning in the future.",
+            Warning,
+            stacklevel=4,
+        )
+        np_axs = np.reshape(fig.axes, shape)
+    elif fig is None and ax is not None:
+        # Only ax is given
+        # Only one mesh should be accepted
+        if shape != (1, 1):
+            msg = "You have provided only one Axis object but multiple meshes. Provide only one mesh per Axis object, or provide Figure object instead."
+            raise ValueError(msg)
+        np_axs = np.reshape(np.array(ax), (1, 1))
+    else:
+        msg = "Neither Figure nor Axis object was provided."
+        raise TypeError(msg)
     if setup.combined_colorbar:
         combined_levels = get_combined_levels(np_meshes, mesh_property)
-
     for i in range(shape[0]):
         for j in range(shape[1]):
             _levels = (
@@ -400,25 +444,39 @@ def _plot_on_figure(
                 else get_combined_levels(np_meshes[i, j], mesh_property)
             )
             subplot(np_meshes[i, j], mesh_property, np_axs[i, j], _levels)
-
     np_axs[0, 0].set_title(setup.title_center, loc="center", y=1.02)
     np_axs[0, 0].set_title(setup.title_left, loc="left", y=1.02)
     np_axs[0, 0].set_title(setup.title_right, loc="right", y=1.02)
     # make extra space for the upper limit of the colorbar
     if setup.layout == "tight":
         plt.tight_layout(pad=1.4)
-
     if setup.combined_colorbar:
-        cb_axs = np.ravel(fig.axes).tolist()
-        add_colorbars(
-            fig, cb_axs, mesh_property, combined_levels, pad=0.05 / shape[1]
-        )
+        if fig is None:
+            warnings.warn(
+                "Cannot plot combined colorbar if Figure object is not provided!",
+                Warning,
+                stacklevel=4,
+            )
+        else:
+            cb_axs = np.ravel(fig.axes).tolist()
+            add_colorbars(
+                fig, cb_axs, mesh_property, combined_levels, pad=0.05 / shape[1]
+            )
     else:
-        for i in range(shape[0]):
-            for j in range(shape[1]):
-                _levels = get_combined_levels(np_meshes[i, j], mesh_property)
-                add_colorbars(fig, np_axs[i, j], mesh_property, _levels)
-
+        # TODO: restructure this logic
+        if fig is None:
+            warnings.warn(
+                "Figure object is required to plot individual colorbars for Axes objects.",
+                Warning,
+                stacklevel=4,
+            )
+        else:
+            for i in range(shape[0]):
+                for j in range(shape[1]):
+                    _levels = get_combined_levels(
+                        np_meshes[i, j], mesh_property
+                    )
+                    add_colorbars(fig, np_axs[i, j], mesh_property, _levels)
     return fig
 
 
@@ -433,11 +491,41 @@ def get_data_aspect(mesh: pv.DataSet) -> float:
     return abs(lims[x_id + 1] - lims[x_id]) / abs(lims[y_id + 1] - lims[y_id])
 
 
+def update_font_sizes(
+    fig: mfigure.Figure,
+    fontsize: int = 20,
+) -> mfigure.Figure:
+    """
+    Update font sizes of lebels and ticks in all subplots
+
+    :param fig: Matplotlib Figure object to use for plotting
+    :param int: New font size for the labels and ticks (optional)
+    """
+    x_label = f"X / {setup.length.output_unit}"
+    y_label = f"Y / {setup.length.output_unit}"
+    for subax in fig.axes:
+        subax.set_xlabel(x_label, fontsize=fontsize)
+        subax.set_ylabel(y_label, fontsize=fontsize)
+        subax_xlim = subax.get_xlim()
+        subax_ylim = subax.get_ylim()
+        subax.set_xticks(
+            subax.get_xticks(), subax.get_xticklabels(), fontsize=fontsize
+        )
+        subax.set_yticks(
+            subax.get_yticks(), subax.get_yticklabels(), fontsize=fontsize
+        )
+        subax.set_xlim(subax_xlim)
+        subax.set_ylim(subax_ylim)
+    return fig
+
+
 # TODO: add as arguments: cmap, limits
 # TODO: num_levels should be min_levels
 def plot(
     meshes: Union[list[pv.UnstructuredGrid], np.ndarray, pv.UnstructuredGrid],
     mesh_property: Union[Property, str],
+    fig: Optional[mfigure.Figure] = None,
+    ax: Optional[plt.Axes] = None,
 ) -> mfigure.Figure:
     """
     Plot the property field of meshes with default settings.
@@ -447,8 +535,9 @@ def plot(
 
     :param meshes:      Singular mesh of 2D numpy array of meshes
     :param property:    The property field to be visualized on all meshes
+    :param fig: Matplotlib Figure object to use for plotting (optional)
+    :param ax: Matplotlib Axis object to use for plotting (optional)
     """
-
     rcParams.update(setup.rcParams_scaled)
     shape = _get_rows_cols(meshes)
     _meshes = np.reshape(meshes, shape).flatten()
@@ -463,11 +552,22 @@ def plot(
             np.clip(data_aspects, setup.min_ax_aspect, setup.max_ax_aspect)
         )
     ax_aspects = fig_aspect / data_aspects
-    _fig = _fig_init(rows=shape[0], cols=shape[1], aspect=fig_aspect)
     n_axs = shape[0] * shape[1]
-    fig = _plot_on_figure(_fig, meshes, mesh_property)
-    for ax, aspect in zip(fig.axes[: n_axs + 1], ax_aspects):
-        ax.set_aspect(1.0 / aspect)
+    if ax is None and fig is None:
+        _fig = _fig_init(rows=shape[0], cols=shape[1], aspect=fig_aspect)
+        fig = _draw_plot(meshes, mesh_property, fig=_fig)
+        for ax, aspect in zip(fig.axes[: n_axs + 1], ax_aspects):
+            ax.set_aspect(1.0 / aspect)
+    elif ax is not None and fig is None:
+        ax.set_aspect(1.0 / ax_aspects[0])
+        _draw_plot(meshes, mesh_property, ax=ax)
+    elif ax is None and fig is not None:
+        fig = _draw_plot(meshes, mesh_property, fig=fig)
+        for ax, aspect in zip(fig.axes[: n_axs + 1], ax_aspects):
+            ax.set_aspect(1.0 / aspect)
+    elif ax is not None and fig is not None:
+        _draw_plot(meshes, mesh_property, fig=fig, ax=ax)
+        ax.set_aspect(1.0 / ax_aspects[0])
     return fig
 
 
@@ -475,6 +575,8 @@ def plot_diff(
     mesh1: pv.UnstructuredGrid,
     mesh2: pv.UnstructuredGrid,
     mesh_property: Union[Property, str],
+    fig: Optional[mfigure.Figure] = None,
+    ax: Optional[plt.Axes] = None,
 ) -> mfigure.Figure:
     if isinstance(mesh_property, str):
         data_shape = mesh1[mesh_property].shape
@@ -495,13 +597,15 @@ def plot_diff(
         bilinear_cmap=True,
         cmap=mesh_property.cmap if mesh_property.bilinear_cmap else "coolwarm",
     )
-    return plot(diff_mesh, diff_property)
+    return plot(diff_mesh, diff_property, fig=fig, ax=ax)
 
 
 def plot_limit(
     mesh_series: MeshSeries,
     mesh_property: Union[Property, str],
     limit: Literal["min", "max"],
+    fig: Optional[mfigure.Figure] = None,
+    ax: Optional[plt.Axes] = None,
 ) -> mfigure.Figure:
     """
     Plot the property limits through all timesteps of a MeshSeries.
@@ -509,6 +613,8 @@ def plot_limit(
     :param mesh_series: MeshSeries object containing the data to be plotted
     :param property:    The property field to be evaluated
     :param limit:       Type of limit to be computed
+    :param fig: Matplotlib Figure object to use for plotting (optional)
+    :param ax: Matplotlib Axis object to use for plotting (optional)
 
     :returns:   A matplotlib Figure
     """
@@ -522,7 +628,7 @@ def plot_limit(
     limit_property = mesh_property.replace(
         output_name=limit + " " + mesh_property.output_name
     )
-    return plot(mesh, limit_property)
+    return plot(mesh, limit_property, fig=fig, ax=ax)
 
 
 def plot_probe(
diff --git a/tests/test_meshplotlib.py b/tests/test_meshplotlib.py
index 8df8486d581abf2a5fca1d49134637137e843d35..ddc3bf56f95fbfd6f7ac7a4717852310b613a546 100644
--- a/tests/test_meshplotlib.py
+++ b/tests/test_meshplotlib.py
@@ -4,6 +4,7 @@ import unittest
 from functools import partial
 from tempfile import mkstemp
 
+import matplotlib.pyplot as plt
 import numpy as np
 from pyvista import examples as pv_examples
 
@@ -14,6 +15,7 @@ from ogstools.meshplotlib import (
     plot_limit,
     plot_probe,
     setup,
+    update_font_sizes,
 )
 from ogstools.meshplotlib.animation import animate, save_animation
 from ogstools.meshplotlib.core import get_ticklabels
@@ -133,6 +135,60 @@ class MeshplotlibTest(unittest.TestCase):
         meshseries = examples.meshseries_CT_2D
         plot_diff(meshseries.read(0), meshseries.read(1), "Si")
 
+    def test_user_defined_ax(self):
+        """Test creating plot with subfigures and user provided ax"""
+        meshseries = examples.meshseries_THM_2D
+        fig, ax = plt.subplots(3, 1, figsize=(40, 30))
+        plot(meshseries.read(0), presets.temperature, fig=fig, ax=ax[0])
+        ax[0].set_title(r"$T(\mathrm{t}_{0})$")
+        plot(meshseries.read(1), presets.temperature, fig=fig, ax=ax[1])
+        ax[1].set_title(r"$T(\mathrm{t}_{end})$")
+        plot_diff(
+            meshseries.read(0),
+            meshseries.read(1),
+            presets.temperature,
+            fig=fig,
+            ax=ax[2],
+        )
+        ax[2].set_title(r"$T(\mathrm{t}_{end})$-$T(\mathrm{t}_{0})$")
+        # fig.suptitle("Test user defined ax")
+        fig.tight_layout()
+
+    def test_user_defined_ax_diff_vals(self):
+        """Test creating plot with subfigures and user provided ax with different values plotted"""
+        meshseries = examples.meshseries_THM_2D
+        setup.combined_colorbar = False
+        fig, ax = plt.subplots(2, 1, figsize=(40, 20))
+        plot(meshseries.read(0), presets.temperature, fig=fig, ax=ax[0])
+        plot(meshseries.read(1), presets.displacement, fig=fig, ax=ax[1])
+        fig.suptitle("Test user defined ax")
+        fig.tight_layout()
+
+    def test_user_defined_fig(self):
+        """Test creating plot with subfigures and user provided fig"""
+        meshseries = examples.meshseries_THM_2D
+        setup.combined_colorbar = False
+        fig, ax = plt.subplots(2, 1, figsize=(40, 20))
+        plot(
+            [meshseries.read(0), meshseries.read(1)],
+            presets.temperature,
+            fig=fig,
+        )
+        fig.suptitle("Test user defined fig")
+
+    def test_update_font_sizes(self):
+        """Test creating plot with subfigures and user provided fig"""
+        meshseries = examples.meshseries_THM_2D
+        setup.combined_colorbar = False
+        fig, ax = plt.subplots(2, 1, figsize=(40, 20))
+        plot(
+            [meshseries.read(0), meshseries.read(1)],
+            presets.temperature,
+            fig=fig,
+        )
+        fig = update_font_sizes(fig, fontsize=25)
+        fig.suptitle("Test user defined fig")
+
     def test_limit_plots(self):
         """Test creation of limit plots."""
         meshseries = examples.meshseries_CT_2D