From b441a1afeca68efadf6fc210606515c876cdb52d Mon Sep 17 00:00:00 2001
From: FZill <florian.zill@ufz.de>
Date: Thu, 21 Mar 2024 12:45:26 +0100
Subject: [PATCH] [meshlib] added aggregate function to MeshSeries

---
 ogstools/meshlib/mesh_series.py | 67 +++++++++++++++++++++++++++++++++
 1 file changed, 67 insertions(+)

diff --git a/ogstools/meshlib/mesh_series.py b/ogstools/meshlib/mesh_series.py
index d43f39654..4420b6739 100644
--- a/ogstools/meshlib/mesh_series.py
+++ b/ogstools/meshlib/mesh_series.py
@@ -10,6 +10,9 @@ import vtuIO
 from scipy.interpolate import LinearNDInterpolator, NearestNDInterpolator
 from tqdm.auto import tqdm
 
+from ogstools.propertylib import Property, presets
+from ogstools.propertylib.tensor_math import identity
+
 from .xdmf_reader import XDMFReader
 
 
@@ -164,6 +167,70 @@ class MeshSeries:
             )
         return mesh[data_name]
 
+    def aggregate(
+        self,
+        mesh_property: Union[Property, str],
+        func: Literal["min", "max", "mean", "median", "sum", "std", "var"],
+    ) -> tuple[pv.DataSet, Property]:
+        """Aggregate data over all timesteps using a specified function.
+
+        :param mesh_property:
+            The mesh property to be aggregated. If given as type `Property`, the
+            :meth:`~ogstools.propertylib.property.Property.transform` function
+            will be applied on each timestep and aggregation afterwards.
+        :param func:
+            The aggregation function to apply. It must be one of "min", "max",
+            "mean", "median", "sum", "std", "var". The equally named numpy
+            function will be used to aggregate over all timesteps.
+        :returns:   A mesh with aggregated data according to the given function
+                    and a Property corresponding to this data.
+
+        """
+        np_func = {
+            "min": np.min,
+            "max": np.max,
+            "mean": np.mean,
+            "median": np.median,
+            "sum": np.sum,
+            "std": np.std,
+            "var": np.var,
+        }[func]
+        mesh = self.read(0).copy(deep=True)
+        mesh.clear_data()
+        if isinstance(mesh_property, Property):
+            if mesh_property.mesh_dependent:
+                vals = np.asarray(
+                    [
+                        mesh_property.transform(self.read(t))
+                        for t in tqdm(self.timesteps)
+                    ]
+                )
+            else:
+                vals = mesh_property.transform(
+                    self.values(mesh_property.data_name)
+                )
+        else:
+            vals = self.values(mesh_property)
+        output_name = (
+            f"{mesh_property.output_name}_{func}"
+            if isinstance(mesh_property, Property)
+            else f"{mesh_property}_{func}"
+        )
+        mesh[output_name] = np.empty(vals.shape[1:])
+        np_func(vals, out=mesh[output_name], axis=0)
+        if isinstance(mesh_property, Property):
+            agg_property = mesh_property.replace(
+                data_name=output_name,
+                data_unit=mesh_property.output_unit,
+                output_unit=mesh_property.output_unit,
+                output_name=output_name,
+                func=identity,
+                mesh_dependent=False,
+            )
+        else:
+            agg_property = presets.get_preset(output_name, mesh)
+        return mesh, agg_property
+
     def _probe_pvd(
         self,
         points: np.ndarray,
-- 
GitLab