From 489ba265f668a3cf3ea326ed687f4b7949341e74 Mon Sep 17 00:00:00 2001
From: FZill <florian.zill@ufz.de>
Date: Thu, 21 Mar 2024 12:45:26 +0100
Subject: [PATCH] [tests] difference and aggregate

agg

test meshlib

test meshplotlib
---
 ogstools/meshlib/mesh_series.py | 21 +++--------
 tests/test_meshlib.py           | 64 ++++++++++++++++-----------------
 tests/test_meshplotlib.py       | 55 ++++++++++++++++------------
 3 files changed, 69 insertions(+), 71 deletions(-)

diff --git a/ogstools/meshlib/mesh_series.py b/ogstools/meshlib/mesh_series.py
index 4420b6739..2b13722c9 100644
--- a/ogstools/meshlib/mesh_series.py
+++ b/ogstools/meshlib/mesh_series.py
@@ -10,8 +10,7 @@ import vtuIO
 from scipy.interpolate import LinearNDInterpolator, NearestNDInterpolator
 from tqdm.auto import tqdm
 
-from ogstools.propertylib import Property, presets
-from ogstools.propertylib.tensor_math import identity
+from ogstools.propertylib import Property
 
 from .xdmf_reader import XDMFReader
 
@@ -171,7 +170,7 @@ class MeshSeries:
         self,
         mesh_property: Union[Property, str],
         func: Literal["min", "max", "mean", "median", "sum", "std", "var"],
-    ) -> tuple[pv.DataSet, Property]:
+    ) -> pv.UnstructuredGrid:
         """Aggregate data over all timesteps using a specified function.
 
         :param mesh_property:
@@ -182,8 +181,7 @@ class MeshSeries:
             The aggregation function to apply. It must be one of "min", "max",
             "mean", "median", "sum", "std", "var". The equally named numpy
             function will be used to aggregate over all timesteps.
-        :returns:   A mesh with aggregated data according to the given function
-                    and a Property corresponding to this data.
+        :returns:   A mesh with aggregated data according to the given function.
 
         """
         np_func = {
@@ -218,18 +216,7 @@ class MeshSeries:
         )
         mesh[output_name] = np.empty(vals.shape[1:])
         np_func(vals, out=mesh[output_name], axis=0)
-        if isinstance(mesh_property, Property):
-            agg_property = mesh_property.replace(
-                data_name=output_name,
-                data_unit=mesh_property.output_unit,
-                output_unit=mesh_property.output_unit,
-                output_name=output_name,
-                func=identity,
-                mesh_dependent=False,
-            )
-        else:
-            agg_property = presets.get_preset(output_name, mesh)
-        return mesh, agg_property
+        return mesh
 
     def _probe_pvd(
         self,
diff --git a/tests/test_meshlib.py b/tests/test_meshlib.py
index 2f46c458a..89aa577ef 100644
--- a/tests/test_meshlib.py
+++ b/tests/test_meshlib.py
@@ -39,6 +39,26 @@ class UtilsTest(unittest.TestCase):
             )
             mesh_series.clear()
 
+    def test_aggregate(self):
+        "Test aggregation of meshseries."
+        mesh_series = MeshSeries(examples.xdmf_file)
+        funcs = ["min", "max", "mean", "median", "sum", "std", "var"]
+        for func in funcs:
+            agg_mesh = mesh_series.aggregate("temperature", func)
+            self.assertTrue(
+                not np.any(np.isnan(agg_mesh["temperature_" + func]))
+            )
+
+    def test_aggregate_mesh_dependent(self):
+        "Test aggregation of mesh_dependent property on meshseries."
+        mesh_series = MeshSeries(examples.pvd_file)
+        agg_mesh = mesh_series.aggregate(presets.dilatancy_alkan, "max")
+        self.assertTrue(
+            not np.any(
+                np.isnan(agg_mesh[presets.dilatancy_alkan.output_name + "_max"])
+            )
+        )
+
     def test_probe_pvd(self):
         "Test point probing on pvd."
         mesh_series = MeshSeries(examples.pvd_file)
@@ -57,63 +77,43 @@ class UtilsTest(unittest.TestCase):
 
     def test_diff_two_meshes(self):
         meshseries = examples_mpl.meshseries_THM_2D
-        mesh_property = presets.temperature
         mesh1 = meshseries.read(0)
         mesh2 = meshseries.read(-1)
-        mesh_diff = difference(mesh_property, mesh1, mesh2)
+        mesh_diff = difference(mesh1, mesh2, "temperature")
+        mesh_diff = difference(mesh1, mesh2, presets.temperature)
         self.assertTrue(isinstance(mesh_diff, UnstructuredGrid))
+        mesh_diff = difference(mesh1, mesh2)
 
     def test_diff_pairwise(self):
         n = 5
         meshseries = examples_mpl.meshseries_THM_2D
-        mesh_property = presets.temperature
         meshes1 = [meshseries.read(0)] * n
         meshes2 = [meshseries.read(-1)] * n
-        meshes_diff = difference_pairwise(mesh_property, meshes1, meshes2)
+        meshes_diff = difference_pairwise(meshes1, meshes2, presets.temperature)
         self.assertTrue(
             isinstance(meshes_diff, np.ndarray) and len(meshes_diff) == n
         )
+        meshes_diff = difference_pairwise(meshes1, meshes2)
 
-    def test_diff_matrix_single_list(self):
+    def test_diff_matrix_single(self):
         meshseries = examples_mpl.meshseries_THM_2D
-        mesh_property = presets.temperature
         meshes1 = [meshseries.read(0), meshseries.read(-1)]
-        meshes_diff = difference_matrix(mesh_property, meshes1)
-        self.assertTrue(
-            isinstance(meshes_diff, np.ndarray)
-            and meshes_diff.shape == (len(meshes1), len(meshes1))
+        meshes_diff = difference_matrix(
+            meshes1, mesh_property=presets.temperature
         )
-
-    def test_diff_matrix_single_numpy(self):
-        meshseries = examples_mpl.meshseries_THM_2D
-        mesh_property = presets.temperature
-        meshes1 = np.array([meshseries.read(0), meshseries.read(-1)])
-        meshes_diff = difference_matrix(mesh_property, meshes1)
         self.assertTrue(
             isinstance(meshes_diff, np.ndarray)
             and meshes_diff.shape == (len(meshes1), len(meshes1))
         )
+        meshes_diff = difference_matrix(meshes1)
 
-    def test_diff_matrix_unequal_list(self):
+    def test_diff_matrix_unequal(self):
         meshseries = examples_mpl.meshseries_THM_2D
-        mesh_property = presets.temperature
         meshes1 = [meshseries.read(0), meshseries.read(-1)]
         meshes2 = [meshseries.read(0), meshseries.read(-1), meshseries.read(-1)]
-        meshes_diff = difference_matrix(mesh_property, meshes1, meshes2)
-        self.assertTrue(
-            isinstance(meshes_diff, np.ndarray)
-            and meshes_diff.shape == (len(meshes1), len(meshes2))
-        )
-
-    def test_diff_matrix_unequal_numpy(self):
-        meshseries = examples_mpl.meshseries_THM_2D
-        mesh_property = presets.temperature
-        meshes1 = np.array([meshseries.read(0), meshseries.read(-1)])
-        meshes2 = np.array(
-            [meshseries.read(0), meshseries.read(-1), meshseries.read(-1)]
-        )
-        meshes_diff = difference_matrix(mesh_property, meshes1, meshes2)
+        meshes_diff = difference_matrix(meshes1, meshes2, presets.temperature)
         self.assertTrue(
             isinstance(meshes_diff, np.ndarray)
             and meshes_diff.shape == (len(meshes1), len(meshes2))
         )
+        meshes_diff = difference_matrix(meshes1, meshes2)
diff --git a/tests/test_meshplotlib.py b/tests/test_meshplotlib.py
index 98f7271bf..2455e3dec 100644
--- a/tests/test_meshplotlib.py
+++ b/tests/test_meshplotlib.py
@@ -14,7 +14,6 @@ from ogstools.meshplotlib import (
     examples,
     label_spatial_axes,
     plot,
-    plot_limit,
     plot_probe,
     setup,
     update_font_sizes,
@@ -32,7 +31,9 @@ assert_allclose = partial(
 
 
 class MeshplotlibTest(unittest.TestCase):
-    """Test case for meshplotlib."""
+    """Test case for meshplotlib.
+
+    Most of these tests only test for no-throw, currently."""
 
     def test_pyvista_offscreen(self):
         import pyvista as pv
@@ -134,8 +135,18 @@ class MeshplotlibTest(unittest.TestCase):
 
     def test_diff_plots(self):
         """Test creation of difference plots."""
-        meshseries = examples.meshseries_CT_2D
-        plot(difference("Si", meshseries.read(0), meshseries.read(1)), "Si")
+        meshseries = examples.meshseries_THM_2D
+        mesh0 = meshseries.read(0)
+        mesh1 = meshseries.read(1)
+        plot(difference(mesh1, mesh0, "temperature"), "temperature_difference")
+        for prop in [
+            presets.temperature,
+            presets.displacement,
+            presets.stress,
+            presets.stress.von_Mises,
+        ]:
+            plot(difference(mesh1, mesh0, prop), prop)
+        plt.close()
 
     def test_user_defined_ax(self):
         """Test creating plot with subfigures and user provided ax"""
@@ -146,9 +157,9 @@ class MeshplotlibTest(unittest.TestCase):
         plot(meshseries.read(1), presets.temperature, fig=fig, ax=ax[1])
         ax[1].set_title(r"$T(\mathrm{t}_{end})$")
         diff_mesh = difference(
-            presets.temperature, meshseries.read(0), meshseries.read(1)
+            meshseries.read(0), meshseries.read(1), presets.temperature
         )
-        plot(diff_mesh, presets.temperature.delta, fig=fig, ax=ax[2])
+        plot(diff_mesh, presets.temperature, fig=fig, ax=ax[2])
         ax[2].set_title(r"$T(\mathrm{t}_{end})$-$T(\mathrm{t}_{0})$")
         # fig.suptitle("Test user defined ax")
         fig.tight_layout()
@@ -197,11 +208,11 @@ class MeshplotlibTest(unittest.TestCase):
         ax = ax.flatten()
         plot(meshseries.read(0), presets.temperature, fig=fig, ax=ax[0])
         plot(meshseries.read(1), presets.temperature, fig=fig, ax=ax[1])
-        diff_ab = difference(presets.temperature, mesh_a, mesh_b)
-        diff_ba = difference(presets.temperature, mesh_b, mesh_a)
-        plot(diff_ab, presets.temperature.delta, fig=fig, ax=ax[2])
-        plot(diff_ba, presets.temperature.delta, fig=fig, ax=ax[3])
-        fig.tight_layout()
+        diff_ab = difference(mesh_a, mesh_b, presets.temperature)
+        diff_ba = difference(mesh_b, mesh_a, presets.temperature)
+        plot(diff_ab, presets.temperature, fig=fig, ax=ax[2])
+        plot(diff_ba, presets.temperature, fig=fig, ax=ax[3])
+        plt.close()
 
     def test_label_sharedxy(self):
         """Test labeling shared x and y axes"""
@@ -211,29 +222,29 @@ class MeshplotlibTest(unittest.TestCase):
         fig, ax = plt.subplots(2, 2, sharex=True, sharey=True)
         plot(meshseries.read(0), presets.temperature, fig=fig, ax=ax[0][0])
         plot(meshseries.read(1), presets.temperature, fig=fig, ax=ax[1][0])
-        diff_ab = difference(presets.temperature, mesh_a, mesh_b)
-        diff_ba = difference(presets.temperature, mesh_b, mesh_a)
-        plot(diff_ab, presets.temperature.delta, fig=fig, ax=ax[0][1])
-        plot(diff_ba, presets.temperature.delta, fig=fig, ax=ax[1][1])
-        ax = label_spatial_axes(ax, np.array([0, 1]))
-        fig.tight_layout()
+        diff_ab = difference(mesh_a, mesh_b, presets.temperature)
+        diff_ba = difference(mesh_b, mesh_a, presets.temperature)
+        plot(diff_ab, presets.temperature, fig=fig, ax=ax[0][1])
+        plot(diff_ba, presets.temperature, fig=fig, ax=ax[1][1])
+        label_spatial_axes(ax, "x", "y")
+        plt.close()
 
     def test_spatial_label(self):
-        """Test if labels are added to x and y axes"""
+        """Test axes labeling"""
         fig, ax = plt.subplots(2, 2)
         ax = label_spatial_axes(ax, np.array([0, 1]))
 
     def test_spatial_label_clear(self):
-        """Test if labels are added to x and y axes"""
+        """Test axes labels clearing"""
         fig, ax = plt.subplots(2, 2)
         ax = label_spatial_axes(ax, np.array([0, 1]))
         ax = clear_labels(ax)
 
     def test_limit_plots(self):
         """Test creation of limit plots."""
-        meshseries = examples.meshseries_CT_2D
-        plot_limit(meshseries, "Si", "min")
-        plot_limit(meshseries, "Si", "max")
+        mesh = examples.meshseries_CT_2D.aggregate("Si", "var")
+        plot(mesh, "Si_var")
+        plt.close()
 
     def test_plot_probe(self):
         """Test creation of probe plots."""
-- 
GitLab