diff --git a/docs/examples/howto_meshplotlib/plot_with_custom_fig_ax.py b/docs/examples/howto_meshplotlib/plot_with_custom_fig_ax.py new file mode 100644 index 0000000000000000000000000000000000000000..cd86df98f71825dcc3e1615be27de16081d8aae2 --- /dev/null +++ b/docs/examples/howto_meshplotlib/plot_with_custom_fig_ax.py @@ -0,0 +1,54 @@ +""" +Plotting different process variables on already existing Matplotlib figures / axes +================================================================================== + +.. sectionauthor:: Feliks Kiszkurno (Helmholtz Centre for Environmental Research GmbH - UFZ) + +For this example we load a 2D meshseries dataset from within the ``meshplotlib`` examples. +This tutorial covers using meshplotlib to plot meshseries data using Matplotlib objects for Figure and / or Axis. This is useful if different plotting functions from Meshplotlib are to be used on different subplots within the same figure +""" + +# %% +# Import Python packages, change some settings and load example data set +import matplotlib.pyplot as plt + +import ogstools.meshplotlib as mpl +from ogstools.meshplotlib import examples, plot, plot_diff +from ogstools.propertylib import presets + +plt.rcParams.update({"font.size": 32}) + +mpl.setup.reset() +mpl.setup.length.output_unit = "km" + +meshseries = examples.meshseries_THM_2D + + +# %% +# It is possible to plot various process parameter in different subplots of the same figure. But each mesh series and parameter pair need a separate call of plot function: + +mpl.setup.combined_colorbar = False +fig, ax = plt.subplots(2, 1, figsize=(15, 15)) +plot(meshseries.read(0), presets.temperature, fig=fig, ax=ax[0]) +plot(meshseries.read(1), presets.displacement, fig=fig, ax=ax[1]) +fig.suptitle("Compare temperature and displacement") +fig.tight_layout() + +# %% +# The same way difference between process variables at different time steps can be plotted. Color bars can be drawn automatically, if user provides both Figure and Axes objects: + +fig, ax = plt.subplots(3, 1, figsize=(20, 30)) +plot(meshseries.read(0), presets.temperature, fig=fig, ax=ax[0]) +ax[0].set_title(r"$T(\mathrm{t}_{0})$") +plot(meshseries.read(1), presets.temperature, fig=fig, ax=ax[1]) +ax[1].set_title(r"$T(\mathrm{t}_{end})$") +plot_diff( + meshseries.read(0), + meshseries.read(1), + presets.temperature, + fig=fig, + ax=ax[2], +) +ax[2].set_title(r"$T(\mathrm{t}_{end})$-$T(\mathrm{t}_{0})$") +fig.suptitle("Plot two time steps and their difference - with colorbars") +fig.tight_layout() diff --git a/ogstools/meshplotlib/__init__.py b/ogstools/meshplotlib/__init__.py index ef020e43b6b86c9560df44526d94a40a5d6c4ac2..2168c807c341e29aed440991a4293ac43bbcc71f 100644 --- a/ogstools/meshplotlib/__init__.py +++ b/ogstools/meshplotlib/__init__.py @@ -3,7 +3,14 @@ from .plot_setup import _setup as setup # noqa: I001: noqa -from .core import plot_diff, plot_limit, plot_probe, plot, subplot +from .core import ( + plot_diff, + plot_limit, + plot_probe, + plot, + subplot, + update_font_sizes, +) from .plot_features import plot_contour, plot_on_top __all__ = [ @@ -15,4 +22,5 @@ __all__ = [ "plot_on_top", "plot_probe", "subplot", + "update_font_sizes", ] diff --git a/ogstools/meshplotlib/animation.py b/ogstools/meshplotlib/animation.py index 295d97e42b756873e41842e65a034a6c94406e30..d210699339c16074efe1444aaa9f6c8c94d9c997 100644 --- a/ogstools/meshplotlib/animation.py +++ b/ogstools/meshplotlib/animation.py @@ -13,7 +13,7 @@ from ogstools.meshlib import MeshSeries from ogstools.propertylib import Property from . import setup -from .core import _plot_on_figure, plot +from .core import _draw_plot, plot def animate( @@ -54,7 +54,7 @@ def animate( mesh = mesh_series.read_interp(i, True) with warnings.catch_warnings(): warnings.simplefilter("ignore") - fig = _plot_on_figure(fig, mesh, property) + fig = _draw_plot(mesh, property, fig=fig) _func = partial(animate_func, fig=fig) diff --git a/ogstools/meshplotlib/core.py b/ogstools/meshplotlib/core.py index 0be970041c81d296795d741ac5e0eef0cdc92fbe..e64ea94c8d4178181d2b415af12e185341fd7c14 100644 --- a/ogstools/meshplotlib/core.py +++ b/ogstools/meshplotlib/core.py @@ -1,5 +1,6 @@ """Meshplotlib core utilitites.""" +import warnings from copy import deepcopy from math import nextafter from typing import Literal, Optional, Union @@ -141,8 +142,14 @@ def add_colorbars( cm = mcm.ScalarMappable(norm=norm, cmap=cmap) cb = fig.colorbar( - cm, norm=norm, ax=ax, ticks=ticks, drawedges=True, location="right", - spacing="uniform", pad=pad # fmt: skip + cm, + norm=norm, + ax=ax, + ticks=ticks, + drawedges=True, + location="right", + spacing="uniform", + pad=pad, # fmt: skip ) # Formatting the colorbar label and ticks @@ -300,7 +307,7 @@ def _get_rows_cols( np.ndarray, pv.UnstructuredGrid, pv.MultiBlock, - ] + ], ) -> tuple[int, ...]: if isinstance(meshes, np.ndarray): if meshes.ndim in [1, 2]: @@ -375,23 +382,60 @@ def get_combined_levels( return compute_levels(p_min, p_max, setup.num_levels) -def _plot_on_figure( - fig: mfigure.Figure, +# TODO: Have a look at fig and ax logic and make it more readable + + +def _draw_plot( meshes: Union[list[pv.UnstructuredGrid], np.ndarray, pv.UnstructuredGrid], mesh_property: Property, + fig: Optional[mfigure.Figure] = None, + ax: Optional[plt.Axes] = None, ) -> mfigure.Figure: """ Plot the property field of meshes on existing figure. :param meshes: Singular mesh of 2D numpy array of meshes :param property: the property field to be visualized on all meshes + :param fig: Matplotlib Figure object to use for plotting (optional) + :param ax: Matplotlib Axis object to use for plotting (optional) """ shape = _get_rows_cols(meshes) np_meshes = np.reshape(meshes, shape) - np_axs = np.reshape(fig.axes, shape) + if fig is not None and ax is not None: + # Fig is given, but plotting on specific ax is intended + # Only one mesh in meshes shall be accepted + if shape != (1, 1): + msg = "You have provided only one Axis object but multiple meshes. Provide only one mesh per Axis object, or provide Figure object instead." + raise ValueError(msg) + np_axs = np.reshape(np.array(ax), (1, 1)) + elif fig is not None and ax is not None: + # Fig is given and Ax is given + # Multiple meshes should be accepted + if shape != (1, 1): + msg = "You have provided only one Axis object but multiple meshes. Provide only one mesh per Axis object, or provide Figure object instead." + raise ValueError(msg) + np_axs = np.reshape(ax, (1, 1)) + elif fig is not None and ax is None: + # Only Fig is given + # Multiple meshes should be accepted + warnings.warn( + "This is not a good practice. Consider providing both fig and ax instead. This option may lead to unexpected behaviour and may be removed without warning in the future.", + Warning, + stacklevel=4, + ) + np_axs = np.reshape(fig.axes, shape) + elif fig is None and ax is not None: + # Only ax is given + # Only one mesh should be accepted + if shape != (1, 1): + msg = "You have provided only one Axis object but multiple meshes. Provide only one mesh per Axis object, or provide Figure object instead." + raise ValueError(msg) + np_axs = np.reshape(np.array(ax), (1, 1)) + else: + msg = "Neither Figure nor Axis object was provided." + raise TypeError(msg) if setup.combined_colorbar: combined_levels = get_combined_levels(np_meshes, mesh_property) - for i in range(shape[0]): for j in range(shape[1]): _levels = ( @@ -400,25 +444,39 @@ def _plot_on_figure( else get_combined_levels(np_meshes[i, j], mesh_property) ) subplot(np_meshes[i, j], mesh_property, np_axs[i, j], _levels) - np_axs[0, 0].set_title(setup.title_center, loc="center", y=1.02) np_axs[0, 0].set_title(setup.title_left, loc="left", y=1.02) np_axs[0, 0].set_title(setup.title_right, loc="right", y=1.02) # make extra space for the upper limit of the colorbar if setup.layout == "tight": plt.tight_layout(pad=1.4) - if setup.combined_colorbar: - cb_axs = np.ravel(fig.axes).tolist() - add_colorbars( - fig, cb_axs, mesh_property, combined_levels, pad=0.05 / shape[1] - ) + if fig is None: + warnings.warn( + "Cannot plot combined colorbar if Figure object is not provided!", + Warning, + stacklevel=4, + ) + else: + cb_axs = np.ravel(fig.axes).tolist() + add_colorbars( + fig, cb_axs, mesh_property, combined_levels, pad=0.05 / shape[1] + ) else: - for i in range(shape[0]): - for j in range(shape[1]): - _levels = get_combined_levels(np_meshes[i, j], mesh_property) - add_colorbars(fig, np_axs[i, j], mesh_property, _levels) - + # TODO: restructure this logic + if fig is None: + warnings.warn( + "Figure object is required to plot individual colorbars for Axes objects.", + Warning, + stacklevel=4, + ) + else: + for i in range(shape[0]): + for j in range(shape[1]): + _levels = get_combined_levels( + np_meshes[i, j], mesh_property + ) + add_colorbars(fig, np_axs[i, j], mesh_property, _levels) return fig @@ -433,11 +491,41 @@ def get_data_aspect(mesh: pv.DataSet) -> float: return abs(lims[x_id + 1] - lims[x_id]) / abs(lims[y_id + 1] - lims[y_id]) +def update_font_sizes( + fig: mfigure.Figure, + fontsize: int = 20, +) -> mfigure.Figure: + """ + Update font sizes of lebels and ticks in all subplots + + :param fig: Matplotlib Figure object to use for plotting + :param int: New font size for the labels and ticks (optional) + """ + x_label = f"X / {setup.length.output_unit}" + y_label = f"Y / {setup.length.output_unit}" + for subax in fig.axes: + subax.set_xlabel(x_label, fontsize=fontsize) + subax.set_ylabel(y_label, fontsize=fontsize) + subax_xlim = subax.get_xlim() + subax_ylim = subax.get_ylim() + subax.set_xticks( + subax.get_xticks(), subax.get_xticklabels(), fontsize=fontsize + ) + subax.set_yticks( + subax.get_yticks(), subax.get_yticklabels(), fontsize=fontsize + ) + subax.set_xlim(subax_xlim) + subax.set_ylim(subax_ylim) + return fig + + # TODO: add as arguments: cmap, limits # TODO: num_levels should be min_levels def plot( meshes: Union[list[pv.UnstructuredGrid], np.ndarray, pv.UnstructuredGrid], mesh_property: Union[Property, str], + fig: Optional[mfigure.Figure] = None, + ax: Optional[plt.Axes] = None, ) -> mfigure.Figure: """ Plot the property field of meshes with default settings. @@ -447,8 +535,9 @@ def plot( :param meshes: Singular mesh of 2D numpy array of meshes :param property: The property field to be visualized on all meshes + :param fig: Matplotlib Figure object to use for plotting (optional) + :param ax: Matplotlib Axis object to use for plotting (optional) """ - rcParams.update(setup.rcParams_scaled) shape = _get_rows_cols(meshes) _meshes = np.reshape(meshes, shape).flatten() @@ -463,11 +552,22 @@ def plot( np.clip(data_aspects, setup.min_ax_aspect, setup.max_ax_aspect) ) ax_aspects = fig_aspect / data_aspects - _fig = _fig_init(rows=shape[0], cols=shape[1], aspect=fig_aspect) n_axs = shape[0] * shape[1] - fig = _plot_on_figure(_fig, meshes, mesh_property) - for ax, aspect in zip(fig.axes[: n_axs + 1], ax_aspects): - ax.set_aspect(1.0 / aspect) + if ax is None and fig is None: + _fig = _fig_init(rows=shape[0], cols=shape[1], aspect=fig_aspect) + fig = _draw_plot(meshes, mesh_property, fig=_fig) + for ax, aspect in zip(fig.axes[: n_axs + 1], ax_aspects): + ax.set_aspect(1.0 / aspect) + elif ax is not None and fig is None: + ax.set_aspect(1.0 / ax_aspects[0]) + _draw_plot(meshes, mesh_property, ax=ax) + elif ax is None and fig is not None: + fig = _draw_plot(meshes, mesh_property, fig=fig) + for ax, aspect in zip(fig.axes[: n_axs + 1], ax_aspects): + ax.set_aspect(1.0 / aspect) + elif ax is not None and fig is not None: + _draw_plot(meshes, mesh_property, fig=fig, ax=ax) + ax.set_aspect(1.0 / ax_aspects[0]) return fig @@ -475,6 +575,8 @@ def plot_diff( mesh1: pv.UnstructuredGrid, mesh2: pv.UnstructuredGrid, mesh_property: Union[Property, str], + fig: Optional[mfigure.Figure] = None, + ax: Optional[plt.Axes] = None, ) -> mfigure.Figure: if isinstance(mesh_property, str): data_shape = mesh1[mesh_property].shape @@ -495,13 +597,15 @@ def plot_diff( bilinear_cmap=True, cmap=mesh_property.cmap if mesh_property.bilinear_cmap else "coolwarm", ) - return plot(diff_mesh, diff_property) + return plot(diff_mesh, diff_property, fig=fig, ax=ax) def plot_limit( mesh_series: MeshSeries, mesh_property: Union[Property, str], limit: Literal["min", "max"], + fig: Optional[mfigure.Figure] = None, + ax: Optional[plt.Axes] = None, ) -> mfigure.Figure: """ Plot the property limits through all timesteps of a MeshSeries. @@ -509,6 +613,8 @@ def plot_limit( :param mesh_series: MeshSeries object containing the data to be plotted :param property: The property field to be evaluated :param limit: Type of limit to be computed + :param fig: Matplotlib Figure object to use for plotting (optional) + :param ax: Matplotlib Axis object to use for plotting (optional) :returns: A matplotlib Figure """ @@ -522,7 +628,7 @@ def plot_limit( limit_property = mesh_property.replace( output_name=limit + " " + mesh_property.output_name ) - return plot(mesh, limit_property) + return plot(mesh, limit_property, fig=fig, ax=ax) def plot_probe( diff --git a/tests/test_meshplotlib.py b/tests/test_meshplotlib.py index 8df8486d581abf2a5fca1d49134637137e843d35..ddc3bf56f95fbfd6f7ac7a4717852310b613a546 100644 --- a/tests/test_meshplotlib.py +++ b/tests/test_meshplotlib.py @@ -4,6 +4,7 @@ import unittest from functools import partial from tempfile import mkstemp +import matplotlib.pyplot as plt import numpy as np from pyvista import examples as pv_examples @@ -14,6 +15,7 @@ from ogstools.meshplotlib import ( plot_limit, plot_probe, setup, + update_font_sizes, ) from ogstools.meshplotlib.animation import animate, save_animation from ogstools.meshplotlib.core import get_ticklabels @@ -133,6 +135,60 @@ class MeshplotlibTest(unittest.TestCase): meshseries = examples.meshseries_CT_2D plot_diff(meshseries.read(0), meshseries.read(1), "Si") + def test_user_defined_ax(self): + """Test creating plot with subfigures and user provided ax""" + meshseries = examples.meshseries_THM_2D + fig, ax = plt.subplots(3, 1, figsize=(40, 30)) + plot(meshseries.read(0), presets.temperature, fig=fig, ax=ax[0]) + ax[0].set_title(r"$T(\mathrm{t}_{0})$") + plot(meshseries.read(1), presets.temperature, fig=fig, ax=ax[1]) + ax[1].set_title(r"$T(\mathrm{t}_{end})$") + plot_diff( + meshseries.read(0), + meshseries.read(1), + presets.temperature, + fig=fig, + ax=ax[2], + ) + ax[2].set_title(r"$T(\mathrm{t}_{end})$-$T(\mathrm{t}_{0})$") + # fig.suptitle("Test user defined ax") + fig.tight_layout() + + def test_user_defined_ax_diff_vals(self): + """Test creating plot with subfigures and user provided ax with different values plotted""" + meshseries = examples.meshseries_THM_2D + setup.combined_colorbar = False + fig, ax = plt.subplots(2, 1, figsize=(40, 20)) + plot(meshseries.read(0), presets.temperature, fig=fig, ax=ax[0]) + plot(meshseries.read(1), presets.displacement, fig=fig, ax=ax[1]) + fig.suptitle("Test user defined ax") + fig.tight_layout() + + def test_user_defined_fig(self): + """Test creating plot with subfigures and user provided fig""" + meshseries = examples.meshseries_THM_2D + setup.combined_colorbar = False + fig, ax = plt.subplots(2, 1, figsize=(40, 20)) + plot( + [meshseries.read(0), meshseries.read(1)], + presets.temperature, + fig=fig, + ) + fig.suptitle("Test user defined fig") + + def test_update_font_sizes(self): + """Test creating plot with subfigures and user provided fig""" + meshseries = examples.meshseries_THM_2D + setup.combined_colorbar = False + fig, ax = plt.subplots(2, 1, figsize=(40, 20)) + plot( + [meshseries.read(0), meshseries.read(1)], + presets.temperature, + fig=fig, + ) + fig = update_font_sizes(fig, fontsize=25) + fig.suptitle("Test user defined fig") + def test_limit_plots(self): """Test creation of limit plots.""" meshseries = examples.meshseries_CT_2D