From a752773a3c4db166c2a78b8ce0dda6d147fe17fd Mon Sep 17 00:00:00 2001
From: FZill <florian.zill@ufz.de>
Date: Fri, 22 Mar 2024 14:09:13 +0100
Subject: [PATCH] [propertylib] get_preset finds derived props

---
 ogstools/propertylib/presets.py | 35 +++++++++++++++++++++++++++++----
 1 file changed, 31 insertions(+), 4 deletions(-)

diff --git a/ogstools/propertylib/presets.py b/ogstools/propertylib/presets.py
index 8935233b5..9da9a7f02 100644
--- a/ogstools/propertylib/presets.py
+++ b/ogstools/propertylib/presets.py
@@ -14,6 +14,7 @@ from . import mesh_dependent, tensor_math
 from .custom_colormaps import integrity_cmap, temperature_cmap
 from .matrix import Matrix
 from .property import Property, Scalar
+from .tensor_math import identity
 from .vector import Vector
 
 T_MASK = "temperature_active"
@@ -148,6 +149,8 @@ def get_preset(
     Returns a Property preset or creates one with correct type.
 
     Searches for presets by data_name and output_name and returns if found.
+    If 'mesh_property' is given as type Property this will also look for
+    derived properties (difference, aggregate).
     Otherwise create Scalar, Vector, or Matrix Property depending on the shape
     of data in mesh.
 
@@ -155,8 +158,30 @@ def get_preset(
     :param mesh:            The mesh containing the property data.
     :returns: A corresponding Property preset or a new Property of correct type.
     """
+    data_keys: list[str] = list(set().union(mesh.point_data, mesh.cell_data))
+    error_msg = (
+        f"Data not found in mesh. Available data names are {data_keys}. "
+    )
+
     if isinstance(mesh_property, Property):
-        return mesh_property
+        if mesh_property.data_name in data_keys:
+            return mesh_property
+        matches = [
+            mesh_property.output_name in data_key for data_key in data_keys
+        ]
+        if not any(matches):
+            raise KeyError(error_msg)
+        data_key = data_keys[matches.index(True)]
+        if data_key == f"{mesh_property.output_name}_difference":
+            return mesh_property.difference
+        return mesh_property.replace(
+            data_name=data_key,
+            data_unit=mesh_property.output_unit,
+            output_unit=mesh_property.output_unit,
+            output_name=data_key,
+            func=identity,
+            mesh_dependent=False,
+        )
 
     for prop in all_properties:
         if prop.output_name == mesh_property:
@@ -164,9 +189,11 @@ def get_preset(
     for prop in all_properties:
         if prop.data_name == mesh_property:
             return prop
-    if mesh_property not in set().union(mesh.point_data, mesh.cell_data):
-        msg = f"Property {mesh_property} not found in mesh."
-        raise KeyError(msg)
+
+    matches = [mesh_property in data_key for data_key in data_keys]
+    if not any(matches):
+        raise KeyError(error_msg)
+
     data_shape = mesh[mesh_property].shape
     if len(data_shape) == 1:
         return Scalar(mesh_property)
-- 
GitLab