Skip to content
Snippets Groups Projects
Commit 3f8431ec authored by Florian Zill's avatar Florian Zill
Browse files

[meshplotlib] refactored levels and colorbar

parent b8bfc342
No related branches found
No related tags found
1 merge request!119Stress Analysis and Refactoring
...@@ -38,6 +38,14 @@ fig = mpl.plot(mesh, presets.pressure.get_mask()) ...@@ -38,6 +38,14 @@ fig = mpl.plot(mesh, presets.pressure.get_mask())
fig = mpl.plot(mesh, presets.velocity) fig = mpl.plot(mesh, presets.velocity)
# %%
# Let's plot it again, this time log-scaled.
mpl.setup.log_scaled = True
mpl.setup.p_min = -8
fig = mpl.plot(mesh, presets.velocity)
mpl.setup.p_min = None
mpl.setup.log_scaled = False
# %% # %%
# We can also plot components of vector property: # We can also plot components of vector property:
......
"""Meshplotlib core utilitites.""" """Meshplotlib core utilitites."""
import types
from copy import deepcopy from copy import deepcopy
from math import nextafter from math import nextafter
from typing import Literal, Optional, Union from typing import Literal, Optional, Union
...@@ -13,7 +12,6 @@ from matplotlib import colors as mcolors ...@@ -13,7 +12,6 @@ from matplotlib import colors as mcolors
from matplotlib import figure as mfigure from matplotlib import figure as mfigure
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from matplotlib import ticker as mticker from matplotlib import ticker as mticker
from matplotlib import transforms as mtransforms
from matplotlib.patches import Rectangle as Rect from matplotlib.patches import Rectangle as Rect
from ogstools.meshlib import MeshSeries from ogstools.meshlib import MeshSeries
...@@ -23,7 +21,7 @@ from ogstools.propertylib.unit_registry import u_reg ...@@ -23,7 +21,7 @@ from ogstools.propertylib.unit_registry import u_reg
from . import plot_features as pf from . import plot_features as pf
from . import setup from . import setup
from .levels import get_levels from .levels import get_levels, get_median_exponent
from .utils import get_style_cycler from .utils import get_style_cycler
# TODO: define default data_name for regions in setup # TODO: define default data_name for regions in setup
...@@ -51,7 +49,7 @@ def get_data( ...@@ -51,7 +49,7 @@ def get_data(
property.data_name not in mesh.point_data property.data_name not in mesh.point_data
and property.data_name not in mesh.cell_data and property.data_name not in mesh.cell_data
): ):
msg = f"Property not found in mesh {mesh}." msg = f"Property {property.data_name} not found in mesh."
raise IndexError(msg) raise IndexError(msg)
if masked: if masked:
return mesh.ctp(True).threshold(value=[1, 1], scalars=property.mask)[ return mesh.ctp(True).threshold(value=[1, 1], scalars=property.mask)[
...@@ -103,25 +101,6 @@ def get_cmap_norm( ...@@ -103,25 +101,6 @@ def get_cmap_norm(
return cmap, norm return cmap, norm
# to fix scientific offset position
# https://github.com/matplotlib/matplotlib/issues/4476#issuecomment-105627334
def fix_scientific_offset_position(axis, func):
axis._update_offset_text_position = types.MethodType(func, axis)
def y_update_offset_text_position(self, bboxes, bboxes2): # noqa: ARG001
x, y = self.offsetText.get_position()
# y in axes coords, x in display coords
self.offsetText.set_transform(
mtransforms.blended_transform_factory(
self.axes.transAxes, mtransforms.IdentityTransform()
)
)
top = self.axes.bbox.ymax
y = top + 2 * self.OFFSETTEXTPAD * self.figure.dpi / 72.0
self.offsetText.set_position((x, y))
def add_colorbars( def add_colorbars(
fig: mfigure.Figure, fig: mfigure.Figure,
ax: Union[plt.Axes, list[plt.Axes]], ax: Union[plt.Axes, list[plt.Axes]],
...@@ -131,60 +110,63 @@ def add_colorbars( ...@@ -131,60 +110,63 @@ def add_colorbars(
labelsize: Optional[float] = None, labelsize: Optional[float] = None,
) -> None: ) -> None:
"""Add a colorbar to the matplotlib figure.""" """Add a colorbar to the matplotlib figure."""
cmap, norm = get_cmap_norm(levels, property) ticks = levels
cm = mcm.ScalarMappable(norm=norm, cmap=cmap) if property.categoric or (len(levels) == 2):
categoric = property.categoric or (len(levels) == 2)
if categoric:
bounds = get_level_boundaries(levels) bounds = get_level_boundaries(levels)
ticks = bounds[:-1] + 0.5 * np.diff(bounds) ticks = bounds[:-1] + 0.5 * np.diff(bounds)
else:
ticks = levels cmap, norm = get_cmap_norm(levels, property)
cm = mcm.ScalarMappable(norm=norm, cmap=cmap)
cb = fig.colorbar( cb = fig.colorbar(
cm, norm=norm, ax=ax, ticks=ticks, drawedges=True, location="right", cm, norm=norm, ax=ax, ticks=ticks, drawedges=True, location="right",
spacing="uniform", pad=pad, format="%.3g" # fmt: skip spacing="uniform", pad=pad, format="%.3g" # fmt: skip
) )
if setup.invert_colorbar: # Formatting the colorbar label
cb.ax.invert_yaxis() cb_label = property.output_name.replace("_", " ") + " / "
if property.is_mask(): POWER_LIMIT = 3
cb.ax.add_patch(Rect((0, 0.5), 1, -1, lw=0, fc="none", hatch="/")) if (
if not categoric and setup.log_scaled: abs(median_exponent := get_median_exponent(ticks)) > POWER_LIMIT
levels = 10**levels ) and not (np.isclose(ticks[0], ticks[-1]) or setup.log_scaled):
ticks = ticks * 10**-median_exponent
unit_str = ( cb_label += f"10$^{{{median_exponent}}}$ "
f" / {property.get_output_unit()}" if property.get_output_unit() else "" if unit := property.get_output_unit():
) cb_label += f"{unit}"
if cb_label[-3:] == " / ":
cb_label = cb_label[:-3]
if setup.log_scaled:
cb_label = f"log$_{{10}}$( {cb_label} )"
labelsize = ( labelsize = (
setup.rcParams_scaled["font.size"] if labelsize is None else labelsize setup.rcParams_scaled["font.size"] if labelsize is None else labelsize
) )
cb.set_label( cb.set_label(cb_label, size=labelsize)
property.output_name.replace("_", " ") + unit_str, size=labelsize
) # formatting the colorbar ticks
cb.ax.tick_params(labelsize=labelsize, direction="out")
mf = mticker.ScalarFormatter(useMathText=True, useOffset=True)
mf.set_scientific(True)
mf.set_powerlimits([-3, 3])
fix_scientific_offset_position(cb.ax.yaxis, y_update_offset_text_position)
cb.ax.yaxis.set_offset_position("left")
cb.ax.yaxis.set_major_formatter(mf)
if _q_zero_line(property, levels):
cb.ax.axhline(
y=0, color="w", lw=2 * setup.rcParams_scaled["lines.linewidth"]
)
if setup.log_scaled:
cb.ax.set_yticklabels(10**ticks)
cb.ax.tick_params(labelsize=labelsize, direction="out")
cb.ax.yaxis.set_major_formatter(mticker.FormatStrFormatter("%.3g"))
tick_labels = [
f"{round(tick, POWER_LIMIT):.{POWER_LIMIT}g}" for tick in ticks
]
if property.data_name == "MaterialIDs" and setup.material_names is not None: if property.data_name == "MaterialIDs" and setup.material_names is not None:
region_names = [] tick_labels = [
for mat_id in levels: setup.material_names.get(mat_id, mat_id) for mat_id in levels
if mat_id in setup.material_names: ]
region_names += [setup.material_names[mat_id]]
else:
region_names += [mat_id]
cb.ax.set_yticklabels(region_names)
cb.ax.set_ylabel("") cb.ax.set_ylabel("")
elif property.categoric: elif property.categoric:
cb.ax.set_yticklabels(levels.astype(int)) tick_labels = [str(level) for level in levels.astype(int)]
cb.ax.set_yticklabels(tick_labels)
# miscellaneous
if property.is_mask():
cb.ax.add_patch(Rect((0, 0.5), 1, -1, lw=0, fc="none", hatch="/"))
if setup.invert_colorbar:
cb.ax.invert_yaxis()
if _q_zero_line(property, ticks):
cb.ax.axhline(
y=0, color="w", lw=2 * setup.rcParams_scaled["lines.linewidth"]
)
def subplot( def subplot(
......
"""Utilities to create nicely spaced levels.""" """Utilities to create nicely spaced levels."""
from math import nextafter
import numpy as np import numpy as np
...@@ -21,29 +23,14 @@ def nice_range(lower: float, upper: float, n_ticks: float) -> np.ndarray: ...@@ -21,29 +23,14 @@ def nice_range(lower: float, upper: float, n_ticks: float) -> np.ndarray:
The length of the arrays will be close to n_ticks. The length of the arrays will be close to n_ticks.
""" """
nice_range = nice_num(upper - lower) base = nice_num(upper - lower)
tick_spacing = nice_num(nice_range / (n_ticks - 1.0)) tick_spacing = nice_num(base / (n_ticks - 1.0))
nice_lower = np.ceil(lower / tick_spacing) * tick_spacing nice_lower = np.ceil(lower / tick_spacing) * tick_spacing
nice_upper = np.ceil(upper / tick_spacing) * tick_spacing nice_upper = np.ceil(upper / tick_spacing) * tick_spacing
res = np.arange(nice_lower, nice_upper, tick_spacing) res = np.arange(nice_lower, nice_upper, tick_spacing)
return res[(res > lower) & (res < upper)] return res[(res > lower) & (res < upper)]
def adaptive_rounding(vals: np.ndarray, precision: int) -> np.ndarray:
"""
Return the given values rounded to significant digits.
The significant digits are based of the median decimal exponent and the
given precision.
"""
if vals.size == 0:
return vals
log = np.log10(np.abs(vals), out=np.zeros_like(vals), where=(vals != 0.0))
exponents = np.floor(log).astype(int)
median_exp = int(np.median(exponents))
return np.stack([np.round(v, precision - median_exp) for v in vals])
def get_levels(lower: float, upper: float, n_ticks: int) -> np.ndarray: def get_levels(lower: float, upper: float, n_ticks: int) -> np.ndarray:
""" """
Return an array in the interval [lower, upper] with terminating decimals. Return an array in the interval [lower, upper] with terminating decimals.
...@@ -51,10 +38,14 @@ def get_levels(lower: float, upper: float, n_ticks: int) -> np.ndarray: ...@@ -51,10 +38,14 @@ def get_levels(lower: float, upper: float, n_ticks: int) -> np.ndarray:
The length of the arrays will be close to n_ticks. The length of the arrays will be close to n_ticks.
At the boundaries the tickspacing may differ from the remaining array. At the boundaries the tickspacing may differ from the remaining array.
""" """
n_digits = 6 if lower == upper:
if lower * (1 - 10**-n_digits) <= upper <= lower * (1 + 10**-n_digits): return np.asarray([lower, nextafter(lower, np.inf)])
return lower + np.array([0.0, 10**-n_digits])
levels = nice_range(lower, upper, n_ticks) levels = nice_range(lower, upper, n_ticks)
return np.unique( return np.unique(np.append(np.append(lower, levels), upper))
np.append(np.append(lower, adaptive_rounding(levels, n_digits)), upper)
)
def get_median_exponent(vals: np.ndarray) -> int:
"Get the median exponent from an array of numbers."
log = np.log10(np.abs(vals), out=np.zeros_like(vals), where=(vals != 0.0))
exponents = np.floor(log).astype(int)
return int(np.median(exponents))
...@@ -21,7 +21,9 @@ from ogstools.meshplotlib.plot_features import plot_on_top ...@@ -21,7 +21,9 @@ from ogstools.meshplotlib.plot_features import plot_on_top
from ogstools.meshplotlib.utils import justified_labels from ogstools.meshplotlib.utils import justified_labels
from ogstools.propertylib import Scalar, presets from ogstools.propertylib import Scalar, presets
equality = partial(np.testing.assert_allclose, rtol=1e-7, verbose=True) equality = partial(
np.testing.assert_allclose, rtol=1e-7, atol=1e-100, verbose=True
)
class MeshplotlibTest(unittest.TestCase): class MeshplotlibTest(unittest.TestCase):
...@@ -42,8 +44,8 @@ class MeshplotlibTest(unittest.TestCase): ...@@ -42,8 +44,8 @@ class MeshplotlibTest(unittest.TestCase):
equality(get_levels(1e-3, 1.2, 5), [1e-3, *np.arange(0.2, 1.4, 0.2)]) equality(get_levels(1e-3, 1.2, 5), [1e-3, *np.arange(0.2, 1.4, 0.2)])
equality(get_levels(1e5, 9e6, 20), [1e5, *np.arange(5e5, 9.5e6, 5e5)]) equality(get_levels(1e5, 9e6, 20), [1e5, *np.arange(5e5, 9.5e6, 5e5)])
equality(get_levels(1, 40, 20), [1, *range(2, 42, 2)]) equality(get_levels(1, 40, 20), [1, *range(2, 42, 2)])
equality(get_levels(0.0, 0.0, 10), [0.0, 1e-6]) equality(get_levels(0.0, 0.0, 10), [0.0, 0.0])
equality(get_levels(1e9, 1e9, 10), [1e9, 1e9 + 1e-6]) equality(get_levels(1e9, 1e9, 10), [1e9, 1e9])
def test_justified_labels(self): def test_justified_labels(self):
points = np.asarray( points = np.asarray(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment