From 578cd98f089e37f6eb49cdabdadca5b60f490ec3 Mon Sep 17 00:00:00 2001 From: FZill <florian.zill@ufz.de> Date: Wed, 28 Feb 2024 15:04:21 +0100 Subject: [PATCH] [meshplotlib] refactor of colobar ticklabels --- ogstools/meshplotlib/core.py | 81 ++++++++++++++++++++++++------------ tests/test_meshplotlib.py | 38 +++++++++++++++++ 2 files changed, 93 insertions(+), 26 deletions(-) diff --git a/ogstools/meshplotlib/core.py b/ogstools/meshplotlib/core.py index 08967cb2b..79e70f2ec 100644 --- a/ogstools/meshplotlib/core.py +++ b/ogstools/meshplotlib/core.py @@ -84,6 +84,41 @@ def get_cmap_norm( return cmap, norm +def get_ticklabels(ticks: np.ndarray) -> tuple[list[str], Optional[str]]: + """Get formatted tick labels and optional offset str. + + If all values in ticks are too close together offset notation is used. + """ + fmt = ".3g" + # "+ 0" prevents output of negative zero, i.e. "-0" + tick_labels = [f"{0.0 + round(tick, 12):{fmt}}" for tick in ticks] + if len(tick_labels[1:-1]) != len(set(tick_labels[1:-1])) and len(ticks) > 2: + label_lens = np.asarray([len(f"{tick:{fmt}}") for tick in ticks]) + offset = ticks[np.argmin(label_lens)] + new_fmt = ( + "g" if abs(get_median_exponent(ticks - offset)) <= 2 else ".1e" + ) + tick_labels = [ + f"{0.0 + round(tick, 14):{new_fmt}}" for tick in ticks - offset + ] + return (tick_labels, f"{offset:{fmt}}") + + # pretty hacky but seems to do the job + def _get_label(x, precision): + return f"{0.0 + round(x, precision)}" + + for idx, adj in [(0, 1), (-1, -2)]: + if tick_labels[idx] != tick_labels[adj]: + continue + for precision in range(12): + new_ticklabel = _get_label(ticks[idx], precision) + adj_ticklabel = _get_label(ticks[adj], precision) + if float(new_ticklabel) != float(adj_ticklabel): + tick_labels[idx] = new_ticklabel + break + return tick_labels, None + + def add_colorbars( fig: mfigure.Figure, ax: Union[plt.Axes, list[plt.Axes]], @@ -105,18 +140,17 @@ def add_colorbars( cm, norm=norm, ax=ax, ticks=ticks, drawedges=True, location="right", spacing="uniform", pad=pad # fmt: skip ) - # Formatting the colorbar label - cb_label = mesh_property.output_name.replace("_", " ") + " / " - POWER_LIMIT = 3 - if ( - abs(median_exponent := get_median_exponent(ticks)) > POWER_LIMIT - ) and not (np.isclose(ticks[0], ticks[-1]) or setup.log_scaled): - ticks = ticks * 10**-median_exponent - cb_label += f"10$^{{{median_exponent}}}$ " - if unit := mesh_property.get_output_unit(): - cb_label += f"{unit}" - if cb_label[-3:] == " / ": - cb_label = cb_label[:-3] + # Formatting the colorbar label and ticks + + tick_labels, offset = get_ticklabels(ticks) + cb_label = mesh_property.output_name.replace("_", " ") + if (unit := mesh_property.get_output_unit()) != "": + cb_label += " / " + unit + if offset is not None: + if offset[0] == "-": + cb_label += " + " + offset[1:] + else: + cb_label += " - " + offset if setup.log_scaled: cb_label = f"log$_{{10}}$( {cb_label} )" labelsize = ( @@ -124,17 +158,7 @@ def add_colorbars( ) cb.set_label(cb_label, size=labelsize) - # formatting the colorbar ticks - - cb.ax.tick_params(labelsize=labelsize, direction="out") - # "+ 0" prevents output of negative zero, i.e. "-0" - tick_labels = [ - f"{round(tick, POWER_LIMIT) + 0:.{POWER_LIMIT}g}" for tick in ticks - ] - if tick_labels[0] == tick_labels[1]: - tick_labels[0] = f"{ticks[0]:.2e}" - if tick_labels[-2] == tick_labels[-1]: - tick_labels[-1] = f"{ticks[-1]:.1e}" + # special formatting for MaterialIDs if ( mesh_property.data_name == "MaterialIDs" and setup.material_names is not None @@ -145,6 +169,7 @@ def add_colorbars( cb.ax.set_ylabel("") elif mesh_property.categoric: tick_labels = [str(level) for level in levels.astype(int)] + cb.ax.tick_params(labelsize=labelsize, direction="out") cb.ax.set_yticklabels(tick_labels) # miscellaneous @@ -453,14 +478,18 @@ def plot_diff( diff_mesh = deepcopy(mesh1) diff_mesh[mesh_property.data_name] -= mesh2[mesh_property.data_name] data_property = mesh_property.replace(output_unit=mesh_property.data_unit) - diff_unit = ( - data_property(1, strip_unit=False) - data_property(1, strip_unit=False) - ).units + diff_unit = str( + ( + data_property(1, strip_unit=False) + - data_property(1, strip_unit=False) + ).units + ) diff_property = mesh_property.replace( data_unit=diff_unit, output_unit=diff_unit, output_name=mesh_property.output_name + " difference", bilinear_cmap=True, + cmap=mesh_property.cmap if mesh_property.bilinear_cmap else "coolwarm", ) return plot(diff_mesh, diff_property) diff --git a/tests/test_meshplotlib.py b/tests/test_meshplotlib.py index bab263036..4822fd1ee 100644 --- a/tests/test_meshplotlib.py +++ b/tests/test_meshplotlib.py @@ -16,6 +16,7 @@ from ogstools.meshplotlib import ( setup, ) from ogstools.meshplotlib.animation import animate, save_animation +from ogstools.meshplotlib.core import get_ticklabels from ogstools.meshplotlib.levels import get_levels from ogstools.meshplotlib.plot_features import plot_on_top from ogstools.meshplotlib.utils import justified_labels @@ -47,6 +48,43 @@ class MeshplotlibTest(unittest.TestCase): equality(get_levels(0.0, 0.0, 10), [0.0, 0.0]) equality(get_levels(1e9, 1e9, 10), [1e9, 1e9]) + def test_ticklabels(self): + def compare(array: np.ndarray, ref_labels: list[str], ref_offset=None): + labels, offset = get_ticklabels(np.asarray(array)) + self.assertTrue(np.all(labels == ref_labels)) + self.assertEqual(offset, ref_offset) + + compare([1, 4, 7, 10], ["1", "4", "7", "10"]) + compare([1, 4, 7, 10, 10.01], ["1", "4", "7", "10", "10.01"]) + compare([1, 4, 7, 10, 10.001], ["1", "4", "7", "10", "10.001"]) + compare([100, 150, 200, 200.1], ["100", "150", "200", "200.1"]) + compare( + [-0.00012345, 1e-15, 1, 2, 2.000012345], + ["-0.000123", "0", "1", "2", "2.00001"], + ) + compare( + [0.99987655, 1, 2, 2.000012345], ["0.9999", "1", "2", "2.00001"] + ) + compare( + [100, 100.004, 100.008, 100.012], + ["0", "0.004", "0.008", "0.012"], + "100", + ) + compare( + [100, 100.0004, 100.0008, 100.0012], + ["0.0e+00", "4.0e-04", "8.0e-04", "1.2e-03"], + "100", + ) + compare( + [110000.0, 140000.0, 1700000.0, 1900010.0], + ["1.1e+05", "1.4e+05", "1.7e+06", "1.9e+06"], + ) + compare( + [1000000.0, 1000004.0, 1000008.0, 1000012.0], + ["0", "4", "8", "12"], + "1e+06", + ) + def test_justified_labels(self): points = np.asarray( [ -- GitLab