diff --git a/ogstools/meshplotlib/utils.py b/ogstools/meshplotlib/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..748732a3334b865efea37c5df0bac6f3fdb525f5 --- /dev/null +++ b/ogstools/meshplotlib/utils.py @@ -0,0 +1,38 @@ +from typing import Optional + +import matplotlib.pyplot as plt +import numpy as np + + +def justified_labels(points: np.ndarray) -> list[str]: + "Formats an array of points to a list of aligned str." + + def fmt(val: float): + return f"{val:.2f}".rstrip("0").rstrip(".") + + col_lens = np.max( + [[len(fmt(coord)) for coord in point] for point in points], axis=0 + ) + return [ + ",".join(fmt(point[i]).rjust(col_lens[i]) for i in range(3)) + for point in points + ] + + +def get_style_cycler( + min_number_of_styles: int, + colors: Optional[Optional[list]] = None, + linestyles: Optional[list] = None, +) -> plt.cycler: + if colors is None: + colors = plt.rcParams["axes.prop_cycle"].by_key()["color"] + if linestyles is None: + linestyles = ["-", "--", ":", "-."] + styles_len = min(len(colors), len(linestyles)) + c_cycler = plt.cycler(color=colors) + ls_cycler = plt.cycler(linestyle=linestyles) + if min_number_of_styles <= styles_len: + style_cycler = c_cycler[:styles_len] + ls_cycler[:styles_len] + else: + style_cycler = ls_cycler * c_cycler + return style_cycler diff --git a/tests/test_meshplotlib.py b/tests/test_meshplotlib.py index 6113d32282e643101961abfcb29c03195e25fdb2..68fe733a9bcd48dfab5a00d81f99b249d82359ab 100644 --- a/tests/test_meshplotlib.py +++ b/tests/test_meshplotlib.py @@ -11,6 +11,7 @@ from ogstools.meshplotlib import examples, plot, plot_diff, plot_limit, setup from ogstools.meshplotlib.animation import animate, save_animation from ogstools.meshplotlib.levels import get_levels from ogstools.meshplotlib.plot_features import plot_on_top +from ogstools.meshplotlib.utils import justified_labels from ogstools.propertylib import Scalar, presets equality = partial(np.testing.assert_allclose, rtol=1e-7, verbose=True) @@ -37,6 +38,19 @@ class MeshplotlibTest(unittest.TestCase): equality(get_levels(0.0, 0.0, 10), [0.0, 1e-6]) equality(get_levels(1e9, 1e9, 10), [1e9, 1e9 + 1e-6]) + def test_justified_labels(self): + points = np.asarray( + [ + [x, y, z] + for x in np.linspace(-1, 0, 3) + for y in np.linspace(-10, 10, 5) + for z in np.linspace(1e-6, 1e6, 7) + ] + ) + labels = justified_labels(points) + str_lens = np.asarray([len(label) for label in labels]) + self.assertTrue(np.all(str_lens == str_lens[0])) + def test_missing_data(self): """Test missing data in mesh.""" mesh = pv_examples.load_uniform()