From 578cd98f089e37f6eb49cdabdadca5b60f490ec3 Mon Sep 17 00:00:00 2001
From: FZill <florian.zill@ufz.de>
Date: Wed, 28 Feb 2024 15:04:21 +0100
Subject: [PATCH] [meshplotlib] refactor of colobar ticklabels

---
 ogstools/meshplotlib/core.py | 81 ++++++++++++++++++++++++------------
 tests/test_meshplotlib.py    | 38 +++++++++++++++++
 2 files changed, 93 insertions(+), 26 deletions(-)

diff --git a/ogstools/meshplotlib/core.py b/ogstools/meshplotlib/core.py
index 08967cb2b..79e70f2ec 100644
--- a/ogstools/meshplotlib/core.py
+++ b/ogstools/meshplotlib/core.py
@@ -84,6 +84,41 @@ def get_cmap_norm(
     return cmap, norm
 
 
+def get_ticklabels(ticks: np.ndarray) -> tuple[list[str], Optional[str]]:
+    """Get formatted tick labels and optional offset str.
+
+    If all values in ticks are too close together offset notation is used.
+    """
+    fmt = ".3g"
+    # "+ 0" prevents output of negative zero, i.e. "-0"
+    tick_labels = [f"{0.0 + round(tick, 12):{fmt}}" for tick in ticks]
+    if len(tick_labels[1:-1]) != len(set(tick_labels[1:-1])) and len(ticks) > 2:
+        label_lens = np.asarray([len(f"{tick:{fmt}}") for tick in ticks])
+        offset = ticks[np.argmin(label_lens)]
+        new_fmt = (
+            "g" if abs(get_median_exponent(ticks - offset)) <= 2 else ".1e"
+        )
+        tick_labels = [
+            f"{0.0 +  round(tick, 14):{new_fmt}}" for tick in ticks - offset
+        ]
+        return (tick_labels, f"{offset:{fmt}}")
+
+    # pretty hacky but seems to do the job
+    def _get_label(x, precision):
+        return f"{0.0 + round(x, precision)}"
+
+    for idx, adj in [(0, 1), (-1, -2)]:
+        if tick_labels[idx] != tick_labels[adj]:
+            continue
+        for precision in range(12):
+            new_ticklabel = _get_label(ticks[idx], precision)
+            adj_ticklabel = _get_label(ticks[adj], precision)
+            if float(new_ticklabel) != float(adj_ticklabel):
+                tick_labels[idx] = new_ticklabel
+                break
+    return tick_labels, None
+
+
 def add_colorbars(
     fig: mfigure.Figure,
     ax: Union[plt.Axes, list[plt.Axes]],
@@ -105,18 +140,17 @@ def add_colorbars(
         cm, norm=norm, ax=ax, ticks=ticks, drawedges=True, location="right",
         spacing="uniform", pad=pad  # fmt: skip
     )
-    # Formatting the colorbar label
-    cb_label = mesh_property.output_name.replace("_", " ") + " / "
-    POWER_LIMIT = 3
-    if (
-        abs(median_exponent := get_median_exponent(ticks)) > POWER_LIMIT
-    ) and not (np.isclose(ticks[0], ticks[-1]) or setup.log_scaled):
-        ticks = ticks * 10**-median_exponent
-        cb_label += f"10$^{{{median_exponent}}}$ "
-    if unit := mesh_property.get_output_unit():
-        cb_label += f"{unit}"
-    if cb_label[-3:] == " / ":
-        cb_label = cb_label[:-3]
+    # Formatting the colorbar label and ticks
+
+    tick_labels, offset = get_ticklabels(ticks)
+    cb_label = mesh_property.output_name.replace("_", " ")
+    if (unit := mesh_property.get_output_unit()) != "":
+        cb_label += " / " + unit
+    if offset is not None:
+        if offset[0] == "-":
+            cb_label += " + " + offset[1:]
+        else:
+            cb_label += " - " + offset
     if setup.log_scaled:
         cb_label = f"log$_{{10}}$( {cb_label} )"
     labelsize = (
@@ -124,17 +158,7 @@ def add_colorbars(
     )
     cb.set_label(cb_label, size=labelsize)
 
-    # formatting the colorbar ticks
-
-    cb.ax.tick_params(labelsize=labelsize, direction="out")
-    # "+ 0" prevents output of negative zero, i.e. "-0"
-    tick_labels = [
-        f"{round(tick, POWER_LIMIT) + 0:.{POWER_LIMIT}g}" for tick in ticks
-    ]
-    if tick_labels[0] == tick_labels[1]:
-        tick_labels[0] = f"{ticks[0]:.2e}"
-    if tick_labels[-2] == tick_labels[-1]:
-        tick_labels[-1] = f"{ticks[-1]:.1e}"
+    # special formatting for MaterialIDs
     if (
         mesh_property.data_name == "MaterialIDs"
         and setup.material_names is not None
@@ -145,6 +169,7 @@ def add_colorbars(
         cb.ax.set_ylabel("")
     elif mesh_property.categoric:
         tick_labels = [str(level) for level in levels.astype(int)]
+    cb.ax.tick_params(labelsize=labelsize, direction="out")
     cb.ax.set_yticklabels(tick_labels)
 
     # miscellaneous
@@ -453,14 +478,18 @@ def plot_diff(
     diff_mesh = deepcopy(mesh1)
     diff_mesh[mesh_property.data_name] -= mesh2[mesh_property.data_name]
     data_property = mesh_property.replace(output_unit=mesh_property.data_unit)
-    diff_unit = (
-        data_property(1, strip_unit=False) - data_property(1, strip_unit=False)
-    ).units
+    diff_unit = str(
+        (
+            data_property(1, strip_unit=False)
+            - data_property(1, strip_unit=False)
+        ).units
+    )
     diff_property = mesh_property.replace(
         data_unit=diff_unit,
         output_unit=diff_unit,
         output_name=mesh_property.output_name + " difference",
         bilinear_cmap=True,
+        cmap=mesh_property.cmap if mesh_property.bilinear_cmap else "coolwarm",
     )
     return plot(diff_mesh, diff_property)
 
diff --git a/tests/test_meshplotlib.py b/tests/test_meshplotlib.py
index bab263036..4822fd1ee 100644
--- a/tests/test_meshplotlib.py
+++ b/tests/test_meshplotlib.py
@@ -16,6 +16,7 @@ from ogstools.meshplotlib import (
     setup,
 )
 from ogstools.meshplotlib.animation import animate, save_animation
+from ogstools.meshplotlib.core import get_ticklabels
 from ogstools.meshplotlib.levels import get_levels
 from ogstools.meshplotlib.plot_features import plot_on_top
 from ogstools.meshplotlib.utils import justified_labels
@@ -47,6 +48,43 @@ class MeshplotlibTest(unittest.TestCase):
         equality(get_levels(0.0, 0.0, 10), [0.0, 0.0])
         equality(get_levels(1e9, 1e9, 10), [1e9, 1e9])
 
+    def test_ticklabels(self):
+        def compare(array: np.ndarray, ref_labels: list[str], ref_offset=None):
+            labels, offset = get_ticklabels(np.asarray(array))
+            self.assertTrue(np.all(labels == ref_labels))
+            self.assertEqual(offset, ref_offset)
+
+        compare([1, 4, 7, 10], ["1", "4", "7", "10"])
+        compare([1, 4, 7, 10, 10.01], ["1", "4", "7", "10", "10.01"])
+        compare([1, 4, 7, 10, 10.001], ["1", "4", "7", "10", "10.001"])
+        compare([100, 150, 200, 200.1], ["100", "150", "200", "200.1"])
+        compare(
+            [-0.00012345, 1e-15, 1, 2, 2.000012345],
+            ["-0.000123", "0", "1", "2", "2.00001"],
+        )
+        compare(
+            [0.99987655, 1, 2, 2.000012345], ["0.9999", "1", "2", "2.00001"]
+        )
+        compare(
+            [100, 100.004, 100.008, 100.012],
+            ["0", "0.004", "0.008", "0.012"],
+            "100",
+        )
+        compare(
+            [100, 100.0004, 100.0008, 100.0012],
+            ["0.0e+00", "4.0e-04", "8.0e-04", "1.2e-03"],
+            "100",
+        )
+        compare(
+            [110000.0, 140000.0, 1700000.0, 1900010.0],
+            ["1.1e+05", "1.4e+05", "1.7e+06", "1.9e+06"],
+        )
+        compare(
+            [1000000.0, 1000004.0, 1000008.0, 1000012.0],
+            ["0", "4", "8", "12"],
+            "1e+06",
+        )
+
     def test_justified_labels(self):
         points = np.asarray(
             [
-- 
GitLab