Skip to content
Snippets Groups Projects
Commit 66eed66d authored by Florian Zill's avatar Florian Zill
Browse files

Merge branch 'meshplotlib_ax_handling' into 'main'

Enable use of external fig and ax in plot functions and plotting different variables within one figure

See merge request ogs/tools/ogstools!122
parents 2fe5cb58 82fbdf4a
No related branches found
No related tags found
No related merge requests found
"""
Plotting different process variables on already existing Matplotlib figures / axes
==================================================================================
.. sectionauthor:: Feliks Kiszkurno (Helmholtz Centre for Environmental Research GmbH - UFZ)
For this example we load a 2D meshseries dataset from within the ``meshplotlib`` examples.
This tutorial covers using meshplotlib to plot meshseries data using Matplotlib objects for Figure and / or Axis. This is useful if different plotting functions from Meshplotlib are to be used on different subplots within the same figure
"""
# %%
# Import Python packages, change some settings and load example data set
import matplotlib.pyplot as plt
import ogstools.meshplotlib as mpl
from ogstools.meshplotlib import examples, plot, plot_diff
from ogstools.propertylib import presets
plt.rcParams.update({"font.size": 32})
mpl.setup.reset()
mpl.setup.length.output_unit = "km"
meshseries = examples.meshseries_THM_2D
# %%
# It is possible to plot various process parameter in different subplots of the same figure. But each mesh series and parameter pair need a separate call of plot function:
mpl.setup.combined_colorbar = False
fig, ax = plt.subplots(2, 1, figsize=(15, 15))
plot(meshseries.read(0), presets.temperature, fig=fig, ax=ax[0])
plot(meshseries.read(1), presets.displacement, fig=fig, ax=ax[1])
fig.suptitle("Compare temperature and displacement")
fig.tight_layout()
# %%
# The same way difference between process variables at different time steps can be plotted. Color bars can be drawn automatically, if user provides both Figure and Axes objects:
fig, ax = plt.subplots(3, 1, figsize=(20, 30))
plot(meshseries.read(0), presets.temperature, fig=fig, ax=ax[0])
ax[0].set_title(r"$T(\mathrm{t}_{0})$")
plot(meshseries.read(1), presets.temperature, fig=fig, ax=ax[1])
ax[1].set_title(r"$T(\mathrm{t}_{end})$")
plot_diff(
meshseries.read(0),
meshseries.read(1),
presets.temperature,
fig=fig,
ax=ax[2],
)
ax[2].set_title(r"$T(\mathrm{t}_{end})$-$T(\mathrm{t}_{0})$")
fig.suptitle("Plot two time steps and their difference - with colorbars")
fig.tight_layout()
......@@ -3,7 +3,14 @@
from .plot_setup import _setup as setup # noqa: I001: noqa
from .core import plot_diff, plot_limit, plot_probe, plot, subplot
from .core import (
plot_diff,
plot_limit,
plot_probe,
plot,
subplot,
update_font_sizes,
)
from .plot_features import plot_contour, plot_on_top
__all__ = [
......@@ -15,4 +22,5 @@ __all__ = [
"plot_on_top",
"plot_probe",
"subplot",
"update_font_sizes",
]
......@@ -13,7 +13,7 @@ from ogstools.meshlib import MeshSeries
from ogstools.propertylib import Property
from . import setup
from .core import _plot_on_figure, plot
from .core import _draw_plot, plot
def animate(
......@@ -54,7 +54,7 @@ def animate(
mesh = mesh_series.read_interp(i, True)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
fig = _plot_on_figure(fig, mesh, property)
fig = _draw_plot(mesh, property, fig=fig)
_func = partial(animate_func, fig=fig)
......
"""Meshplotlib core utilitites."""
import warnings
from copy import deepcopy
from math import nextafter
from typing import Literal, Optional, Union
......@@ -141,8 +142,14 @@ def add_colorbars(
cm = mcm.ScalarMappable(norm=norm, cmap=cmap)
cb = fig.colorbar(
cm, norm=norm, ax=ax, ticks=ticks, drawedges=True, location="right",
spacing="uniform", pad=pad # fmt: skip
cm,
norm=norm,
ax=ax,
ticks=ticks,
drawedges=True,
location="right",
spacing="uniform",
pad=pad, # fmt: skip
)
# Formatting the colorbar label and ticks
......@@ -300,7 +307,7 @@ def _get_rows_cols(
np.ndarray,
pv.UnstructuredGrid,
pv.MultiBlock,
]
],
) -> tuple[int, ...]:
if isinstance(meshes, np.ndarray):
if meshes.ndim in [1, 2]:
......@@ -375,23 +382,60 @@ def get_combined_levels(
return compute_levels(p_min, p_max, setup.num_levels)
def _plot_on_figure(
fig: mfigure.Figure,
# TODO: Have a look at fig and ax logic and make it more readable
def _draw_plot(
meshes: Union[list[pv.UnstructuredGrid], np.ndarray, pv.UnstructuredGrid],
mesh_property: Property,
fig: Optional[mfigure.Figure] = None,
ax: Optional[plt.Axes] = None,
) -> mfigure.Figure:
"""
Plot the property field of meshes on existing figure.
:param meshes: Singular mesh of 2D numpy array of meshes
:param property: the property field to be visualized on all meshes
:param fig: Matplotlib Figure object to use for plotting (optional)
:param ax: Matplotlib Axis object to use for plotting (optional)
"""
shape = _get_rows_cols(meshes)
np_meshes = np.reshape(meshes, shape)
np_axs = np.reshape(fig.axes, shape)
if fig is not None and ax is not None:
# Fig is given, but plotting on specific ax is intended
# Only one mesh in meshes shall be accepted
if shape != (1, 1):
msg = "You have provided only one Axis object but multiple meshes. Provide only one mesh per Axis object, or provide Figure object instead."
raise ValueError(msg)
np_axs = np.reshape(np.array(ax), (1, 1))
elif fig is not None and ax is not None:
# Fig is given and Ax is given
# Multiple meshes should be accepted
if shape != (1, 1):
msg = "You have provided only one Axis object but multiple meshes. Provide only one mesh per Axis object, or provide Figure object instead."
raise ValueError(msg)
np_axs = np.reshape(ax, (1, 1))
elif fig is not None and ax is None:
# Only Fig is given
# Multiple meshes should be accepted
warnings.warn(
"This is not a good practice. Consider providing both fig and ax instead. This option may lead to unexpected behaviour and may be removed without warning in the future.",
Warning,
stacklevel=4,
)
np_axs = np.reshape(fig.axes, shape)
elif fig is None and ax is not None:
# Only ax is given
# Only one mesh should be accepted
if shape != (1, 1):
msg = "You have provided only one Axis object but multiple meshes. Provide only one mesh per Axis object, or provide Figure object instead."
raise ValueError(msg)
np_axs = np.reshape(np.array(ax), (1, 1))
else:
msg = "Neither Figure nor Axis object was provided."
raise TypeError(msg)
if setup.combined_colorbar:
combined_levels = get_combined_levels(np_meshes, mesh_property)
for i in range(shape[0]):
for j in range(shape[1]):
_levels = (
......@@ -400,25 +444,39 @@ def _plot_on_figure(
else get_combined_levels(np_meshes[i, j], mesh_property)
)
subplot(np_meshes[i, j], mesh_property, np_axs[i, j], _levels)
np_axs[0, 0].set_title(setup.title_center, loc="center", y=1.02)
np_axs[0, 0].set_title(setup.title_left, loc="left", y=1.02)
np_axs[0, 0].set_title(setup.title_right, loc="right", y=1.02)
# make extra space for the upper limit of the colorbar
if setup.layout == "tight":
plt.tight_layout(pad=1.4)
if setup.combined_colorbar:
cb_axs = np.ravel(fig.axes).tolist()
add_colorbars(
fig, cb_axs, mesh_property, combined_levels, pad=0.05 / shape[1]
)
if fig is None:
warnings.warn(
"Cannot plot combined colorbar if Figure object is not provided!",
Warning,
stacklevel=4,
)
else:
cb_axs = np.ravel(fig.axes).tolist()
add_colorbars(
fig, cb_axs, mesh_property, combined_levels, pad=0.05 / shape[1]
)
else:
for i in range(shape[0]):
for j in range(shape[1]):
_levels = get_combined_levels(np_meshes[i, j], mesh_property)
add_colorbars(fig, np_axs[i, j], mesh_property, _levels)
# TODO: restructure this logic
if fig is None:
warnings.warn(
"Figure object is required to plot individual colorbars for Axes objects.",
Warning,
stacklevel=4,
)
else:
for i in range(shape[0]):
for j in range(shape[1]):
_levels = get_combined_levels(
np_meshes[i, j], mesh_property
)
add_colorbars(fig, np_axs[i, j], mesh_property, _levels)
return fig
......@@ -433,11 +491,41 @@ def get_data_aspect(mesh: pv.DataSet) -> float:
return abs(lims[x_id + 1] - lims[x_id]) / abs(lims[y_id + 1] - lims[y_id])
def update_font_sizes(
fig: mfigure.Figure,
fontsize: int = 20,
) -> mfigure.Figure:
"""
Update font sizes of lebels and ticks in all subplots
:param fig: Matplotlib Figure object to use for plotting
:param int: New font size for the labels and ticks (optional)
"""
x_label = f"X / {setup.length.output_unit}"
y_label = f"Y / {setup.length.output_unit}"
for subax in fig.axes:
subax.set_xlabel(x_label, fontsize=fontsize)
subax.set_ylabel(y_label, fontsize=fontsize)
subax_xlim = subax.get_xlim()
subax_ylim = subax.get_ylim()
subax.set_xticks(
subax.get_xticks(), subax.get_xticklabels(), fontsize=fontsize
)
subax.set_yticks(
subax.get_yticks(), subax.get_yticklabels(), fontsize=fontsize
)
subax.set_xlim(subax_xlim)
subax.set_ylim(subax_ylim)
return fig
# TODO: add as arguments: cmap, limits
# TODO: num_levels should be min_levels
def plot(
meshes: Union[list[pv.UnstructuredGrid], np.ndarray, pv.UnstructuredGrid],
mesh_property: Union[Property, str],
fig: Optional[mfigure.Figure] = None,
ax: Optional[plt.Axes] = None,
) -> mfigure.Figure:
"""
Plot the property field of meshes with default settings.
......@@ -447,8 +535,9 @@ def plot(
:param meshes: Singular mesh of 2D numpy array of meshes
:param property: The property field to be visualized on all meshes
:param fig: Matplotlib Figure object to use for plotting (optional)
:param ax: Matplotlib Axis object to use for plotting (optional)
"""
rcParams.update(setup.rcParams_scaled)
shape = _get_rows_cols(meshes)
_meshes = np.reshape(meshes, shape).flatten()
......@@ -463,11 +552,22 @@ def plot(
np.clip(data_aspects, setup.min_ax_aspect, setup.max_ax_aspect)
)
ax_aspects = fig_aspect / data_aspects
_fig = _fig_init(rows=shape[0], cols=shape[1], aspect=fig_aspect)
n_axs = shape[0] * shape[1]
fig = _plot_on_figure(_fig, meshes, mesh_property)
for ax, aspect in zip(fig.axes[: n_axs + 1], ax_aspects):
ax.set_aspect(1.0 / aspect)
if ax is None and fig is None:
_fig = _fig_init(rows=shape[0], cols=shape[1], aspect=fig_aspect)
fig = _draw_plot(meshes, mesh_property, fig=_fig)
for ax, aspect in zip(fig.axes[: n_axs + 1], ax_aspects):
ax.set_aspect(1.0 / aspect)
elif ax is not None and fig is None:
ax.set_aspect(1.0 / ax_aspects[0])
_draw_plot(meshes, mesh_property, ax=ax)
elif ax is None and fig is not None:
fig = _draw_plot(meshes, mesh_property, fig=fig)
for ax, aspect in zip(fig.axes[: n_axs + 1], ax_aspects):
ax.set_aspect(1.0 / aspect)
elif ax is not None and fig is not None:
_draw_plot(meshes, mesh_property, fig=fig, ax=ax)
ax.set_aspect(1.0 / ax_aspects[0])
return fig
......@@ -475,6 +575,8 @@ def plot_diff(
mesh1: pv.UnstructuredGrid,
mesh2: pv.UnstructuredGrid,
mesh_property: Union[Property, str],
fig: Optional[mfigure.Figure] = None,
ax: Optional[plt.Axes] = None,
) -> mfigure.Figure:
if isinstance(mesh_property, str):
data_shape = mesh1[mesh_property].shape
......@@ -495,13 +597,15 @@ def plot_diff(
bilinear_cmap=True,
cmap=mesh_property.cmap if mesh_property.bilinear_cmap else "coolwarm",
)
return plot(diff_mesh, diff_property)
return plot(diff_mesh, diff_property, fig=fig, ax=ax)
def plot_limit(
mesh_series: MeshSeries,
mesh_property: Union[Property, str],
limit: Literal["min", "max"],
fig: Optional[mfigure.Figure] = None,
ax: Optional[plt.Axes] = None,
) -> mfigure.Figure:
"""
Plot the property limits through all timesteps of a MeshSeries.
......@@ -509,6 +613,8 @@ def plot_limit(
:param mesh_series: MeshSeries object containing the data to be plotted
:param property: The property field to be evaluated
:param limit: Type of limit to be computed
:param fig: Matplotlib Figure object to use for plotting (optional)
:param ax: Matplotlib Axis object to use for plotting (optional)
:returns: A matplotlib Figure
"""
......@@ -522,7 +628,7 @@ def plot_limit(
limit_property = mesh_property.replace(
output_name=limit + " " + mesh_property.output_name
)
return plot(mesh, limit_property)
return plot(mesh, limit_property, fig=fig, ax=ax)
def plot_probe(
......
......@@ -4,6 +4,7 @@ import unittest
from functools import partial
from tempfile import mkstemp
import matplotlib.pyplot as plt
import numpy as np
from pyvista import examples as pv_examples
......@@ -14,6 +15,7 @@ from ogstools.meshplotlib import (
plot_limit,
plot_probe,
setup,
update_font_sizes,
)
from ogstools.meshplotlib.animation import animate, save_animation
from ogstools.meshplotlib.core import get_ticklabels
......@@ -133,6 +135,60 @@ class MeshplotlibTest(unittest.TestCase):
meshseries = examples.meshseries_CT_2D
plot_diff(meshseries.read(0), meshseries.read(1), "Si")
def test_user_defined_ax(self):
"""Test creating plot with subfigures and user provided ax"""
meshseries = examples.meshseries_THM_2D
fig, ax = plt.subplots(3, 1, figsize=(40, 30))
plot(meshseries.read(0), presets.temperature, fig=fig, ax=ax[0])
ax[0].set_title(r"$T(\mathrm{t}_{0})$")
plot(meshseries.read(1), presets.temperature, fig=fig, ax=ax[1])
ax[1].set_title(r"$T(\mathrm{t}_{end})$")
plot_diff(
meshseries.read(0),
meshseries.read(1),
presets.temperature,
fig=fig,
ax=ax[2],
)
ax[2].set_title(r"$T(\mathrm{t}_{end})$-$T(\mathrm{t}_{0})$")
# fig.suptitle("Test user defined ax")
fig.tight_layout()
def test_user_defined_ax_diff_vals(self):
"""Test creating plot with subfigures and user provided ax with different values plotted"""
meshseries = examples.meshseries_THM_2D
setup.combined_colorbar = False
fig, ax = plt.subplots(2, 1, figsize=(40, 20))
plot(meshseries.read(0), presets.temperature, fig=fig, ax=ax[0])
plot(meshseries.read(1), presets.displacement, fig=fig, ax=ax[1])
fig.suptitle("Test user defined ax")
fig.tight_layout()
def test_user_defined_fig(self):
"""Test creating plot with subfigures and user provided fig"""
meshseries = examples.meshseries_THM_2D
setup.combined_colorbar = False
fig, ax = plt.subplots(2, 1, figsize=(40, 20))
plot(
[meshseries.read(0), meshseries.read(1)],
presets.temperature,
fig=fig,
)
fig.suptitle("Test user defined fig")
def test_update_font_sizes(self):
"""Test creating plot with subfigures and user provided fig"""
meshseries = examples.meshseries_THM_2D
setup.combined_colorbar = False
fig, ax = plt.subplots(2, 1, figsize=(40, 20))
plot(
[meshseries.read(0), meshseries.read(1)],
presets.temperature,
fig=fig,
)
fig = update_font_sizes(fig, fontsize=25)
fig.suptitle("Test user defined fig")
def test_limit_plots(self):
"""Test creation of limit plots."""
meshseries = examples.meshseries_CT_2D
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment