From 489ba265f668a3cf3ea326ed687f4b7949341e74 Mon Sep 17 00:00:00 2001 From: FZill <florian.zill@ufz.de> Date: Thu, 21 Mar 2024 12:45:26 +0100 Subject: [PATCH] [tests] difference and aggregate agg test meshlib test meshplotlib --- ogstools/meshlib/mesh_series.py | 21 +++-------- tests/test_meshlib.py | 64 ++++++++++++++++----------------- tests/test_meshplotlib.py | 55 ++++++++++++++++------------ 3 files changed, 69 insertions(+), 71 deletions(-) diff --git a/ogstools/meshlib/mesh_series.py b/ogstools/meshlib/mesh_series.py index 4420b6739..2b13722c9 100644 --- a/ogstools/meshlib/mesh_series.py +++ b/ogstools/meshlib/mesh_series.py @@ -10,8 +10,7 @@ import vtuIO from scipy.interpolate import LinearNDInterpolator, NearestNDInterpolator from tqdm.auto import tqdm -from ogstools.propertylib import Property, presets -from ogstools.propertylib.tensor_math import identity +from ogstools.propertylib import Property from .xdmf_reader import XDMFReader @@ -171,7 +170,7 @@ class MeshSeries: self, mesh_property: Union[Property, str], func: Literal["min", "max", "mean", "median", "sum", "std", "var"], - ) -> tuple[pv.DataSet, Property]: + ) -> pv.UnstructuredGrid: """Aggregate data over all timesteps using a specified function. :param mesh_property: @@ -182,8 +181,7 @@ class MeshSeries: The aggregation function to apply. It must be one of "min", "max", "mean", "median", "sum", "std", "var". The equally named numpy function will be used to aggregate over all timesteps. - :returns: A mesh with aggregated data according to the given function - and a Property corresponding to this data. + :returns: A mesh with aggregated data according to the given function. """ np_func = { @@ -218,18 +216,7 @@ class MeshSeries: ) mesh[output_name] = np.empty(vals.shape[1:]) np_func(vals, out=mesh[output_name], axis=0) - if isinstance(mesh_property, Property): - agg_property = mesh_property.replace( - data_name=output_name, - data_unit=mesh_property.output_unit, - output_unit=mesh_property.output_unit, - output_name=output_name, - func=identity, - mesh_dependent=False, - ) - else: - agg_property = presets.get_preset(output_name, mesh) - return mesh, agg_property + return mesh def _probe_pvd( self, diff --git a/tests/test_meshlib.py b/tests/test_meshlib.py index 2f46c458a..89aa577ef 100644 --- a/tests/test_meshlib.py +++ b/tests/test_meshlib.py @@ -39,6 +39,26 @@ class UtilsTest(unittest.TestCase): ) mesh_series.clear() + def test_aggregate(self): + "Test aggregation of meshseries." + mesh_series = MeshSeries(examples.xdmf_file) + funcs = ["min", "max", "mean", "median", "sum", "std", "var"] + for func in funcs: + agg_mesh = mesh_series.aggregate("temperature", func) + self.assertTrue( + not np.any(np.isnan(agg_mesh["temperature_" + func])) + ) + + def test_aggregate_mesh_dependent(self): + "Test aggregation of mesh_dependent property on meshseries." + mesh_series = MeshSeries(examples.pvd_file) + agg_mesh = mesh_series.aggregate(presets.dilatancy_alkan, "max") + self.assertTrue( + not np.any( + np.isnan(agg_mesh[presets.dilatancy_alkan.output_name + "_max"]) + ) + ) + def test_probe_pvd(self): "Test point probing on pvd." mesh_series = MeshSeries(examples.pvd_file) @@ -57,63 +77,43 @@ class UtilsTest(unittest.TestCase): def test_diff_two_meshes(self): meshseries = examples_mpl.meshseries_THM_2D - mesh_property = presets.temperature mesh1 = meshseries.read(0) mesh2 = meshseries.read(-1) - mesh_diff = difference(mesh_property, mesh1, mesh2) + mesh_diff = difference(mesh1, mesh2, "temperature") + mesh_diff = difference(mesh1, mesh2, presets.temperature) self.assertTrue(isinstance(mesh_diff, UnstructuredGrid)) + mesh_diff = difference(mesh1, mesh2) def test_diff_pairwise(self): n = 5 meshseries = examples_mpl.meshseries_THM_2D - mesh_property = presets.temperature meshes1 = [meshseries.read(0)] * n meshes2 = [meshseries.read(-1)] * n - meshes_diff = difference_pairwise(mesh_property, meshes1, meshes2) + meshes_diff = difference_pairwise(meshes1, meshes2, presets.temperature) self.assertTrue( isinstance(meshes_diff, np.ndarray) and len(meshes_diff) == n ) + meshes_diff = difference_pairwise(meshes1, meshes2) - def test_diff_matrix_single_list(self): + def test_diff_matrix_single(self): meshseries = examples_mpl.meshseries_THM_2D - mesh_property = presets.temperature meshes1 = [meshseries.read(0), meshseries.read(-1)] - meshes_diff = difference_matrix(mesh_property, meshes1) - self.assertTrue( - isinstance(meshes_diff, np.ndarray) - and meshes_diff.shape == (len(meshes1), len(meshes1)) + meshes_diff = difference_matrix( + meshes1, mesh_property=presets.temperature ) - - def test_diff_matrix_single_numpy(self): - meshseries = examples_mpl.meshseries_THM_2D - mesh_property = presets.temperature - meshes1 = np.array([meshseries.read(0), meshseries.read(-1)]) - meshes_diff = difference_matrix(mesh_property, meshes1) self.assertTrue( isinstance(meshes_diff, np.ndarray) and meshes_diff.shape == (len(meshes1), len(meshes1)) ) + meshes_diff = difference_matrix(meshes1) - def test_diff_matrix_unequal_list(self): + def test_diff_matrix_unequal(self): meshseries = examples_mpl.meshseries_THM_2D - mesh_property = presets.temperature meshes1 = [meshseries.read(0), meshseries.read(-1)] meshes2 = [meshseries.read(0), meshseries.read(-1), meshseries.read(-1)] - meshes_diff = difference_matrix(mesh_property, meshes1, meshes2) - self.assertTrue( - isinstance(meshes_diff, np.ndarray) - and meshes_diff.shape == (len(meshes1), len(meshes2)) - ) - - def test_diff_matrix_unequal_numpy(self): - meshseries = examples_mpl.meshseries_THM_2D - mesh_property = presets.temperature - meshes1 = np.array([meshseries.read(0), meshseries.read(-1)]) - meshes2 = np.array( - [meshseries.read(0), meshseries.read(-1), meshseries.read(-1)] - ) - meshes_diff = difference_matrix(mesh_property, meshes1, meshes2) + meshes_diff = difference_matrix(meshes1, meshes2, presets.temperature) self.assertTrue( isinstance(meshes_diff, np.ndarray) and meshes_diff.shape == (len(meshes1), len(meshes2)) ) + meshes_diff = difference_matrix(meshes1, meshes2) diff --git a/tests/test_meshplotlib.py b/tests/test_meshplotlib.py index 98f7271bf..2455e3dec 100644 --- a/tests/test_meshplotlib.py +++ b/tests/test_meshplotlib.py @@ -14,7 +14,6 @@ from ogstools.meshplotlib import ( examples, label_spatial_axes, plot, - plot_limit, plot_probe, setup, update_font_sizes, @@ -32,7 +31,9 @@ assert_allclose = partial( class MeshplotlibTest(unittest.TestCase): - """Test case for meshplotlib.""" + """Test case for meshplotlib. + + Most of these tests only test for no-throw, currently.""" def test_pyvista_offscreen(self): import pyvista as pv @@ -134,8 +135,18 @@ class MeshplotlibTest(unittest.TestCase): def test_diff_plots(self): """Test creation of difference plots.""" - meshseries = examples.meshseries_CT_2D - plot(difference("Si", meshseries.read(0), meshseries.read(1)), "Si") + meshseries = examples.meshseries_THM_2D + mesh0 = meshseries.read(0) + mesh1 = meshseries.read(1) + plot(difference(mesh1, mesh0, "temperature"), "temperature_difference") + for prop in [ + presets.temperature, + presets.displacement, + presets.stress, + presets.stress.von_Mises, + ]: + plot(difference(mesh1, mesh0, prop), prop) + plt.close() def test_user_defined_ax(self): """Test creating plot with subfigures and user provided ax""" @@ -146,9 +157,9 @@ class MeshplotlibTest(unittest.TestCase): plot(meshseries.read(1), presets.temperature, fig=fig, ax=ax[1]) ax[1].set_title(r"$T(\mathrm{t}_{end})$") diff_mesh = difference( - presets.temperature, meshseries.read(0), meshseries.read(1) + meshseries.read(0), meshseries.read(1), presets.temperature ) - plot(diff_mesh, presets.temperature.delta, fig=fig, ax=ax[2]) + plot(diff_mesh, presets.temperature, fig=fig, ax=ax[2]) ax[2].set_title(r"$T(\mathrm{t}_{end})$-$T(\mathrm{t}_{0})$") # fig.suptitle("Test user defined ax") fig.tight_layout() @@ -197,11 +208,11 @@ class MeshplotlibTest(unittest.TestCase): ax = ax.flatten() plot(meshseries.read(0), presets.temperature, fig=fig, ax=ax[0]) plot(meshseries.read(1), presets.temperature, fig=fig, ax=ax[1]) - diff_ab = difference(presets.temperature, mesh_a, mesh_b) - diff_ba = difference(presets.temperature, mesh_b, mesh_a) - plot(diff_ab, presets.temperature.delta, fig=fig, ax=ax[2]) - plot(diff_ba, presets.temperature.delta, fig=fig, ax=ax[3]) - fig.tight_layout() + diff_ab = difference(mesh_a, mesh_b, presets.temperature) + diff_ba = difference(mesh_b, mesh_a, presets.temperature) + plot(diff_ab, presets.temperature, fig=fig, ax=ax[2]) + plot(diff_ba, presets.temperature, fig=fig, ax=ax[3]) + plt.close() def test_label_sharedxy(self): """Test labeling shared x and y axes""" @@ -211,29 +222,29 @@ class MeshplotlibTest(unittest.TestCase): fig, ax = plt.subplots(2, 2, sharex=True, sharey=True) plot(meshseries.read(0), presets.temperature, fig=fig, ax=ax[0][0]) plot(meshseries.read(1), presets.temperature, fig=fig, ax=ax[1][0]) - diff_ab = difference(presets.temperature, mesh_a, mesh_b) - diff_ba = difference(presets.temperature, mesh_b, mesh_a) - plot(diff_ab, presets.temperature.delta, fig=fig, ax=ax[0][1]) - plot(diff_ba, presets.temperature.delta, fig=fig, ax=ax[1][1]) - ax = label_spatial_axes(ax, np.array([0, 1])) - fig.tight_layout() + diff_ab = difference(mesh_a, mesh_b, presets.temperature) + diff_ba = difference(mesh_b, mesh_a, presets.temperature) + plot(diff_ab, presets.temperature, fig=fig, ax=ax[0][1]) + plot(diff_ba, presets.temperature, fig=fig, ax=ax[1][1]) + label_spatial_axes(ax, "x", "y") + plt.close() def test_spatial_label(self): - """Test if labels are added to x and y axes""" + """Test axes labeling""" fig, ax = plt.subplots(2, 2) ax = label_spatial_axes(ax, np.array([0, 1])) def test_spatial_label_clear(self): - """Test if labels are added to x and y axes""" + """Test axes labels clearing""" fig, ax = plt.subplots(2, 2) ax = label_spatial_axes(ax, np.array([0, 1])) ax = clear_labels(ax) def test_limit_plots(self): """Test creation of limit plots.""" - meshseries = examples.meshseries_CT_2D - plot_limit(meshseries, "Si", "min") - plot_limit(meshseries, "Si", "max") + mesh = examples.meshseries_CT_2D.aggregate("Si", "var") + plot(mesh, "Si_var") + plt.close() def test_plot_probe(self): """Test creation of probe plots.""" -- GitLab