Skip to content
Snippets Groups Projects
Commit 3f8431ec authored by Florian Zill's avatar Florian Zill
Browse files

[meshplotlib] refactored levels and colorbar

parent b8bfc342
No related branches found
No related tags found
No related merge requests found
......@@ -38,6 +38,14 @@ fig = mpl.plot(mesh, presets.pressure.get_mask())
fig = mpl.plot(mesh, presets.velocity)
# %%
# Let's plot it again, this time log-scaled.
mpl.setup.log_scaled = True
mpl.setup.p_min = -8
fig = mpl.plot(mesh, presets.velocity)
mpl.setup.p_min = None
mpl.setup.log_scaled = False
# %%
# We can also plot components of vector property:
......
"""Meshplotlib core utilitites."""
import types
from copy import deepcopy
from math import nextafter
from typing import Literal, Optional, Union
......@@ -13,7 +12,6 @@ from matplotlib import colors as mcolors
from matplotlib import figure as mfigure
from matplotlib import pyplot as plt
from matplotlib import ticker as mticker
from matplotlib import transforms as mtransforms
from matplotlib.patches import Rectangle as Rect
from ogstools.meshlib import MeshSeries
......@@ -23,7 +21,7 @@ from ogstools.propertylib.unit_registry import u_reg
from . import plot_features as pf
from . import setup
from .levels import get_levels
from .levels import get_levels, get_median_exponent
from .utils import get_style_cycler
# TODO: define default data_name for regions in setup
......@@ -51,7 +49,7 @@ def get_data(
property.data_name not in mesh.point_data
and property.data_name not in mesh.cell_data
):
msg = f"Property not found in mesh {mesh}."
msg = f"Property {property.data_name} not found in mesh."
raise IndexError(msg)
if masked:
return mesh.ctp(True).threshold(value=[1, 1], scalars=property.mask)[
......@@ -103,25 +101,6 @@ def get_cmap_norm(
return cmap, norm
# to fix scientific offset position
# https://github.com/matplotlib/matplotlib/issues/4476#issuecomment-105627334
def fix_scientific_offset_position(axis, func):
axis._update_offset_text_position = types.MethodType(func, axis)
def y_update_offset_text_position(self, bboxes, bboxes2): # noqa: ARG001
x, y = self.offsetText.get_position()
# y in axes coords, x in display coords
self.offsetText.set_transform(
mtransforms.blended_transform_factory(
self.axes.transAxes, mtransforms.IdentityTransform()
)
)
top = self.axes.bbox.ymax
y = top + 2 * self.OFFSETTEXTPAD * self.figure.dpi / 72.0
self.offsetText.set_position((x, y))
def add_colorbars(
fig: mfigure.Figure,
ax: Union[plt.Axes, list[plt.Axes]],
......@@ -131,60 +110,63 @@ def add_colorbars(
labelsize: Optional[float] = None,
) -> None:
"""Add a colorbar to the matplotlib figure."""
cmap, norm = get_cmap_norm(levels, property)
cm = mcm.ScalarMappable(norm=norm, cmap=cmap)
categoric = property.categoric or (len(levels) == 2)
if categoric:
ticks = levels
if property.categoric or (len(levels) == 2):
bounds = get_level_boundaries(levels)
ticks = bounds[:-1] + 0.5 * np.diff(bounds)
else:
ticks = levels
cmap, norm = get_cmap_norm(levels, property)
cm = mcm.ScalarMappable(norm=norm, cmap=cmap)
cb = fig.colorbar(
cm, norm=norm, ax=ax, ticks=ticks, drawedges=True, location="right",
spacing="uniform", pad=pad, format="%.3g" # fmt: skip
)
if setup.invert_colorbar:
cb.ax.invert_yaxis()
if property.is_mask():
cb.ax.add_patch(Rect((0, 0.5), 1, -1, lw=0, fc="none", hatch="/"))
if not categoric and setup.log_scaled:
levels = 10**levels
unit_str = (
f" / {property.get_output_unit()}" if property.get_output_unit() else ""
)
# Formatting the colorbar label
cb_label = property.output_name.replace("_", " ") + " / "
POWER_LIMIT = 3
if (
abs(median_exponent := get_median_exponent(ticks)) > POWER_LIMIT
) and not (np.isclose(ticks[0], ticks[-1]) or setup.log_scaled):
ticks = ticks * 10**-median_exponent
cb_label += f"10$^{{{median_exponent}}}$ "
if unit := property.get_output_unit():
cb_label += f"{unit}"
if cb_label[-3:] == " / ":
cb_label = cb_label[:-3]
if setup.log_scaled:
cb_label = f"log$_{{10}}$( {cb_label} )"
labelsize = (
setup.rcParams_scaled["font.size"] if labelsize is None else labelsize
)
cb.set_label(
property.output_name.replace("_", " ") + unit_str, size=labelsize
)
cb.ax.tick_params(labelsize=labelsize, direction="out")
mf = mticker.ScalarFormatter(useMathText=True, useOffset=True)
mf.set_scientific(True)
mf.set_powerlimits([-3, 3])
fix_scientific_offset_position(cb.ax.yaxis, y_update_offset_text_position)
cb.ax.yaxis.set_offset_position("left")
cb.ax.yaxis.set_major_formatter(mf)
if _q_zero_line(property, levels):
cb.ax.axhline(
y=0, color="w", lw=2 * setup.rcParams_scaled["lines.linewidth"]
)
if setup.log_scaled:
cb.ax.set_yticklabels(10**ticks)
cb.set_label(cb_label, size=labelsize)
# formatting the colorbar ticks
cb.ax.tick_params(labelsize=labelsize, direction="out")
cb.ax.yaxis.set_major_formatter(mticker.FormatStrFormatter("%.3g"))
tick_labels = [
f"{round(tick, POWER_LIMIT):.{POWER_LIMIT}g}" for tick in ticks
]
if property.data_name == "MaterialIDs" and setup.material_names is not None:
region_names = []
for mat_id in levels:
if mat_id in setup.material_names:
region_names += [setup.material_names[mat_id]]
else:
region_names += [mat_id]
cb.ax.set_yticklabels(region_names)
tick_labels = [
setup.material_names.get(mat_id, mat_id) for mat_id in levels
]
cb.ax.set_ylabel("")
elif property.categoric:
cb.ax.set_yticklabels(levels.astype(int))
tick_labels = [str(level) for level in levels.astype(int)]
cb.ax.set_yticklabels(tick_labels)
# miscellaneous
if property.is_mask():
cb.ax.add_patch(Rect((0, 0.5), 1, -1, lw=0, fc="none", hatch="/"))
if setup.invert_colorbar:
cb.ax.invert_yaxis()
if _q_zero_line(property, ticks):
cb.ax.axhline(
y=0, color="w", lw=2 * setup.rcParams_scaled["lines.linewidth"]
)
def subplot(
......
"""Utilities to create nicely spaced levels."""
from math import nextafter
import numpy as np
......@@ -21,29 +23,14 @@ def nice_range(lower: float, upper: float, n_ticks: float) -> np.ndarray:
The length of the arrays will be close to n_ticks.
"""
nice_range = nice_num(upper - lower)
tick_spacing = nice_num(nice_range / (n_ticks - 1.0))
base = nice_num(upper - lower)
tick_spacing = nice_num(base / (n_ticks - 1.0))
nice_lower = np.ceil(lower / tick_spacing) * tick_spacing
nice_upper = np.ceil(upper / tick_spacing) * tick_spacing
res = np.arange(nice_lower, nice_upper, tick_spacing)
return res[(res > lower) & (res < upper)]
def adaptive_rounding(vals: np.ndarray, precision: int) -> np.ndarray:
"""
Return the given values rounded to significant digits.
The significant digits are based of the median decimal exponent and the
given precision.
"""
if vals.size == 0:
return vals
log = np.log10(np.abs(vals), out=np.zeros_like(vals), where=(vals != 0.0))
exponents = np.floor(log).astype(int)
median_exp = int(np.median(exponents))
return np.stack([np.round(v, precision - median_exp) for v in vals])
def get_levels(lower: float, upper: float, n_ticks: int) -> np.ndarray:
"""
Return an array in the interval [lower, upper] with terminating decimals.
......@@ -51,10 +38,14 @@ def get_levels(lower: float, upper: float, n_ticks: int) -> np.ndarray:
The length of the arrays will be close to n_ticks.
At the boundaries the tickspacing may differ from the remaining array.
"""
n_digits = 6
if lower * (1 - 10**-n_digits) <= upper <= lower * (1 + 10**-n_digits):
return lower + np.array([0.0, 10**-n_digits])
if lower == upper:
return np.asarray([lower, nextafter(lower, np.inf)])
levels = nice_range(lower, upper, n_ticks)
return np.unique(
np.append(np.append(lower, adaptive_rounding(levels, n_digits)), upper)
)
return np.unique(np.append(np.append(lower, levels), upper))
def get_median_exponent(vals: np.ndarray) -> int:
"Get the median exponent from an array of numbers."
log = np.log10(np.abs(vals), out=np.zeros_like(vals), where=(vals != 0.0))
exponents = np.floor(log).astype(int)
return int(np.median(exponents))
......@@ -21,7 +21,9 @@ from ogstools.meshplotlib.plot_features import plot_on_top
from ogstools.meshplotlib.utils import justified_labels
from ogstools.propertylib import Scalar, presets
equality = partial(np.testing.assert_allclose, rtol=1e-7, verbose=True)
equality = partial(
np.testing.assert_allclose, rtol=1e-7, atol=1e-100, verbose=True
)
class MeshplotlibTest(unittest.TestCase):
......@@ -42,8 +44,8 @@ class MeshplotlibTest(unittest.TestCase):
equality(get_levels(1e-3, 1.2, 5), [1e-3, *np.arange(0.2, 1.4, 0.2)])
equality(get_levels(1e5, 9e6, 20), [1e5, *np.arange(5e5, 9.5e6, 5e5)])
equality(get_levels(1, 40, 20), [1, *range(2, 42, 2)])
equality(get_levels(0.0, 0.0, 10), [0.0, 1e-6])
equality(get_levels(1e9, 1e9, 10), [1e9, 1e9 + 1e-6])
equality(get_levels(0.0, 0.0, 10), [0.0, 0.0])
equality(get_levels(1e9, 1e9, 10), [1e9, 1e9])
def test_justified_labels(self):
points = np.asarray(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment