From 0e44a5a4bc415dff9bf63aae7a2d41edab86e1ab Mon Sep 17 00:00:00 2001
From: FZill <florian.zill@ufz.de>
Date: Mon, 26 Feb 2024 10:27:24 +0100
Subject: [PATCH] [studies] use plot_diff in convergence study

---
 .../plot_convergence_study_nuclear_decay.py       |  2 +-
 ogstools/studies/convergence/convergence.py       |  2 +-
 ogstools/studies/templates/convergence_study.py   | 15 ++++-----------
 3 files changed, 6 insertions(+), 13 deletions(-)

diff --git a/docs/examples/howto_studies/plot_convergence_study_nuclear_decay.py b/docs/examples/howto_studies/plot_convergence_study_nuclear_decay.py
index a880b9ae2..303a01f88 100644
--- a/docs/examples/howto_studies/plot_convergence_study_nuclear_decay.py
+++ b/docs/examples/howto_studies/plot_convergence_study_nuclear_decay.py
@@ -91,7 +91,7 @@ for sim_result, dt in zip(sim_results, time_step_sizes):
     for ts in mesh_series.timesteps:
         mesh = mesh_series.read(ts)
         results["temperature"] += [np.max(mesh.point_data["temperature"])]
-    max_T = propertylib.presets.temperature(results["temperature"]).magnitude
+    max_T = propertylib.presets.temperature(results["temperature"])
     # times 2 due to symmetry, area of repo, to kW
     results["heat_flux"] += [np.max(mesh.point_data["heat_flux"][:, 0])]
     tv = np.asarray(mesh_series.timevalues) / sec_per_yr
diff --git a/ogstools/studies/convergence/convergence.py b/ogstools/studies/convergence/convergence.py
index ba79673d6..6cfb173e4 100644
--- a/ogstools/studies/convergence/convergence.py
+++ b/ogstools/studies/convergence/convergence.py
@@ -57,7 +57,7 @@ def grid_convergence(
     returns:                    Grid convergence field of the given property.
     """
     assert len(meshes) >= 3
-    cast = property.magnitude.strip_units
+    cast = property.magnitude
     result = deepcopy(topology)
     result.clear_point_data()
     result.clear_cell_data()
diff --git a/ogstools/studies/templates/convergence_study.py b/ogstools/studies/templates/convergence_study.py
index 8afbefb91..fdbf7b3f5 100644
--- a/ogstools/studies/templates/convergence_study.py
+++ b/ogstools/studies/templates/convergence_study.py
@@ -81,21 +81,14 @@ fig = meshplotlib.plot(richardson, mesh_property)
 
 data_key = mesh_property.data_name
 if reference_solution_path is None:
-    diff = richardson[data_key] - topology.sample(meshes[-1])[data_key]
+    fig = meshplotlib.plot_diff(
+        richardson, topology.sample(meshes[-1]), mesh_property
+    )
 else:
     reference_solution = topology.sample(
         meshlib.MeshSeries(reference_solution_path).read_closest(timevalue)
     )
-    diff = reference_solution[data_key] - richardson[data_key]
-richardson["difference"] = diff
-diff_unit = (mesh_property(1) - mesh_property(1)).units
-diff_property = type(mesh_property)(
-    data_name="difference",
-    data_unit=diff_unit,
-    output_unit=diff_unit,
-    output_name=mesh_property.output_name + " difference",
-)
-fig = meshplotlib.plot(richardson, diff_property)
+    fig = meshplotlib.plot_diff(reference_solution, richardson, mesh_property)
 
 # %% [markdown]
 # ## Convergence metrics
-- 
GitLab