diff --git a/docs/examples/howto_meshplotlib/plot_meshplotlib_2d.py b/docs/examples/howto_meshplotlib/plot_meshplotlib_2d.py index 9be7ab03bbee74806a3a276ccb9f91e8a9ad28f3..006e86e3c2e3f449b88db95cf15ff74b6b0be742 100644 --- a/docs/examples/howto_meshplotlib/plot_meshplotlib_2d.py +++ b/docs/examples/howto_meshplotlib/plot_meshplotlib_2d.py @@ -38,6 +38,14 @@ fig = mpl.plot(mesh, presets.pressure.get_mask()) fig = mpl.plot(mesh, presets.velocity) +# %% +# Let's plot it again, this time log-scaled. +mpl.setup.log_scaled = True +mpl.setup.p_min = -8 +fig = mpl.plot(mesh, presets.velocity) +mpl.setup.p_min = None +mpl.setup.log_scaled = False + # %% # We can also plot components of vector property: diff --git a/ogstools/meshplotlib/core.py b/ogstools/meshplotlib/core.py index 6ee8c65266ff756eb9af4d3a179b4c6de2ffa68e..4e7284265b7f7036748c9075c173714bcab983d0 100644 --- a/ogstools/meshplotlib/core.py +++ b/ogstools/meshplotlib/core.py @@ -1,6 +1,5 @@ """Meshplotlib core utilitites.""" -import types from copy import deepcopy from math import nextafter from typing import Literal, Optional, Union @@ -13,7 +12,6 @@ from matplotlib import colors as mcolors from matplotlib import figure as mfigure from matplotlib import pyplot as plt from matplotlib import ticker as mticker -from matplotlib import transforms as mtransforms from matplotlib.patches import Rectangle as Rect from ogstools.meshlib import MeshSeries @@ -23,7 +21,7 @@ from ogstools.propertylib.unit_registry import u_reg from . import plot_features as pf from . import setup -from .levels import get_levels +from .levels import get_levels, get_median_exponent from .utils import get_style_cycler # TODO: define default data_name for regions in setup @@ -51,7 +49,7 @@ def get_data( property.data_name not in mesh.point_data and property.data_name not in mesh.cell_data ): - msg = f"Property not found in mesh {mesh}." + msg = f"Property {property.data_name} not found in mesh." raise IndexError(msg) if masked: return mesh.ctp(True).threshold(value=[1, 1], scalars=property.mask)[ @@ -103,25 +101,6 @@ def get_cmap_norm( return cmap, norm -# to fix scientific offset position -# https://github.com/matplotlib/matplotlib/issues/4476#issuecomment-105627334 -def fix_scientific_offset_position(axis, func): - axis._update_offset_text_position = types.MethodType(func, axis) - - -def y_update_offset_text_position(self, bboxes, bboxes2): # noqa: ARG001 - x, y = self.offsetText.get_position() - # y in axes coords, x in display coords - self.offsetText.set_transform( - mtransforms.blended_transform_factory( - self.axes.transAxes, mtransforms.IdentityTransform() - ) - ) - top = self.axes.bbox.ymax - y = top + 2 * self.OFFSETTEXTPAD * self.figure.dpi / 72.0 - self.offsetText.set_position((x, y)) - - def add_colorbars( fig: mfigure.Figure, ax: Union[plt.Axes, list[plt.Axes]], @@ -131,60 +110,63 @@ def add_colorbars( labelsize: Optional[float] = None, ) -> None: """Add a colorbar to the matplotlib figure.""" - cmap, norm = get_cmap_norm(levels, property) - cm = mcm.ScalarMappable(norm=norm, cmap=cmap) - categoric = property.categoric or (len(levels) == 2) - if categoric: + ticks = levels + if property.categoric or (len(levels) == 2): bounds = get_level_boundaries(levels) ticks = bounds[:-1] + 0.5 * np.diff(bounds) - else: - ticks = levels + + cmap, norm = get_cmap_norm(levels, property) + cm = mcm.ScalarMappable(norm=norm, cmap=cmap) + cb = fig.colorbar( cm, norm=norm, ax=ax, ticks=ticks, drawedges=True, location="right", spacing="uniform", pad=pad, format="%.3g" # fmt: skip ) - if setup.invert_colorbar: - cb.ax.invert_yaxis() - if property.is_mask(): - cb.ax.add_patch(Rect((0, 0.5), 1, -1, lw=0, fc="none", hatch="/")) - if not categoric and setup.log_scaled: - levels = 10**levels - - unit_str = ( - f" / {property.get_output_unit()}" if property.get_output_unit() else "" - ) + # Formatting the colorbar label + cb_label = property.output_name.replace("_", " ") + " / " + POWER_LIMIT = 3 + if ( + abs(median_exponent := get_median_exponent(ticks)) > POWER_LIMIT + ) and not (np.isclose(ticks[0], ticks[-1]) or setup.log_scaled): + ticks = ticks * 10**-median_exponent + cb_label += f"10$^{{{median_exponent}}}$ " + if unit := property.get_output_unit(): + cb_label += f"{unit}" + if cb_label[-3:] == " / ": + cb_label = cb_label[:-3] + if setup.log_scaled: + cb_label = f"log$_{{10}}$( {cb_label} )" labelsize = ( setup.rcParams_scaled["font.size"] if labelsize is None else labelsize ) - cb.set_label( - property.output_name.replace("_", " ") + unit_str, size=labelsize - ) - cb.ax.tick_params(labelsize=labelsize, direction="out") - mf = mticker.ScalarFormatter(useMathText=True, useOffset=True) - mf.set_scientific(True) - mf.set_powerlimits([-3, 3]) - fix_scientific_offset_position(cb.ax.yaxis, y_update_offset_text_position) - cb.ax.yaxis.set_offset_position("left") - cb.ax.yaxis.set_major_formatter(mf) - - if _q_zero_line(property, levels): - cb.ax.axhline( - y=0, color="w", lw=2 * setup.rcParams_scaled["lines.linewidth"] - ) - if setup.log_scaled: - cb.ax.set_yticklabels(10**ticks) + cb.set_label(cb_label, size=labelsize) + + # formatting the colorbar ticks + cb.ax.tick_params(labelsize=labelsize, direction="out") + cb.ax.yaxis.set_major_formatter(mticker.FormatStrFormatter("%.3g")) + tick_labels = [ + f"{round(tick, POWER_LIMIT):.{POWER_LIMIT}g}" for tick in ticks + ] if property.data_name == "MaterialIDs" and setup.material_names is not None: - region_names = [] - for mat_id in levels: - if mat_id in setup.material_names: - region_names += [setup.material_names[mat_id]] - else: - region_names += [mat_id] - cb.ax.set_yticklabels(region_names) + tick_labels = [ + setup.material_names.get(mat_id, mat_id) for mat_id in levels + ] cb.ax.set_ylabel("") elif property.categoric: - cb.ax.set_yticklabels(levels.astype(int)) + tick_labels = [str(level) for level in levels.astype(int)] + cb.ax.set_yticklabels(tick_labels) + + # miscellaneous + + if property.is_mask(): + cb.ax.add_patch(Rect((0, 0.5), 1, -1, lw=0, fc="none", hatch="/")) + if setup.invert_colorbar: + cb.ax.invert_yaxis() + if _q_zero_line(property, ticks): + cb.ax.axhline( + y=0, color="w", lw=2 * setup.rcParams_scaled["lines.linewidth"] + ) def subplot( diff --git a/ogstools/meshplotlib/levels.py b/ogstools/meshplotlib/levels.py index 261e53503cd012cfa9a01f5411c02c107e17d6a6..3f0e7c13189c2fd78c1b80130002105fe996060e 100644 --- a/ogstools/meshplotlib/levels.py +++ b/ogstools/meshplotlib/levels.py @@ -1,5 +1,7 @@ """Utilities to create nicely spaced levels.""" +from math import nextafter + import numpy as np @@ -21,29 +23,14 @@ def nice_range(lower: float, upper: float, n_ticks: float) -> np.ndarray: The length of the arrays will be close to n_ticks. """ - nice_range = nice_num(upper - lower) - tick_spacing = nice_num(nice_range / (n_ticks - 1.0)) + base = nice_num(upper - lower) + tick_spacing = nice_num(base / (n_ticks - 1.0)) nice_lower = np.ceil(lower / tick_spacing) * tick_spacing nice_upper = np.ceil(upper / tick_spacing) * tick_spacing res = np.arange(nice_lower, nice_upper, tick_spacing) return res[(res > lower) & (res < upper)] -def adaptive_rounding(vals: np.ndarray, precision: int) -> np.ndarray: - """ - Return the given values rounded to significant digits. - - The significant digits are based of the median decimal exponent and the - given precision. - """ - if vals.size == 0: - return vals - log = np.log10(np.abs(vals), out=np.zeros_like(vals), where=(vals != 0.0)) - exponents = np.floor(log).astype(int) - median_exp = int(np.median(exponents)) - return np.stack([np.round(v, precision - median_exp) for v in vals]) - - def get_levels(lower: float, upper: float, n_ticks: int) -> np.ndarray: """ Return an array in the interval [lower, upper] with terminating decimals. @@ -51,10 +38,14 @@ def get_levels(lower: float, upper: float, n_ticks: int) -> np.ndarray: The length of the arrays will be close to n_ticks. At the boundaries the tickspacing may differ from the remaining array. """ - n_digits = 6 - if lower * (1 - 10**-n_digits) <= upper <= lower * (1 + 10**-n_digits): - return lower + np.array([0.0, 10**-n_digits]) + if lower == upper: + return np.asarray([lower, nextafter(lower, np.inf)]) levels = nice_range(lower, upper, n_ticks) - return np.unique( - np.append(np.append(lower, adaptive_rounding(levels, n_digits)), upper) - ) + return np.unique(np.append(np.append(lower, levels), upper)) + + +def get_median_exponent(vals: np.ndarray) -> int: + "Get the median exponent from an array of numbers." + log = np.log10(np.abs(vals), out=np.zeros_like(vals), where=(vals != 0.0)) + exponents = np.floor(log).astype(int) + return int(np.median(exponents)) diff --git a/tests/test_meshplotlib.py b/tests/test_meshplotlib.py index e74feedaaf4cbbadc882df0a933d64e335385808..bab263036aacb4c11f8e9dc7ede58e7d193c701b 100644 --- a/tests/test_meshplotlib.py +++ b/tests/test_meshplotlib.py @@ -21,7 +21,9 @@ from ogstools.meshplotlib.plot_features import plot_on_top from ogstools.meshplotlib.utils import justified_labels from ogstools.propertylib import Scalar, presets -equality = partial(np.testing.assert_allclose, rtol=1e-7, verbose=True) +equality = partial( + np.testing.assert_allclose, rtol=1e-7, atol=1e-100, verbose=True +) class MeshplotlibTest(unittest.TestCase): @@ -42,8 +44,8 @@ class MeshplotlibTest(unittest.TestCase): equality(get_levels(1e-3, 1.2, 5), [1e-3, *np.arange(0.2, 1.4, 0.2)]) equality(get_levels(1e5, 9e6, 20), [1e5, *np.arange(5e5, 9.5e6, 5e5)]) equality(get_levels(1, 40, 20), [1, *range(2, 42, 2)]) - equality(get_levels(0.0, 0.0, 10), [0.0, 1e-6]) - equality(get_levels(1e9, 1e9, 10), [1e9, 1e9 + 1e-6]) + equality(get_levels(0.0, 0.0, 10), [0.0, 0.0]) + equality(get_levels(1e9, 1e9, 10), [1e9, 1e9]) def test_justified_labels(self): points = np.asarray(