diff --git a/docs/examples/howto_conversions/plot_D_feflowlib_CT_simulation.py b/docs/examples/howto_conversions/plot_D_feflowlib_CT_simulation.py index 27766bbc2cf72a6a83cf1f8b2a2c0ad3003af786..99fa60fd06fa6db4b1063476cd6730670d7ca3ac 100644 --- a/docs/examples/howto_conversions/plot_D_feflowlib_CT_simulation.py +++ b/docs/examples/howto_conversions/plot_D_feflowlib_CT_simulation.py @@ -3,8 +3,9 @@ Workflow: Component-transport model - conversion, simulation, postprocessing ============================================================================ .. sectionauthor:: Julian Heinze (Helmholtz Centre for Environmental Research GmbH - UFZ) -In this example we show how a simple mass transport FEFLOW model can be converted to a pyvista.UnstructuredGrid and then -be simulated in OGS with the component transport process. +In this example we show how a simple mass transport FEFLOW model can be +converted to a pyvista.UnstructuredGrid and then be simulated in OGS with the +component transport process. """ # %% @@ -14,11 +15,11 @@ import xml.etree.ElementTree as ET from pathlib import Path import matplotlib.pyplot as plt -import numpy as np import ogstools as ot from ogstools.examples import feflow_model_2D_CT_t_560 -from ogstools.meshlib import Mesh + +ot.plot.setup.show_element_edges = True # %% # 1. Load a FEFLOW model (.fem) as a FeflowModel object to further work it. @@ -27,29 +28,25 @@ temp_dir = Path(tempfile.mkdtemp("feflow_test_simulation")) feflow_model = ot.FeflowModel( feflow_model_2D_CT_t_560, temp_dir / "2D_CT_model" ) -feflow_concentration = ot.variables.Scalar( - data_name="single_species_P_CONC", - output_name="concentration", - data_unit="mg/l", - output_unit="mg/l", -) -# The original mesh is clipped to focus on the relevant part of it, where concentration is larger -# than 1e-9 mg/l. The rest of the mesh has concentration values of 0. -ot.plot.setup.show_element_edges = True -ot.plot.contourf( - feflow_model.mesh.clip_scalar( - scalars="single_species_P_CONC", invert=False, value=1.0e-9 - ), - feflow_concentration, +# name the feflow concentratiob result the same as in OGS for easier comparison +feflow_model.mesh["single_species"] = feflow_model.mesh["single_species_P_CONC"] +concentration = ot.variables.Scalar( + data_name="single_species", output_name="concentration", + data_unit="mg/l", output_unit="mg/l", +) # fmt: skip +# The original mesh is clipped to focus on the relevant part of it, where +# concentration is larger than 1e-9 mg/l. The rest of the mesh has concentration +# values of 0. +clipped_mesh = feflow_model.mesh.clip_scalar( + scalars="single_species", invert=False, value=1.0e-9 ) +ot.plot.contourf(clipped_mesh, concentration) # %% # 2. Setup a prj-file to run a OGS-simulation. -feflow_model.setup_prj( - end_time=int(4.8384e07), - time_stepping=list( - zip([10] * 8, [8.64 * 10**i for i in range(8)], strict=False) - ), +time_steps = list( + zip([10] * 8, [8.64 * 10**i for i in range(8)], strict=False) ) +feflow_model.setup_prj(end_time=int(4.8384e07), time_stepping=time_steps) # Save the model (mesh, boundary meshes and project file). feflow_model.save() # Print the prj-file as an example. @@ -58,80 +55,28 @@ ET.dump(ET.parse(feflow_model.mesh_path.with_suffix(".prj"))) # 3. Run the model. feflow_model.run() # %% -# 4. Read the results along a line on the upper edge of the mesh parallel to the x-axis and plot them. -ms = ot.MeshSeries(temp_dir / "2D_CT_model.pvd") -# Read the last timestep: -ogs_sim_res = ms.mesh(ms.timesteps[-1]) -""" -It is also possible to read the file directly with pyvista: -ogs_sim_res = pv.read( - temp_dir / "2D_CT_model_ts_65_t_48384000.000000.vtu" -) -""" -profile = np.array([[0.038 + 1.0e-8, 0.005, 0], [0.045, 0.005, 0]]) -fig, ax = plt.subplots(1, 1, figsize=(7, 5)) -ogs_sim_res.plot_linesample( - "dist", - ot.variables.Scalar( - data_name="single_species", - output_name="concentration", - data_unit="mg/l", - output_unit="mg/l", - ), - profile_points=profile, - ax=ax, - resolution=1000, - grid="major", - fontsize=18, - label="OGS", - color="black", - linewidth=2, -) -Mesh(feflow_model.mesh).plot_linesample( - "dist", - feflow_concentration, - profile_points=profile, - ax=ax, - resolution=1000, - fontsize=16, - label="FEFLOW", - ls=":", - linewidth=2, - color="red", -) -ax.legend(loc="best", fontsize=16) +# 4. Read the last timestep and plot the results along a line on the upper edge +# of the mesh parallel to the x-axis. +ogs_sim_res = ot.MeshSeries(temp_dir / "2D_CT_model.pvd")[-1] +fig, ax = plt.subplots(1, 1, figsize=(16, 10)) +pts = [[0.038 + 1.0e-8, 0.005, 0], [0.045, 0.005, 0]] +for i, mesh in enumerate([ogs_sim_res, feflow_model.mesh]): + sample = mesh.sample_over_line(*pts) + label = ["OGS", "FEFLOW"][i] + ot.plot.line( + sample, concentration, ax=ax, color="kr"[i], label=label, ls="-:"[i] + ) fig.tight_layout() # %% # 5. Concentration difference plotted on the mesh. -ogs_sim_res["concentration_difference"] = ( - feflow_model.mesh["single_species_P_CONC"] - ogs_sim_res["single_species"] -) -concentration_difference = ot.variables.Scalar( - data_name="concentration_difference", - output_name="concentration", - data_unit="mg/l", - output_unit="mg/l", -) -bounds = [0.038, 0.045, 0, 0.01, 0, 0] -ot.plot.contourf( - ogs_sim_res.clip_box(bounds, invert=False), - concentration_difference, -) +diff = ot.meshlib.difference(feflow_model.mesh, ogs_sim_res, concentration) +diff_clipped = diff.clip_box([0.038, 0.045, 0, 0.01, 0, 0], invert=False) +fig = ot.plot.contourf(diff_clipped, concentration.difference, fontsize=20) # %% # 5.1 Concentration difference plotted along a line. -fig, ax = plt.subplots(1, 1, figsize=(7, 5)) -ogs_sim_res.plot_linesample( - "dist", - concentration_difference, - profile_points=profile, - ax=ax, - resolution=1000, - grid="both", - fontsize=18, - linewidth=2, - color="green", - label="Difference FEFLOW-OGS", +diff_sample = diff.sample_over_line(*pts) +fig = ot.plot.line( + diff_sample, concentration.difference, label="Difference FEFLOW-OGS" ) -ax.legend(loc="best", fontsize=16) fig.tight_layout() diff --git a/docs/examples/howto_conversions/plot_E_feflowlib_H_simulation.py b/docs/examples/howto_conversions/plot_E_feflowlib_H_simulation.py index 56121c5368856b2e57a778e4613d364e450317fd..f7cbace22fc0826df18a9b217cd8608951c0d339 100644 --- a/docs/examples/howto_conversions/plot_E_feflowlib_H_simulation.py +++ b/docs/examples/howto_conversions/plot_E_feflowlib_H_simulation.py @@ -99,7 +99,8 @@ slices = np.reshape( ) fig = ot.plot.contourf(slices, diff_head) for ax, slice in zip(fig.axes, np.ravel(slices), strict=False): - ax.set_title(f"z = {slice.center[2]:.1f} {ms.spatial_output_unit}") + ax.set_title(f"z = {slice.center[2]:.1f} m", fontsize=32) + # %% # 6.2 Slices are taken along the y-axis. slices = np.reshape( @@ -107,4 +108,4 @@ slices = np.reshape( ) fig = ot.plot.contourf(slices, diff_head) for ax, slice in zip(fig.axes, np.ravel(slices), strict=False): - ax.set_title(f"y = {slice.center[1]:.1f} {ms.spatial_output_unit}") + ax.set_title(f"y = {slice.center[1]:.1f} m", fontsize=32) diff --git a/docs/examples/howto_plot/plot_animation.py b/docs/examples/howto_plot/plot_animation.py index 62c4e6d68be1185a384ec7fcaa7be917587f7fb2..8f8811f94ec559d0891d6a4248f9f671b795ee4a 100644 --- a/docs/examples/howto_plot/plot_animation.py +++ b/docs/examples/howto_plot/plot_animation.py @@ -34,14 +34,14 @@ mesh_series = examples.load_meshseries_CT_2D_XDMF() # %% timevalues = np.linspace( - mesh_series.timevalues()[0], mesh_series.timevalues()[-1], num=25 + mesh_series.timevalues[0], mesh_series.timevalues[-1], num=25 ) # %% [markdown] # Now, let's animate the saturation solution. A timescale at the top # indicates existing timesteps and the position of the current timevalue. # Note that rendering many frames in conjunction with large meshes might take -# a really long time. We can pass a `plot_func` which can apply custom +# a really long time. We can pass a ``plot_func`` which can apply custom # formatting and / or plotting. To modify the domain, we can use the transform # method of MeshSeries. diff --git a/docs/examples/howto_plot/plot_contourf_2d.py b/docs/examples/howto_plot/plot_contourf_2d.py index a990e866f819e531b1dd4beb04cbe9220cee7415..97e29f83867875047a1fc43a19138286b09b6889 100644 --- a/docs/examples/howto_plot/plot_contourf_2d.py +++ b/docs/examples/howto_plot/plot_contourf_2d.py @@ -17,7 +17,7 @@ import ogstools as ot from ogstools import examples ot.plot.setup.material_names = {i + 1: f"Layer {i+1}" for i in range(26)} -mesh = examples.load_meshseries_THM_2D_PVD().mesh(1) +mesh = examples.load_meshseries_THM_2D_PVD().scale(spatial=("m", "km")).mesh(1) # %% [markdown] # To read your own data as a mesh series you can do: @@ -33,7 +33,7 @@ fig = mesh.plot_contourf(ot.variables.material_id) # %% [markdown] # Now, let's plot the temperature field (point_data) at the first timestep. -# The default temperature variable from the `variables` reads the temperature +# The default temperature variable from the ``variables`` reads the temperature # data as Kelvin and converts them to degrees Celsius. # %% diff --git a/docs/examples/howto_plot/plot_observation_points.py b/docs/examples/howto_plot/plot_observation_points.py index 80931ef3cd0b5fc47b6aa9486356d5e2e99409b0..d5074ef16bc2ae589e8fb2cbef858dfb692473ef 100644 --- a/docs/examples/howto_plot/plot_observation_points.py +++ b/docs/examples/howto_plot/plot_observation_points.py @@ -17,7 +17,6 @@ Here we use a component transport example from the ogs benchmark gallery # sphinx_gallery_start_ignore # sphinx_gallery_thumbnail_number = 2 -# fmt:off # sphinx_gallery_end_ignore @@ -27,7 +26,7 @@ import numpy as np import ogstools as ot from ogstools import examples -mesh_series = examples.load_meshseries_CT_2D_XDMF() +mesh_series = examples.load_meshseries_CT_2D_XDMF().scale(time=("s", "a")) si = ot.variables.saturation # %% [markdown] @@ -46,28 +45,26 @@ si = ot.variables.saturation # Let's define 4 observation points and plot them on the mesh. # %% -points = np.asarray( - [[x, 0, 60] for x in [0, 40, 80, 120]] - + [[x, 0, 40] for x in [0, 40, 80, 120]] -) +points = np.asarray([[x, 0, z] for z in [60, 40] for x in [0, 40, 80, 120]]) fig = mesh_series.mesh(0).plot_contourf(si) fig.axes[0].scatter(points[:, 0], points[:, 2], s=50, fc="none", ec="r", lw=3) -for i, point in enumerate(points): - fig.axes[0].annotate(str(i), (point[0], point[2] - 5), va="top") +for i, pt in enumerate(points): + fig.axes[0].annotate(str(i), (pt[0], pt[2] - 5), va="top", fontsize=32) # %% [markdown] # And now probe the points and the values over time: # %%, -labels = [f"{i}: {label}" for i, label in enumerate(ot.plot.utils.justified_labels(points))] -fig = mesh_series.plot_probe( - points=points[:4], variable=si, time_unit="a", labels=labels[:4] -) +labels = [ + f"{i}: {label}" + for i, label in enumerate(ot.plot.utils.justified_labels(points)) +] +fig = mesh_series.plot_probe(points=points[:4], variable=si, labels=labels[:4]) # %% [markdown] # You can also pass create your own matplotlib figure and pass the axes object. # Additionally, you can pass any keyword arguments which are known by # matplotlibs plot function to further customize the curves. -# In this case `marker` and `linewidth` are not part of the API of `plot_probe` +# In this case ``marker`` and ``linewidth`` are not part of the API of `plot_probe` # but get processed correctly anyway. # If you want to have more freedom with the data you can just do the probing, # adapt to your needs and then do the plotting yourself: @@ -75,16 +72,15 @@ fig = mesh_series.plot_probe( # %% fig, axs = plt.subplots(nrows=2, figsize=[10, 5]) mesh_series.plot_probe( - points[:4], si, time_unit="a", ax=axs[0], colors=["k"], - labels=labels[:4], marker=".") + points[:4], si, ax=axs[0], colors=["k"], labels=labels[:4], marker="." +) mesh_series.plot_probe( - points[4:], si, time_unit="a", ax=axs[1], linestyles=["-"], - labels=labels[4:], linewidth=1, + points[4:], si, ax=axs[1], linestyles=["-"], labels=labels[4:], linewidth=1 ) # add the mean of the observation point timeseries values = si.transform(mesh_series.probe(points, data_name=si.data_name)) mean_values = np.mean(values.reshape((-1, 2, 4)), axis=-1) -ts = mesh_series.timevalues("a") +ts = mesh_series.timevalues for index in range(2): fig.axes[index].plot(ts, mean_values[:, index], "rk"[index], label="mean") fig.axes[index].legend() diff --git a/docs/examples/howto_plot/plot_shared_axes.py b/docs/examples/howto_plot/plot_shared_axes.py index 9832cf340a95f5f516a71ffb0b2748256e4430b4..aa7f326d52f2d519d698d130871fc23d25393529 100644 --- a/docs/examples/howto_plot/plot_shared_axes.py +++ b/docs/examples/howto_plot/plot_shared_axes.py @@ -15,7 +15,7 @@ import matplotlib.pyplot as plt import ogstools as ot from ogstools import examples -meshseries = examples.load_meshseries_THM_2D_PVD() +meshseries = examples.load_meshseries_THM_2D_PVD().scale(spatial=("m", "km")) mesh_0 = meshseries.mesh(0) mesh_1 = meshseries.mesh(1) variable = ot.variables.temperature diff --git a/docs/examples/howto_plot/plot_timeslice.py b/docs/examples/howto_plot/plot_timeslice.py index 33f3c211d4847e65389a106a7b4fea588ccab041..905b36433449a1f4028e3de05152aceda9523f34 100644 --- a/docs/examples/howto_plot/plot_timeslice.py +++ b/docs/examples/howto_plot/plot_timeslice.py @@ -14,43 +14,63 @@ To see this benchmark results over all timesteps have a look at """ # %% [markdown] -# Let's load the data and create 3 different lines to sample over: -# vertical, horizontal and diagonal. +# Let's load the data which we want to investigate. + +# %% +import matplotlib.pyplot as plt import numpy as np +import pyvista as pv import ogstools as ot from ogstools import examples -mesh_series = examples.load_meshseries_CT_2D_XDMF() +mesh_series = examples.load_meshseries_CT_2D_XDMF().scale(time=("s", "a")) +y = mesh_series[0].center[1] # flat y coordinate of this 2D mesh is not 0 si = ot.variables.saturation -points_vert = np.linspace([25, 0, -75], [25, 0, 75], num=100) -points_hori = np.linspace([0, 0, 60], [150, 0, 60], num=100) -points_diag = np.linspace([25, 0, 75], [100, 0, 0], num=100) + +# %% [markdown] +# Now we setup two sampling lines. + +# %% +pts_vert = np.linspace([25, y, -75], [25, y, 75], num=100) +pts_diag = np.linspace([25, y, 75], [100, y, 0], num=100) fig = mesh_series.mesh(-1).plot_contourf(si, vmin=0) -fig.axes[0].plot(points_vert[[0, -1], 0], points_vert[[0, -1], 2], "-k2") -fig.axes[0].plot(points_hori[[0, -1], 0], points_hori[[0, -1], 2], "--k2") -fig.axes[0].plot(points_diag[[0, -1], 0], points_diag[[0, -1], 2], "-.k2") +fig.axes[0].plot(pts_vert[:, 0], pts_vert[:, 2], "-k", linewidth=3) +fig.axes[0].plot(pts_diag[:, 0], pts_diag[:, 2], "-.k", linewidth=3) # %% [markdown] -# The function plot_time_slice automatically detects if the line lies on a -# cardinal direction and labels the y-axes with the changing spatial dimension. -fig = mesh_series.plot_time_slice(si, points_vert, time_unit="a") +# Here, we first show a regular line sample plot for the vertical sampling line +# for each timestep. +# %% +fig, ax = plt.subplots(figsize=[15, 8]) +for mesh, timevalue in zip(mesh_series, mesh_series.timevalues, strict=True): + sample = pv.PolyData(pts_vert).sample(mesh) + color = str(0.8 * timevalue / mesh_series.timevalues[-1]) + label = f"{timevalue:.1f} a" + fig = ot.plot.line( + sample, "z", si, ax=ax, label=label, color=color, fontsize=20 + ) # %% [markdown] -# By default the plot is smoothened with interpolation. When deactivated, we -# see the raw sampled data. Be sure to adjust the number of sampling points if -# the MeshSeries contains a lot of timesteps. -fig = mesh_series.plot_time_slice( - si, points_vert, time_unit="a", interpolate=False -) +# As the above kind of plot is getting cluttered for lots of timesteps we +# provide a function to create a filled contour plot over the transient data. +# The function :meth:`~ogstools.meshlib.mesh_series.MeshSeries.plot_time_slice` +# automatically detects if the line lies on a cardinal direction and labels the +# y-axes with the changing spatial dimension. + +# %% +fig = mesh_series.plot_time_slice(si, pts_vert) # %% [markdown] -# The horizontal sampling line gets also labeled appropriately. -fig = mesh_series.plot_time_slice(si, points_hori, time_unit="a") +# By default the plot is smoothened with interpolation. When deactivated, we +# see the edges of the raw sampled data. When using the interpolation, be sure +# to adjust the number of sampling points if the MeshSeries contains a lot of +# small timesteps. +fig = mesh_series.plot_time_slice(si, pts_vert, interpolate=False) # %% [markdown] # If the line doesn't point in a cardinal direction the distance along the # line is used for the y-axis by default. You can however, specify if you want # to use spatial dimension via the argument "y_axis". This may be useful when # plotting data of an edge / boundary of the mesh. -fig = mesh_series.plot_time_slice(si, points_diag, time_unit="a") +fig = mesh_series.plot_time_slice(si, pts_diag) diff --git a/docs/examples/howto_plot/plot_with_custom_fig_ax.py b/docs/examples/howto_plot/plot_with_custom_fig_ax.py index 15706060e1d78b908f6d4d43d509cbef96820b61..c705007ca15ef1a54c8a448f94122aa22a9640db 100644 --- a/docs/examples/howto_plot/plot_with_custom_fig_ax.py +++ b/docs/examples/howto_plot/plot_with_custom_fig_ax.py @@ -15,7 +15,7 @@ import matplotlib.pyplot as plt import ogstools as ot from ogstools import examples -meshseries = examples.load_meshseries_THM_2D_PVD() +meshseries = examples.load_meshseries_THM_2D_PVD().scale(spatial=("m", "km")) ot.plot.setup.combined_colorbar = False diff --git a/docs/examples/howto_postprocessing/plot_aggregate.py b/docs/examples/howto_postprocessing/plot_aggregate.py index ecc07b54610be2eaf8edc62c20f8e9264a258d92..21d527664de532feb92bb02c41f19927795fd76b 100644 --- a/docs/examples/howto_postprocessing/plot_aggregate.py +++ b/docs/examples/howto_postprocessing/plot_aggregate.py @@ -19,7 +19,7 @@ import numpy as np import ogstools as ot from ogstools import examples -mesh_series = examples.load_meshseries_CT_2D_XDMF() +mesh_series = examples.load_meshseries_CT_2D_XDMF().scale(time=("s", "a")) saturation = ot.variables.saturation # %% [markdown] @@ -54,7 +54,7 @@ fig = mesh.plot_contourf(saturation) # %% mesh = mesh_series.time_of_max(saturation) -fig = mesh.plot_contourf(ot.variables.Scalar("max_Saturation_time", "s", "a")) +fig = mesh.plot_contourf(ot.variables.Scalar("max_Saturation_time", "a", "a")) # %% [markdown] # Likewise we can calculate and visualize the variance of the saturation: @@ -76,4 +76,4 @@ fig = mesh.plot_contourf(saturation) # of e.g. the max or mean value of a variable in the entire domain. # %% -fig = mesh_series.plot_domain_aggregate(saturation, np.mean, time_unit="a") +fig = mesh_series.plot_domain_aggregate(saturation, np.mean) diff --git a/docs/examples/howto_postprocessing/plot_calculate_diff.py b/docs/examples/howto_postprocessing/plot_calculate_diff.py index 56f2019db1cbeb0822f5a7740d0e4711142b7f32..60ca9e6a6f2f6b4ab4743b2b0eacb0650e9f0caf 100644 --- a/docs/examples/howto_postprocessing/plot_calculate_diff.py +++ b/docs/examples/howto_postprocessing/plot_calculate_diff.py @@ -69,7 +69,7 @@ def custom_mesh(lengths: int, element_order: int, quads: bool) -> ot.Mesh: # the variable of interest, the difference will work fine. # %% -mesh_series = examples.load_meshseries_THM_2D_PVD() +mesh_series = examples.load_meshseries_THM_2D_PVD().scale(spatial=("m", "km")) mesh1 = mesh_series.mesh(0) mesh2 = mesh_series.mesh(-1) diff --git a/docs/examples/howto_postprocessing/plot_convergence_study_nuclear_decay.py b/docs/examples/howto_postprocessing/plot_convergence_study_nuclear_decay.py index c87a310dfed6ead5446ecbacb6642458682b9a0e..9c149e5ca886e91b25d910b3c32ba1fba9335881 100644 --- a/docs/examples/howto_postprocessing/plot_convergence_study_nuclear_decay.py +++ b/docs/examples/howto_postprocessing/plot_convergence_study_nuclear_decay.py @@ -88,7 +88,7 @@ fig, (ax1, ax2) = plt.subplots(figsize=(8, 8), nrows=2, sharex=True) ax2.plot(time, heat, lw=2, label="reference", color="k") for sim_result, dt in zip(sim_results, time_step_sizes, strict=False): - mesh_series = ot.MeshSeries(sim_result) + mesh_series = ot.MeshSeries(sim_result).scale(time=("s", "yrs")) results = {"heat_flux": [], "temperature": []} for ts in mesh_series.timesteps: mesh = mesh_series.mesh(ts) @@ -96,7 +96,7 @@ for sim_result, dt in zip(sim_results, time_step_sizes, strict=False): max_T = ot.variables.temperature.transform(results["temperature"]) # times 2 due to symmetry, area of repo, to kW results["heat_flux"] += [np.max(mesh.point_data["heat_flux"][:, 0])] - tv = np.asarray(mesh_series.timevalues("a")) + tv = np.asarray(mesh_series.timevalues) ax1.plot(tv, max_T, lw=1.5, label=f"{dt=}") edges = np.append(0, tv) mean_t = 0.5 * (edges[1:] + edges[:-1]) diff --git a/docs/examples/howto_postprocessing/plot_sample_mesh_line.py b/docs/examples/howto_postprocessing/plot_sample_mesh_line.py index 21a59f950f5da737b4bb017d53f99275dd22b9f3..ea4ff449ed1734c5e628abe1f6459511703014bb 100644 --- a/docs/examples/howto_postprocessing/plot_sample_mesh_line.py +++ b/docs/examples/howto_postprocessing/plot_sample_mesh_line.py @@ -1,209 +1,144 @@ """ -Extract a 1D profile from 2D and plot it -======================================== +***************************** +Plot data of a sampling lines +***************************** -.. sectionauthor:: Feliks Kiszkurno (Helmholtz Centre for Environmental Research GmbH - UFZ) +.. sectionauthor:: Florian Zill (Helmholtz Centre for Environmental Research GmbH - UFZ) + +This example provides clean coding recipes for plotting data of meshes over +sampling lines. We also present different ways to setup the sampling lines. +For plotting we us the function :py:func:`ogstools.plot.line`. """ # %% + +# sphinx_gallery_start_ignore + +# sphinx_gallery_thumbnail_number = -1 + +# sphinx_gallery_end_ignore + +from itertools import pairwise + import matplotlib.pyplot as plt import numpy as np +import pyvista as pv import ogstools as ot from ogstools import examples -# %% [markdown] -# Single fracture -# ------------------ -# Define a profile line by providing a list of points in x, y, z coordinates -# and load an example data set: - -# %% -mesh = examples.load_meshseries_HT_2D_XDMF().mesh(-1) - -profile_HT = np.array([[4, 2, 0], [4, 18, 0]]) - -# %% -mesh_sp, mesh_kp = ot.meshlib.sample_polyline( - mesh, ["pressure", "temperature"], profile_HT -) +ot.plot.setup.show_region_bounds = False +mesh = examples.load_mesh_mechanics_2D() # %% [markdown] -# It has returned a pandas DataFrame containing all information about the -# profile and a numpy array with the position of the "knot-points". -# Let's investigate the DataFrame first: +# Simple case: straight line +# ========================== +# We use the ``pyvista`` function ``sample_over_line`` and use two points to define +# the line and get a Mesh with the sampled data. Let's plot the Mesh and the +# line together. # %% -mesh_sp.head(10) +sample = mesh.sample_over_line([25, -460, 0], [100, -800, 0]) +fig = ot.plot.contourf(mesh, ot.variables.temperature) +fig = ot.plot.line( + mesh=sample, y_var="y", x_var="x", ax=fig.axes[0], linestyle="--" +) # %% [markdown] -# We can see the spatial coordinates of points on the profile ("x", "y", "z" -# - columns), distances from the beginning of the profile ("dist") and within -# current segment ("dist_in_segment"). Note, that since we defined our profile -# on only two points, there is only one segment, hence in this special case -# columns dist and dist_in_segment are identical. At the end of the DataFrame -# we can can find two columns with the variables that we are interested in: -# "temperature" and "pressure". Each occupies one column, as those are scalar -# values. Using columns "dist", "pressure" and "temperature" we can easily -# plot the data: +# Now we plot the temperature data. The spatial coordinate for the x-axis is +# automatically detected here by not passing ``x_var`` explicitly. # %% -fig, ax = plt.subplots(1, 1, figsize=(7, 5)) -ax = mesh.plot_linesample( - x="dist", - variable="pressure", - profile_points=profile_HT, - ax=ax, - fontsize=15, -) -ax_twinx = ax.twinx() -ax_twinx = mesh.plot_linesample( - x="dist", - variable="temperature", - profile_points=profile_HT, - ax=ax_twinx, - fontsize=15, -) -ot.plot.utils.color_twin_axes( - [ax, ax_twinx], - [ot.variables.pressure.color, ot.variables.temperature.color], -) +fig = ot.plot.line(sample, ot.variables.temperature) fig.tight_layout() # %% [markdown] -# What happens when we are interested in a vector variable? -# We can see it in the following example using the Darcy velocity: +# Simple case: circular arc +# ========================= +# With 3 points we can define an arc over which to sample the data. +# Having the arc directly on a boundary might result in some gaps in the +# sampled data, thus we extend the arc by a small margin. # %% -mesh_sp, mesh_kp = ot.meshlib.sample_polyline( - mesh, "darcy_velocity", profile_HT +sample = mesh.sample_over_circular_arc( + pointa=[100 - 1e-4, -650, 0], + pointb=[200 + 1e-4, -650, 0], + center=[150, -650, 0], ) - -# %% -mesh_sp.head(5) +fig, axs = plt.subplots(ncols=2, figsize=[26, 10]) +ot.plot.contourf(mesh, ot.variables.displacement["x"], fig=fig, ax=axs[1]) +ot.plot.line(sample, "y", "x", axs[1], linewidth="8", color="red") +ot.plot.line(sample, ot.variables.displacement["x"], ax=axs[0]) +fig.tight_layout() # %% [markdown] -# Now we have two columns for the variable. The Darcy velocity is a vector, -# therefore "sample_over_polyline" has split it into two columns and appended -# the variable name with increasing integer. Note, that this suffix has no -# physical meaning and only indicates order. It is up to user to interpret it -# in a meaningful way. By the -# `OpenGeoSys conventions <https://www.opengeosys.org/docs/userguide/basics/conventions/#a-namesymmetric-tensorsa--symmetric-tensors-and-kelvin-mapping>`_, -# "darcy_velocity_0" will be in the x-direction and "darcy_velocity_1" in -# y-direction. - +# Other methods to setup the sampling line +# ======================================== +# The following section shows different methods of creating sampling lines. # %% [markdown] -# Elder benchmark -# ------------------ -# In this example we will use a Variable object from the ogstools to -# sample the data. This allows "sample_over_polyline" to automatically -# convert from the "data_unit" to the "output_unit": +# Linear spaced points +# -------------------- +# This basically does the same as the ``pyvista`` function `sample_over_line`. # %% -profile_CT = np.array([[47.0, 1.17, 72.0], [-4.5, 1.17, -59.0]]) -mesh = examples.load_meshseries_CT_2D_XDMF().mesh(11) - -# %% -mesh_sp, mesh_kp = ot.meshlib.sample_polyline( - mesh, ot.variables.saturation, profile_CT -) +pts = np.linspace([50, -460, 0], [50, -800, 0], 100) +sample_1 = pv.PolyData(pts).sample(mesh) # %% [markdown] -# As before we can see the profile parameters and propertiy values in a -# DataFrame: +# Mutilsegmented line from list of points +# --------------------------------------- +# The following code allows you to have a line connecting multiple observation +# points. # %% -mesh_sp.head(5) - +obs_pts = np.asarray([[150, -460, 0], [50, -650, 0], [150, -800, 0]]) +pts = np.vstack([np.linspace(pt1, pt2, 50) for pt1, pt2 in pairwise(obs_pts)]) +sample_2 = pv.PolyData(pts).sample(mesh) # %% [markdown] -# This time we will prepare more complicated plot showing both the mesh and -# the profile. +# Spline from list of points +# -------------------------- +# You can also create smooth sampling lines by using a fitting function. +# The following creates a second order polynomial fit for the x-coordinates +# in dependence of the y-coordinates. # %% -fig, ax = mesh.plot_linesample_contourf( - ot.variables.saturation, profile_CT, resolution=100 -) +pts = np.asarray([[200, -460, 0], [250, -650, 0], [200, -800, 0]]) +fit = np.poly1d(np.polyfit(pts[:, 1], pts[:, 0], 2)) +y = np.linspace(-460, -800, 100) +pts = np.transpose([fit(y), y, y * 0]) +sample_3 = pv.PolyData(pts).sample(mesh) # %% [markdown] -# THM -# ------ -# It is also possible to obtain more than one variable at the same time using -# more complex profiles. They can be constructed by providing more than -# 2 points. With those points: +# Use existing geometry +# --------------------- +# Another way to setup the sampling line is to extract points from the domain +# mesh. Here, we use the ``clip`` function from ``pyvista`` and some boolean logic, +# to extract a vertical line through the center, which follows the boundary of +# the hole. We need to sort the points however, to have them adjacent. # %% -profile_THM = np.array( - [ - [-1000.0, -175.0, 6700.0], # Point A - [-600.0, -600.0, 6700.0], # Point B - [100.0, -300.0, 6700.0], # Point C - [3500, -900.0, 6700.0], # Point D - ] -) -# %% [markdown] -# the profile will run as follows: -# -# .. math:: -# -# \text{AB} \rightarrow \text{BC} \rightarrow \text{CD} -# -# Point B will at the same time be the last point in the first segment AB -# and first one in second segment BC, however in the returned array, -# it will occur only once. -# For this example we will use a different dataset: - -# %% -mesh = examples.load_meshseries_THM_2D_PVD().mesh(-1) - -# %% -ms_THM_sp, dist_at_knot = ot.meshlib.sample_polyline( - mesh, - [ot.variables.pressure, ot.variables.temperature], - profile_THM, - resolution=100, -) +edges = mesh.clip("x").extract_feature_edges() +is_top_bot = np.isin(edges.points[:, 1], [-800, -460]) +is_left = edges.points[:, 0] == 0 +pts = edges.points[np.invert(is_top_bot | is_left)] +sample_4 = pv.PolyData(pts[np.argsort(pts[:, 1])]).sample(mesh) # %% [markdown] -# Again, we can investigate the returned DataFrame, but this time we will -# have a look at its beginning: +# Now we plot all samples for comparison. # %% -ms_THM_sp.head(5) - -# %% -# and end: - -# %% -ms_THM_sp.tail(10) - -# %% [markdown] -# Note, that unlike in the first example, here the columns "dist" and -# "dist_in_segment" are not identical, as this time profile consists of -# multiple segments. The following figure illustrates the difference: -plt.rcdefaults() -ax: plt.Axes -fig, ax = plt.subplots(1, 1, figsize=(7, 3)) -ax.plot(ms_THM_sp["dist"], label="dist") -ax.plot(ms_THM_sp["dist_in_segment"], label="dist_in_segment") -ax.set_xlabel("Point ID / -") -ax.set_ylabel("Distance / m") -ax.legend() +fig, axs = plt.subplots(ncols=2, figsize=[26, 10]) +u_x = ot.variables.displacement["x"] +ot.plot.contourf(mesh, u_x, fig=fig, ax=axs[1]) +for i, sample in enumerate([sample_1, sample_2, sample_3, sample_4]): + c = f"C{i}" # cycle through default color cycle + ot.plot.line(sample, "y", "x", ax=axs[1], linestyle="--", color=c) + ot.plot.line(sample, "y", u_x, ax=axs[0], label=f"sample {i + 1}", color=c) fig.tight_layout() # %% [markdown] -# The orange line returns to 0 twice. It is because of how the overlap of nodal -# points between segments is handled. A nodal point always belongs to the -# segment it starts: point B is included in segment BC but not AB and point -# C in CD but not in in BC. The following figure shows the profile on the mesh: - -# %% -# plt.rcdefaults() -fig, ax = mesh.plot_linesample_contourf( - [ot.variables.pressure, ot.variables.temperature], - profile_THM, - resolution=100, -) -# %% +# If you want to sample data over multiple timesteps in a MeshSeries, have a +# look at :ref:`sphx_glr_auto_examples_howto_plot_plot_timeslice.py`. diff --git a/docs/examples/howto_preprocessing/plot_remeshing.py b/docs/examples/howto_preprocessing/plot_remeshing.py index f86f2fcc97f340ae7463dccb333806535ce065e2..d02a6d104b62768be40938f15aa0a93f5c799530 100644 --- a/docs/examples/howto_preprocessing/plot_remeshing.py +++ b/docs/examples/howto_preprocessing/plot_remeshing.py @@ -4,7 +4,7 @@ Remeshing with triangle elements .. sectionauthor:: Florian Zill (Helmholtz Centre for Environmental Research GmbH - UFZ) -This short example showcases the function `remesh_with_tri` which allows us to +This short example showcases the function ``remesh_with_tri`` which allows us to take an existing mesh and re-discretize it with triangle elements. This is useful for models, where the underlying meshing script is not available or hard to adapt. diff --git a/docs/examples/howto_prjfile/plot_manipulation.py b/docs/examples/howto_prjfile/plot_manipulation.py index 618ab9cfa6c3945be12f12bdf79494f4e5907bce..5894312e3c0c17ccefb757d358a54331dccc300a 100644 --- a/docs/examples/howto_prjfile/plot_manipulation.py +++ b/docs/examples/howto_prjfile/plot_manipulation.py @@ -28,7 +28,7 @@ for youngs_modulus in youngs_moduli: prj.run_model(args=f"-m {EXAMPLES_DIR}/prj/ -o {model_dir}") # %% -# Instead of the `replace_parameter` method, the more general `replace_text` method +# Instead of the ``replace_parameter`` method, the more general ``replace_text`` method # can also be used to replace the young modulus in this example: prj.replace_text(youngs_modulus, xpath="./parameters/parameter[name='E']/value") diff --git a/docs/examples/howto_quickstart/plot_meshseries.py b/docs/examples/howto_quickstart/plot_meshseries.py index 2ac27722f53c0862fe90d3054cb7e9cbf567eae6..0693715ece0f3170a6afd4473cdcd4e43d62fd0a 100644 --- a/docs/examples/howto_quickstart/plot_meshseries.py +++ b/docs/examples/howto_quickstart/plot_meshseries.py @@ -31,11 +31,16 @@ ms # %% [markdown] # Accessing time values # ===================== -# Time values can be unit transformed. By default they are output in seconds. +# Time values (and spatial coordinates) can be unit transformed via +# :meth:`~ogstools.meshlib.mesh_series.MeshSeries.scale`. Either pass a tuple +# to convert from the first to the second unit or pass a scaling factor. # %% -print(f"First 3 time values are: {ms.timevalues()[:3]} s.") -print(f"Last time value is: {ms.timevalues(time_unit='h')[-1]} h.") +print(f"First 3 time values are: {ms.timevalues[:3]} s.") +ms = ms.scale(time=("s", "h")) +print(f"Last time value is: {ms.timevalues[-1]} h.") +ms = ms.scale(time=3600.0) +print(f"Last time value is: {ms.timevalues[-1]} s.") # %% [markdown] # Accessing meshes @@ -62,7 +67,7 @@ mesh_ts10.plot(show_edges=True) # number of points/cells, and the last dimension is the number of components of # the variable. # -# By default, `values` would read the entire dataset. If only a subset of the +# By default, ``values`` would read the entire dataset. If only a subset of the # MeshSeries should be read you can select the relevant timesteps by indexing / # slicing the MeshSeries directly. This selection will also be adhered to if you # read individual meshes. @@ -77,7 +82,7 @@ print("Every second timestep:", np.shape(ms[::2].values("temperature"))) print("Last two steps:", np.shape(ms[-2:].values("darcy_velocity"))) # %% [markdown] -# To select points or cells you can use the `extract` method to specify the +# To select points or cells you can use the ``extract`` method to specify the # corresponding ids. # %% @@ -86,7 +91,7 @@ print("Data on extracted points:", np.shape(temp_at_points)) print("Temperatures at last timestep:", temp_at_points[-1]) # %% [markdown] -# You can also use pyvista dataset filters to `transform` the domain for the +# You can also use pyvista dataset filters to ``transform`` the domain for the # entire MeshSeries. # %% diff --git a/docs/examples/howto_simulation/plot_100_logparser_intro.py b/docs/examples/howto_simulation/plot_100_logparser_intro.py index dd9f44bd0529b4eee2a0ed37ba9d0f5aa69a2f72..27a59ad975bbd3658b28ab4d93ba7408ede12788 100644 --- a/docs/examples/howto_simulation/plot_100_logparser_intro.py +++ b/docs/examples/howto_simulation/plot_100_logparser_intro.py @@ -39,7 +39,7 @@ df_ts_it # %% [markdown] # The log file # ------------- -# Running `ogs` in the command line outputs the logs into the console output. With +# Running ``ogs`` in the command line outputs the logs into the console output. With # `tee # <https://en.wikipedia.org/wiki/Tee_(command)>`_ in Linux and Mac # and @@ -64,7 +64,7 @@ df_ts_it # :py:mod:`ogstools.logparser.fill_ogs_context` are explained in # :ref:`sphx_glr_auto_examples_howto_simulation_plot_102_logparser_advanced.py`. # All predefined analyses need the result of fill_ogs_context. -# Here `const_viscosity_thermal_convection_log` is string representing the +# Here ``const_viscosity_thermal_convection_log`` is string representing the # location of the ogs log file. print(log_const_viscosity_thermal_convection) # %% @@ -81,7 +81,7 @@ df_log = fill_ogs_context(df_records) # # Here we are interested in every time step of the simulation and how many # iterations have been needed. -# The predefined analyses only work with logs from `ogs` run with level `info` or finer (`debug`), like `ogs -l info` or `ogs - l debug`. +# The predefined analyses only work with logs from ``ogs`` run with level ``info`` or finer (`debug`), like `ogs -l info` or `ogs - l debug`. # (see # `OGS Developer Guide - log and debug output # <https://www.opengeosys.org/docs/devguide/advanced/log-and-debug-output>`_ @@ -96,6 +96,6 @@ df_ts_it # Pandas to plot # -------------- # You can directly use -# `plot` ` +# ``plot`` # <https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.plot.html>`_ from pandas. df_ts_it.plot(grid=True) diff --git a/docs/examples/howto_simulation/plot_102_logparser_advanced.py b/docs/examples/howto_simulation/plot_102_logparser_advanced.py index d3f9fd51debb6e858d8b17c9110485590cb17112..d503241fd790f9758a6b7d0d292c92892af4f208 100644 --- a/docs/examples/howto_simulation/plot_102_logparser_advanced.py +++ b/docs/examples/howto_simulation/plot_102_logparser_advanced.py @@ -72,7 +72,7 @@ df_ts[["output_time", "assembly_time"]].boxplot() # For this example we are only interested in the number of iterations per time step. # Because the parsing process is expensive, it is useful to store the records to a file. # According to :py:mod:`ogstools.logparser.parse_file` -# via parameter `regexes` a list of reduced or custom regexes can be provided. +# via parameter ``regexes`` a list of reduced or custom regexes can be provided. # # Save and load records # ~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/releases/ogstools-0.x.md b/docs/releases/ogstools-0.x.md index e02780d64e578c4362002518dad52aa04012b479..2c300788dc9982e8e411ecf2feeda7618af536e4 100644 --- a/docs/releases/ogstools-0.x.md +++ b/docs/releases/ogstools-0.x.md @@ -11,6 +11,12 @@ This is not released yet! - CLI tool msh2vtu is not affected by this - parameter keep_ids was removed (in our OGS world there is no reason to keep the gmsh data names and the wrong data types in the meshes, which would happen if k was used) - parameter log_level was changed to log (True or False) +- removed: + - MeshSeries.spatial_data_unit/spatial_output_unit/time_unit (see + MeshSeries.scale()) + - plot.linesample/linesample_contourf + - meshlib.data_processing.interp_points/distance_in_profile/sample_polyline + (see update line sample example) ## Features @@ -19,6 +25,11 @@ This is not released yet! - MeshSeries get extract() method to select points or cells via ids - MeshSeries can be sliced to get new MeshSeries with the selected subset of timesteps - difference() between two meshes is now possible even with different topologies +- MeshSeries gets scale() method to scale spatially or temporally +- variables.get_preset will now return a Variable corresponding to the spatial + coordinates if given "x", "y" or "z" +- plot module gets line() function as a general purpose 1D plotting function +- plot.setup get spatial_unit and time_unit which are used for labeling ## Infrastructure diff --git a/ogstools/definitions.py b/ogstools/definitions.py index f9a8eb327b7185696b841fffb454d47f4fc28353..79f07278272c82fed9da880bec7cfb4696b8f614 100644 --- a/ogstools/definitions.py +++ b/ogstools/definitions.py @@ -8,5 +8,3 @@ import pathlib ROOT_DIR = pathlib.Path(__file__).parent.resolve() EXAMPLES_DIR = ROOT_DIR / "examples" - -SPATIAL_UNITS_KEY = "data_length_unit" diff --git a/ogstools/examples/__init__.py b/ogstools/examples/__init__.py index a9f8aed71066af172157cfa326675bfd10fe99e2..ec7402e328b93dbb337d068a6d35af9947723394 100644 --- a/ogstools/examples/__init__.py +++ b/ogstools/examples/__init__.py @@ -21,26 +21,22 @@ _shp_dir = EXAMPLES_DIR / "shapefiles" def load_meshseries_THM_2D_PVD(): - return MeshSeries( - str(_meshseries_dir / "2D.pvd"), time_unit="s", spatial_output_unit="km" - ) + return MeshSeries(str(_meshseries_dir / "2D.pvd")) def load_meshseries_CT_2D_XDMF(): - return MeshSeries(str(_meshseries_dir / "elder.xdmf"), time_unit="s") + return MeshSeries(str(_meshseries_dir / "elder.xdmf")) def load_meshseries_HT_2D_XDMF(): return MeshSeries( - str(_meshseries_dir / "2D_single_fracture_HT_2D_single_fracture.xdmf"), - time_unit="s", + str(_meshseries_dir / "2D_single_fracture_HT_2D_single_fracture.xdmf") ) def load_meshseries_HT_2D_PVD(): return MeshSeries( - str(_meshseries_dir / "2D_single_fracture_HT_2D_single_fracture.pvd"), - time_unit="s", + str(_meshseries_dir / "2D_single_fracture_HT_2D_single_fracture.pvd") ) @@ -50,22 +46,18 @@ def load_meshseries_HT_2D_VTU(): _meshseries_dir / "2D_single_fracture_HT_2D_single_fracture" / "2D_single_fracture_HT_2D_single_fracture_0_96.vtu" - ), - time_unit="s", + ) ) def load_meshseries_HT_2D_paraview_XMF(): return MeshSeries( - str(_meshseries_dir / "2D_single_fracture_HT_2D_single_fracture.xmf"), - time_unit="s", + str(_meshseries_dir / "2D_single_fracture_HT_2D_single_fracture.xmf") ) def load_mesh_mechanics_2D(): - return Mesh.read( - _meshseries_dir / "mechanics_example.vtu", spatial_output_unit="km" - ) + return Mesh.read(_meshseries_dir / "mechanics_example.vtu") msh_geolayers_2d = _msh_dir / "geolayers_2d.msh" diff --git a/ogstools/meshlib/__init__.py b/ogstools/meshlib/__init__.py index 4bf7bff4412bb4b2e4b6e3854b5b72b12d731ff4..d711d19f8d0ea10aa3d099c7986a60e28d28ef86 100644 --- a/ogstools/meshlib/__init__.py +++ b/ogstools/meshlib/__init__.py @@ -7,15 +7,7 @@ from .boundary import Boundary, Layer, LocationFrame, Raster from .boundary_set import LayerSet from .boundary_subset import Gaussian2D, Surface -from .data_processing import ( - difference, - difference_matrix, - difference_pairwise, - distance_in_profile, - distance_in_segments, - interp_points, - sample_polyline, -) +from .data_processing import difference, difference_matrix, difference_pairwise from .gmsh_converter import meshes_from_gmsh from .gmsh_meshing import cuboid, rect from .ip_mesh import to_ip_mesh, to_ip_point_cloud @@ -43,13 +35,9 @@ __all__ = [ "difference", "difference_matrix", "difference_pairwise", - "distance_in_profile", - "distance_in_segments", - "interp_points", "meshes_from_gmsh", "read_shape", "rect", - "sample_polyline", "to_ip_mesh", "to_ip_point_cloud", "to_region_prism", diff --git a/ogstools/meshlib/data_processing.py b/ogstools/meshlib/data_processing.py index 46c9a7d94133c229f1566a6e0036cbc50e009290..f8cd96b1045869fcaaf3ac265cc90baf06ae59c7 100644 --- a/ogstools/meshlib/data_processing.py +++ b/ogstools/meshlib/data_processing.py @@ -4,15 +4,14 @@ # http://www.opengeosys.org/project/license # -from itertools import pairwise, product +from itertools import product from typing import TypeVar import numpy as np -import pandas as pd import pyvista as pv from typeguard import typechecked -from ogstools.variables import Variable, get_preset +from ogstools.variables import Variable Mesh = TypeVar("Mesh", bound=pv.UnstructuredGrid) @@ -139,155 +138,3 @@ def difference_matrix( difference(m1, m2, variable) for m1, m2 in product(meshes_1, meshes_2) ] return np.asarray(diff_meshes).reshape((len(meshes_1), len(meshes_2))) - - -@typechecked -def interp_points(points: np.ndarray, resolution: int = 100) -> np.ndarray: - """ - Provides lists of points on every segment at a line profile between \ - arbitrary number of points pairs. - - :param points: Numpy array of N points to sample between.\ - Has to be of shape (N, 3). - - :param resolution: Resolution of the sampled profile. Total number of \ - points within all profile segments. - - :returns: Numpy array of shape (N, 3), without duplicated nodal points. - """ - profile = np.zeros([0, 3]) - distances = np.linalg.norm(np.diff(points, axis=0), axis=1) - - npoints_per_segment = np.ceil( - distances / np.sum(distances) * resolution - ).astype(int) - - for pt1, pt2, n_points in zip( - points[:-1], points[1:], npoints_per_segment, strict=True - ): - new_seg_points = np.linspace(pt1, pt2, n_points, False) - profile = np.vstack([profile, new_seg_points]) - - return np.vstack([profile, points[-1]]) - - -@typechecked -def distance_in_segments( - profile_nodes: np.ndarray, profile: np.ndarray -) -> np.ndarray: - """ - Calculate the distance within segments of a polyline profile. - - :param profile_nodes: 2D array of N points (profile nodes) of shape (N, 3) - :param profile: output from interp_points function. 2D array of N points \ - (profile nodes) of shape (N, 3) - - :returns: 1D array of distances in each segment to its starting point \ - of shape (N, 3), where N is the number of points in profile - """ - point_index = [ - np.argmin(np.sum(np.abs(profile - pt), axis=1)) for pt in profile_nodes - ] - if not (point_index[0] == 0 and point_index[-1] == profile.shape[0] - 1): - err_msg = "Something went wrong with generating profile_points!" - raise ValueError(err_msg) - dist_in_segment = np.zeros([profile.shape[0]]) - for pt_id1, pt_id2 in pairwise(point_index): - dist_current_segment = profile[pt_id1:pt_id2] - profile[pt_id1] - dist_current_segment = np.linalg.norm(dist_current_segment, axis=1) - dist_in_segment[pt_id1:pt_id2] = dist_current_segment - - # Handle last point - dist_in_segment[-1] = np.linalg.norm(profile[-1] - profile_nodes[-2]) - - return dist_in_segment - - -@typechecked -def distance_in_profile(points: np.ndarray) -> np.ndarray: - """ - :param points: 2D array of N points (profile nodes) of shape (N, 3) - - :returns: 1D array of distances of each point to the beginning of the \ - profile (first row in points), shape of (N,) - """ - return np.concatenate( - ([0], np.cumsum(np.linalg.norm(np.diff(points, axis=0), axis=1))) - ) - - -def sample_polyline( - mesh: pv.UnstructuredGrid, - variables: str | Variable | list[str] | list[Variable], - profile_nodes: np.ndarray, - resolution: int | None = 100, -) -> tuple[pd.DataFrame, np.ndarray]: - """ - Sample one or more variables along a polyline. - Profiles created by user can be passed as profile_nodes parameter. In this - case user should also set resolution to None in order to avoid further - interpolation between the points. - - :param mesh: Mesh from which variables will be sampled. - :param variables: Name or list of names of variables to sample. - :param profile_nodes: 2D array of N points (profile nodes) of shape (N, 3) - :param resolution: Total number of sampling points. - - :returns: tuple containing DataFrame with results of the profile sampling - and Numpy array of distances from the beginning of the profile - at points defined in profile_points. - """ - _variables = [variables] if not isinstance(variables, list) else variables - variables = [get_preset(var, mesh) for var in _variables] - - if resolution is None: - # Only cumulative distance alongside the profile will be returned - profile_points = profile_nodes - assert isinstance(resolution, int) - profile_points = interp_points(profile_nodes, resolution=resolution) - sampled_data_dist_in_segment = distance_in_segments( - profile_nodes, profile_points - ) - dist_at_nodes = distance_in_profile(profile_nodes) - - sampled_data_distance = distance_in_profile(profile_points) - - line = pv.PolyData(profile_points) - sampled_data = line.sample(mesh) - - output_data = {["x", "y", "z"][i]: profile_points[:, i] for i in [0, 1, 2]} - - # TODO: data should be written in output_name otherwise different - # variables with the same data_name will override each other - for variable_current in variables: - # TODO: workaround for Issue 59 - if variable_current.data_name in sampled_data.point_data: - variable_name = variable_current.data_name - elif variable_current.output_name in sampled_data.point_data: - variable_name = variable_current.output_name - else: - err_msg = "Cannot match variable name to variables available\ - in mesh!" - raise KeyError(err_msg) - sampled_data_variable = sampled_data[variable_name] - if isinstance(variable_current, Variable): - sampled_data_variable = variable_current.transform( - data=sampled_data_variable - ) - if variable_name not in output_data: - if len(sampled_data_variable.shape) > 1: - # Vector variables - for variable_id in range(sampled_data_variable.shape[1]): - variable_key = f"{variable_name}_{variable_id}" - output_data[variable_key] = sampled_data_variable[ - :, variable_id - ] - else: - # Scalar variables - output_data[variable_name] = sampled_data_variable - - output_data["dist"] = sampled_data_distance - if isinstance(resolution, int): - output_data["dist_in_segment"] = sampled_data_dist_in_segment - - return pd.DataFrame.from_dict(output_data), dist_at_nodes diff --git a/ogstools/meshlib/ip_mesh.py b/ogstools/meshlib/ip_mesh.py index 2b4111d2bb256a88fd08ff9cc27ec2b89652650b..12a9ca64f8536a4022fa2cb121705e595637de93 100644 --- a/ogstools/meshlib/ip_mesh.py +++ b/ogstools/meshlib/ip_mesh.py @@ -6,8 +6,6 @@ import numpy as np import ogs import pyvista as pv -from ogstools.definitions import SPATIAL_UNITS_KEY - Mesh = TypeVar("Mesh", bound=pv.UnstructuredGrid) @@ -192,7 +190,6 @@ def to_ip_point_cloud(mesh: Mesh) -> pv.UnstructuredGrid: "Convert integration point data to a pyvista point cloud." # ipDataToPointCloud can't handle this bad_keys = [ - SPATIAL_UNITS_KEY, "material_state_variable_ElasticStrain_ip", "free_energy_density_ip", ] @@ -223,7 +220,6 @@ def to_ip_mesh(mesh: Mesh) -> pv.UnstructuredGrid: _mesh = mesh.extract_cells_by_type(cell_type) new_meshes += [tessellate(_mesh, cell_type, integration_order)] new_mesh = new_meshes[0] - new_mesh.field_data[SPATIAL_UNITS_KEY] = mesh.field_data[SPATIAL_UNITS_KEY] for _mesh in new_meshes[1:]: new_mesh = new_mesh.merge(_mesh) new_mesh = new_mesh.clean() diff --git a/ogstools/meshlib/mesh.py b/ogstools/meshlib/mesh.py index f42e9d000654b3e8709b3f1b8a53126546378eae..8a617c65407aaebc958c68dd6e918cc3a2413155 100644 --- a/ogstools/meshlib/mesh.py +++ b/ogstools/meshlib/mesh.py @@ -16,8 +16,6 @@ import pyvista as pv import ogstools.meshlib as ml from ogstools import plot from ogstools._internal import copy_method_signature -from ogstools.definitions import SPATIAL_UNITS_KEY -from ogstools.plot import lineplots from . import data_processing, geo, ip_mesh, shape_meshing @@ -54,14 +52,6 @@ class Mesh(pv.UnstructuredGrid): def plot_streamlines(self, *args: Any, **kwargs: Any) -> Any: return plot.streamlines(self, *args, **kwargs) - @copy_method_signature(lineplots.linesample) - def plot_linesample(self, *args: Any, **kwargs: Any) -> Any: - return lineplots.linesample(self, *args, **kwargs) - - @copy_method_signature(lineplots.linesample_contourf) - def plot_linesample_contourf(self, *args: Any, **kwargs: Any) -> Any: - return lineplots.linesample_contourf(self, *args, **kwargs) - def to_ip_mesh(self) -> Mesh: return Mesh(ip_mesh.to_ip_mesh(self)) @@ -76,16 +66,12 @@ class Mesh(pv.UnstructuredGrid): def __init__( self, pv_mesh: pv.UnstructuredGrid | None = None, - spatial_unit: str = "m", - spatial_output_unit: str = "m", **kwargs: dict, ): """ Initialize a Mesh object :param pv_mesh: Underlying pyvista mesh. - :param data_length_unit: Length unit of the mesh data. - :param output_length_unit: Length unit in plots. """ if not pv_mesh: # for copy constructor @@ -93,23 +79,13 @@ class Mesh(pv.UnstructuredGrid): super().__init__(**kwargs) else: super().__init__(pv_mesh, **kwargs) - self.field_data[SPATIAL_UNITS_KEY] = np.asarray( - [ord(char) for char in f"{spatial_unit},{spatial_output_unit}"] - ) @classmethod - def read( - cls, - filepath: str | Path, - spatial_unit: str = "m", - spatial_output_unit: str = "m", - ) -> Mesh: + def read(cls, filepath: str | Path) -> Mesh: """ Initialize a Mesh object :param filepath: Path to the mesh or shapefile file. - :param data_length_unit: Spatial data unit of the mesh. - :param output_length_unit: Spatial output unit of the mesh. :returns: A Mesh object """ @@ -118,9 +94,6 @@ class Mesh(pv.UnstructuredGrid): else: mesh = cls(pv.read(filepath)) - mesh.field_data[SPATIAL_UNITS_KEY] = np.asarray( - [ord(char) for char in f"{spatial_unit},{spatial_output_unit}"] - ) return mesh @classmethod diff --git a/ogstools/meshlib/mesh_series.py b/ogstools/meshlib/mesh_series.py index 068cf6a1e73e1dfe08e54be9dcc1b5500a1d2929..9c2c78721b63ee3e97d19c868dc1addcbea533a3 100644 --- a/ogstools/meshlib/mesh_series.py +++ b/ogstools/meshlib/mesh_series.py @@ -39,29 +39,19 @@ class MeshSeries: A wrapper around pyvista and meshio for reading of pvd and xdmf timeseries. """ - def __init__( - self, - filepath: str | Path, - time_unit: str = "s", - spatial_unit: str = "m", - spatial_output_unit: str = "m", - ) -> None: + def __init__(self, filepath: str | Path) -> None: """ Initialize a MeshSeries object :param filepath: Path to the PVD or XDMF file. - :param time_unit: Data unit of the timevalues. - :param data_length_unit: Length unit of the mesh data. - :param output_length_unit: Length unit in plots. :returns: A MeshSeries object """ if isinstance(filepath, Path): filepath = str(filepath) self.filepath = filepath - self.time_unit = time_unit - self.spatial_unit = spatial_unit - self.spatial_output_unit = spatial_output_unit + self._spatial_factor = 1.0 + self._time_factor = 1.0 self._mesh_cache: dict[float, Mesh] = {} self._mesh_func_opt: Callable[[Mesh], Mesh] | None = None self._data_type = filepath.split(".")[-1] @@ -171,9 +161,8 @@ class MeshSeries: return ( f"MeshSeries:\n" f"filepath: {self.filepath}\n" - f"spatial_unit: {self.spatial_unit}\n" f"data_type: {self._data_type}\n" - f"timevalues: {self._timevalues[0]}{self.time_unit} to {self._timevalues[-1]}{self.time_unit} in {len(self._timevalues)} steps\n" + f"timevalues: {self.timevalues[0]} to {self.timevalues[-1]} in {len(self.timevalues)} steps\n" f"reader: {reader}\n" f"rawdata_file: {self.rawdata_file()}\n" ) @@ -205,19 +194,16 @@ class MeshSeries: def closest_timestep(self, timevalue: float) -> int: """Return the corresponding timestep from a timevalue.""" - return int(np.argmin(np.abs(self._timevalues - timevalue))) + return int(np.argmin(np.abs(self.timevalues - timevalue))) def closest_timevalue(self, timevalue: float) -> float: """Return the closest timevalue to a timevalue.""" - return self._timevalues[self.closest_timestep(timevalue)] + return self.timevalues[self.closest_timestep(timevalue)] def ip_tesselated(self) -> MeshSeries: "Create a new MeshSeries from integration point tessellation." ip_ms = MeshSeries( - Path(self.filepath).parent / "ip_meshseries.synthetic", - self.time_unit, - self.spatial_unit, - self.spatial_output_unit, + Path(self.filepath).parent / "ip_meshseries.synthetic" ) ip_mesh = self.mesh(0).to_ip_mesh() ip_pt_cloud = self.mesh(0).to_ip_point_cloud() @@ -229,18 +215,20 @@ class MeshSeries: } ip_mesh.cell_data.update(ip_data) ip_ms._mesh_cache[ - self.timevalues()[ts] + self.timevalues[ts] ] = ip_mesh.copy() # pylint: disable=protected-access ip_ms._timevalues = self._timevalues # pylint: disable=protected-access return ip_ms def mesh(self, timestep: int, lazy_eval: bool = True) -> Mesh: """Returns the mesh at the given timestep.""" - timevalue = self.timevalues()[timestep] - if not np.any(timevalue_match := (self._timevalues == timevalue)): + timevalue = self.timevalues[timestep] + if not np.any(self.timevalues == timevalue): msg = f"Value {timevalue} not found in the array." raise ValueError(msg) - data_timestep = np.argmax(timevalue_match) + data_timestep = np.argmax( + self._timevalues * self._time_factor == timevalue + ) if timevalue in self._mesh_cache: mesh = self._mesh_cache[timevalue] else: @@ -254,11 +242,7 @@ class MeshSeries: case _: msg = f"Unexpected datatype {self._data_type}." raise TypeError(msg) - mesh = Mesh( - self.mesh_func(pv_mesh), - self.spatial_unit, - self.spatial_output_unit, - ) + mesh = Mesh(self.mesh_func(pv_mesh)) if lazy_eval: self._mesh_cache[timevalue] = mesh return mesh @@ -278,7 +262,7 @@ class MeshSeries: def read_interp(self, timevalue: float, lazy_eval: bool = True) -> Mesh: """Return the temporal interpolated mesh for a given timevalue.""" - t_vals = self._timevalues + t_vals = self.timevalues ts1 = int(t_vals.searchsorted(timevalue, "right") - 1) ts2 = min(ts1 + 1, len(t_vals) - 1) if np.isclose(timevalue, t_vals[ts1]): @@ -296,21 +280,18 @@ class MeshSeries: ) return mesh - def timevalues(self, time_unit: str | None = None) -> np.ndarray: - "Return the timevalues, optionally converted to another time unit." + @property + def timevalues(self) -> np.ndarray: + "Return the timevalues." vals = self._timevalues for index in self._time_indices: vals = vals[index] - return ( - u_reg.Quantity(vals, self.time_unit) - .to(self.time_unit if time_unit is None else time_unit) - .magnitude - ) + return vals * self._time_factor @property def timesteps(self) -> list: """Return the timesteps of the timeseries data.""" - return np.arange(len(self.timevalues()), dtype=int) + return np.arange(len(self.timevalues), dtype=int) def _xdmf_values(self, variable_name: str) -> np.ndarray: dataitems = self._xdmf_reader.data_items[variable_name] @@ -337,7 +318,7 @@ class MeshSeries: ms_copy = self.copy(deep=True) ms_copy._mesh_func_opt = None # pylint: disable=protected-access ms_copy.clear_cache() - raw_meshes = [ms_copy.mesh(0)] * len(result) + raw_meshes = list(ms_copy) for mesh, data in zip(raw_meshes, result, strict=True): mesh[variable_name] = data meshes = list(map(self.mesh_func, raw_meshes)) @@ -369,7 +350,7 @@ class MeshSeries: if ( self._data_type == "xdmf" and variable_name in self._xdmf_reader.data_items - and not all(tv in self._mesh_cache for tv in self.timevalues()) + and not all(tv in self._mesh_cache for tv in self.timevalues) ): result = self._xdmf_values(variable_name) else: @@ -409,7 +390,7 @@ class MeshSeries: mesh.clear_point_data() mesh.clear_cell_data() output_name = f"{prefix}_{variable.output_name}_time" - mesh[output_name] = self._timevalues[ + mesh[output_name] = self.timevalues[ np_func(self.values(variable), axis=0) ] return mesh @@ -439,7 +420,6 @@ class MeshSeries: self, variable: Variable | str, func: Callable, - time_unit: str | None = "s", ax: plt.Axes | None = None, **kwargs: Any, ) -> plt.Figure | None: @@ -449,7 +429,6 @@ class MeshSeries: :param variable: The mesh variable to be aggregated. :param func: The aggregation function to apply. E.g. np.min, np.max, np.mean, np.median, np.sum, np.std, np.var - :param time_unit: Output unit of the timevalues. :param ax: matplotlib axis to use for plotting :param kwargs: Keyword args passed to matplotlib's plot function. @@ -457,9 +436,8 @@ class MeshSeries: """ variable = get_preset(variable, self.mesh(0)) values = self.aggregate_over_domain(variable.magnitude, func) - time_unit = time_unit if time_unit is not None else self.time_unit - x_values = self.timevalues(time_unit) - x_label = f"time t / {time_unit}" + x_values = self.timevalues + x_label = f"time t / {plot.setup.time_unit}" if ax is None: fig, ax = plt.subplots() else: @@ -522,7 +500,6 @@ class MeshSeries: variable: Variable | str, variable_abscissa: Variable | str | None = None, labels: list[str] | None = None, - time_unit: str | None = "s", interp_method: Literal["nearest", "linear"] = "linear", colors: list | None = None, linestyles: list | None = None, @@ -536,7 +513,6 @@ class MeshSeries: :param points: The points to sample at. :param variable: The variable to be sampled. :param labels: The labels for each observation point. - :param time_unit: Output unit of the timevalues. :param interp_method: Choose the interpolation method, defaults to `linear` for xdmf MeshSeries and `probefilter` for pvd MeshSeries. @@ -553,11 +529,9 @@ class MeshSeries: ) if values.shape[0] == 1: values = values.ravel() - Q_ = u_reg.Quantity - time_unit_conversion = Q_(Q_(self.time_unit), time_unit).magnitude if variable_abscissa is None: - x_values = time_unit_conversion * self._timevalues - x_label = f"time / {time_unit}" if time_unit else "time" + x_values = self.timevalues + x_label = f"time / {plot.setup.time_unit}" else: variable_abscissa = get_preset(variable_abscissa, self.mesh(0)) x_values = variable_abscissa.magnitude.transform( @@ -666,7 +640,6 @@ class MeshSeries: points: np.ndarray, y_axis: Literal["x", "y", "z", "dist", "auto"] = "auto", interpolate: bool = True, - time_unit: str = "s", time_logscale: bool = False, fig: plt.Figure | None = None, ax: plt.Axes | None = None, @@ -681,7 +654,6 @@ class MeshSeries: is changing, this axis is taken, otherwise the distance along the line is taken. :param interpolate: Smoothen the result be interpolation. - :param time_unit: Time unit displayed on the x-axis. :param time_logscale: Should log-scaling be applied to the time-axis? :param fig: matplotlib figure to use for plotting. :param ax: matplotlib axis to use for plotting. @@ -707,9 +679,7 @@ class MeshSeries: msg = "Please provide fig and ax together or not at all." raise ValueError(msg) - time = Variable("", self.time_unit, time_unit).transform( - self._timevalues - ) + time = self.timevalues if time_logscale: time = np.log10(time, where=time != 0) time[0] = time[1] - (time[2] - time[1]) @@ -738,8 +708,6 @@ class MeshSeries: else: y = np.linalg.norm(points - points[0], axis=1) ylabel = "distance" - spatial = plot.shared.spatial_quantity(self.mesh(0)) - y = spatial.transform(y) if interpolate: grid_interp = RegularGridInterpolator( @@ -764,8 +732,10 @@ class MeshSeries: ax.pcolormesh(time, y, values.T, cmap=cmap, norm=norm) fontsize = kwargs.get("fontsize", plot.setup.fontsize) - ax.set_ylabel(ylabel + " / " + spatial.output_unit, fontsize=fontsize) - xlabel = "time / " + time_unit + ax.set_ylabel( + ylabel + " / " + plot.setup.spatial_unit, fontsize=fontsize + ) + xlabel = "time / " + plot.setup.time_unit if time_logscale: xlabel = "log10( " + xlabel + " )" ax.set_xlabel(xlabel, fontsize=fontsize) @@ -779,8 +749,12 @@ class MeshSeries: def mesh_func(self) -> Callable[[Mesh], Mesh]: """Returns stored transformation function or identity if not given.""" if self._mesh_func_opt is None: - return lambda mesh: mesh - return self._mesh_func_opt + return lambda mesh: mesh.scale(self._spatial_factor) + return lambda mesh: Mesh( + self._mesh_func_opt(mesh).scale( # type: ignore[misc] + self._spatial_factor + ) + ) def transform( self, mesh_func: Callable[[Mesh], Mesh] = lambda mesh: mesh @@ -796,14 +770,47 @@ class MeshSeries: ms_copy = self.copy(deep=True) # pylint: disable=protected-access for cache_timevalue, cache_mesh in self._mesh_cache.items(): - ms_copy._mesh_cache[cache_timevalue] = Mesh( - mesh_func(cache_mesh), - ms_copy.spatial_unit, - ms_copy.spatial_output_unit, - ) + ms_copy._mesh_cache[cache_timevalue] = Mesh(mesh_func(cache_mesh)) ms_copy._mesh_func_opt = lambda mesh: mesh_func(self.mesh_func(mesh)) return ms_copy + def scale( + self, + spatial: float | tuple[str, str] = 1.0, + time: float | tuple[str, str] = 1.0, + ) -> MeshSeries: + """Scale the spatial coordinates and timevalues. + + Useful to convert to other units, e.g. "m" to "km" or "s" to "a". + If given as tuple of strings, the latter units will also be set in + ot.plot.setup.spatial_unt and ot.plot.setup.time_unit for plotting. + + :param spatial: Float factor or a tuple of str (from_unit, to_unit). + :param time: Float factor or a tuple of str (from_unit, to_unit). + """ + Qty = u_reg.Quantity + if isinstance(spatial, float): + spatial_factor = spatial + else: + spatial_factor = Qty(Qty(spatial[0]), spatial[1]).magnitude + plot.setup.spatial_unit = spatial[1] + if isinstance(time, float): + time_factor = time + else: + time_factor = Qty(Qty(time[0]), time[1]).magnitude + plot.setup.time_unit = time[1] + self._spatial_factor *= spatial_factor + self._time_factor *= time_factor + + scaled_cache = { + timevalue * time_factor: Mesh(mesh.scale(spatial_factor)) + for timevalue, mesh in self._mesh_cache.items() + } + self.clear_cache() + self._mesh_cache = scaled_cache + + return self + @typechecked def extract( self, diff --git a/ogstools/plot/__init__.py b/ogstools/plot/__init__.py index 298a30a6fb7fe90d4f188821202cebcab3fe54eb..bda4b50a8669bda5b10171729c263fc4dd3716ac 100644 --- a/ogstools/plot/__init__.py +++ b/ogstools/plot/__init__.py @@ -7,18 +7,18 @@ # Author: Florian Zill (Helmholtz Centre for Environmental Research GmbH - UFZ) """Plotting utilities for simple access.""" +from .shared import setup # noqa: I001 from . import utils from .contourplots import contourf, subplot -from .features import outline, shape_on_top +from .features import shape_on_top from .levels import compute_levels -from .shared import setup +from .lineplots import line from .vectorplots import quiver, streamlines __all__ = [ "compute_levels", "contourf", - "outline", - "plot_time_slice", + "line", "quiver", "setup", "shape_on_top", diff --git a/ogstools/plot/contourplots.py b/ogstools/plot/contourplots.py index 1e40f508a554b09a13c7c96a178151ad3862083f..048b8b2904df483e146e5b601718fe46773c5995 100644 --- a/ogstools/plot/contourplots.py +++ b/ogstools/plot/contourplots.py @@ -23,7 +23,7 @@ from ogstools.variables import Variable, Vector, get_preset from . import features from .levels import compute_levels, median_exponent -from .shared import setup, spatial_quantity +from .shared import setup from .vectorplots import streamlines # TODO: define default data_name for regions in setup @@ -167,8 +167,7 @@ def subplot( # faces contains a padding indicating number of points per face which gets # removed with this reshaping and slicing to get the array of tri's - spatial = spatial_quantity(surf_tri) - x, y = spatial.transform(surf_tri.points.T[[x_id, y_id]]) + x, y = surf_tri.points.T[[x_id, y_id]] tri = surf_tri.faces.reshape((-1, 4))[:, 1:] values = variable.magnitude.transform(surf_tri) # Passing the data and not the mesh here purposely to ensure correct shape @@ -239,7 +238,7 @@ def subplot( if not show: continue index = np.unravel_index(func(values), values.shape)[0] - x_pos, y_pos = spatial.transform(mesh.points[index, [x_id, y_id]]) + x_pos, y_pos = mesh.points[index, [x_id, y_id]] value = values[mesh.find_closest_point(mesh.points[index])] color = utils.contrast_color(cmap(norm(value))) ax.plot( @@ -273,7 +272,7 @@ def subplot( mticker.FixedLocator(list(ax.get_xticks())) ) secax.set_xticklabels(sec_labels) - secax.set_xlabel(f'{"xyz"[projection]} / {spatial.output_unit}') + secax.set_xlabel(f'{"xyz"[projection]} / {setup.spatial_unit}') utils.update_font_sizes(secax, fontsize) @@ -371,9 +370,8 @@ def draw_plot( # One mesh is sufficient, it should be the same for all of them x_id, y_id, _, _ = utils.get_projection(np_meshes[0, 0]) - spatial_unit = spatial_quantity(np_meshes[0, 0]).output_unit utils.label_spatial_axes( - np_axs, "xyz"[x_id], "xyz"[y_id], spatial_unit=spatial_unit + np_axs, "xyz"[x_id], "xyz"[y_id], spatial_unit=setup.spatial_unit ) # make extra space for the upper limit of the colorbar if setup.layout == "tight" and fig is not None: @@ -443,13 +441,5 @@ def contourf( utils.update_font_sizes( fig.axes, fontsize=kwargs.get("fontsize", setup.fontsize) ) - min_tick_length = setup.tick_length * 2.0 / 3.5 # mpl default - for ax in fig.axes: - ax.tick_params( - "both", which="major", pad=setup.tick_pad, length=setup.tick_length - ) - ax.tick_params( - "both", which="minor", pad=setup.tick_pad, length=min_tick_length - ) return fig diff --git a/ogstools/plot/features.py b/ogstools/plot/features.py index c4702005c609ff1a51018b57726507819f8aa967..17ff236e55f9e812abf31ea5021aa00c3f30f23b 100644 --- a/ogstools/plot/features.py +++ b/ogstools/plot/features.py @@ -14,7 +14,7 @@ import pandas as pd import pyvista as pv from matplotlib.collections import PolyCollection -from .shared import setup, spatial_quantity +from .shared import setup def layer_boundaries( @@ -30,9 +30,9 @@ def layer_boundaries( for reg_id in np.unique(segments.cell_data["RegionId"]): segment = segments.threshold((reg_id, reg_id), "RegionId") edges = segment.extract_surface().strip(True, 10000) - x_b, y_b = spatial_quantity(mesh).transform( - edges.points[edges.lines % edges.n_points].T[[x_id, y_id]] - ) + x_b, y_b = edges.points[edges.lines % edges.n_points].T[ + [x_id, y_id] + ] ax.plot(x_b, y_b, "-k", lw=setup.linewidth) @@ -49,9 +49,7 @@ def element_edges( for cp, ct in zip(cell_points, cell_types, strict=False) if ct == cell_type ] - verts = spatial_quantity(lin_mesh).transform( - np.delete(cell_pts, projection, -1) - ) + verts = np.delete(cell_pts, projection, -1) lw = 0.5 * setup.linewidth pc = PolyCollection(verts.tolist(), fc="None", ec="black", lw=lw) ax.add_collection(pc) @@ -75,23 +73,5 @@ def shape_on_top( contour_vals = [ y + scaling * contour(x) for y, x in zip(y_vals, x_vals, strict=False) ] - spatial = spatial_quantity(surf).transform - ax.set_ylim(top=float(spatial(np.max(contour_vals)))) - ax.fill_between( - spatial(x_vals), - spatial(y_vals), - spatial(contour_vals), - facecolor="lightgrey", - ) - - -def outline( - ax: plt.Axes, mesh: pv.DataSet, style: str, lw: int, projection: int = 2 -) -> None: - "Plot the outline of a mesh on a matplotlib ax object." - contour = mesh.extract_surface().strip(join=True) - x_id, y_id = np.delete([0, 1, 2], projection) - x, y = spatial_quantity(mesh).transform( - contour.points[contour.lines[1:]].T[[x_id, y_id]] - ) - ax.plot(x, y, style, lw=lw) + ax.set_ylim(top=float(np.max(contour_vals))) + ax.fill_between(x_vals, y_vals, contour_vals, facecolor="lightgrey") diff --git a/ogstools/plot/lineplots.py b/ogstools/plot/lineplots.py index 2a1efef53036fb8698ca65ec13db10cc63fc1743..b39cb4b97c9af7547bb33631c4651630a42c89ac 100644 --- a/ogstools/plot/lineplots.py +++ b/ogstools/plot/lineplots.py @@ -1,167 +1,79 @@ -from string import ascii_uppercase -from typing import Any, Literal +from typing import Any import matplotlib.pyplot as plt import numpy as np import pyvista as pv -from ogstools.meshlib.data_processing import sample_polyline -from ogstools.plot import contourf, setup, utils -from ogstools.plot.shared import spatial_quantity +from ogstools.plot import setup, utils from ogstools.variables import Variable, get_preset -# TODO: ability to swap x and y? -def linesample( +def line( mesh: pv.UnstructuredGrid, - x: str, # TODO renamed it to "along" maybe - variable: str | Variable, - profile_points: np.ndarray, - ax: plt.Axes, - resolution: int | None = 100, - grid: Literal["major", "both", None] = None, + y_var: str | Variable, + x_var: str | Variable | None = None, + ax: plt.Axes | None = None, **kwargs: Any, -) -> plt.Axes: - """ - Plot selected variables obtained from sample_over_polyline function, - this function calls to it internally. Values provided in param x and y - refer to columns of the DataFrame returned by it. - - :param mesh: mesh to sample from. - :param x: Value to be used on x-axis of the plot - :param variable: Values to be used on y-axis of the plot - :param profile_points: Points defining the profile (and its segments) - :param ax: User-created array of Matplotlib axis object - :param resolution: Resolution of the sampled profile. Total number of - points within all profile segments. - :param resolution: Resolution of the sampled profile. Total number of - points within all profile segments. - :param grid: Which gridlines should be drawn? - :param kwargs: Optional keyword arguments passed to matplotlib.pyplot.plot - to customize plot options like a line label (for auto legends), linewidth, - antialiasing, marker face color. - - :returns: Matplotlib Axes object +) -> plt.Figure | None: + """Plot some data of a (1D) mesh. + + You can pass "x", "y" or "z" to either of x_var or y_var to specify which + spatial dimension should be used for the corresponding axis. You can also + pass two data variables for a phase plot. + + :param mesh: The mesh which contains the data to plot + :param y_var: The variable to use for the y-axis + :param x_var: The variable to use for the x-axis, if None automatic + detection of spatial axis is tried. + :param ax: The matplotlib axis to use for plotting, if None creates a + new figure. + :Keyword Arguments: + - figsize: figure size + - color: color of the line + - linewidth: width of the line + - linestyle: style of the line + - label: label in the legend + - grid: if True, show grid + - all other kwargs get passed to matplotlib's plot function """ - variable = get_preset(variable, mesh) - mesh_sp, _ = sample_polyline(mesh, variable, profile_points, resolution) + figsize = kwargs.pop("figsize", [16, 10]) + ax_ = plt.subplots(figsize=figsize)[1] if ax is None else ax - assert isinstance(ax, plt.Axes) + if x_var is None: + non_flat_axis = np.argwhere( + np.invert(np.all(np.isclose(mesh.points, mesh.points[0]), axis=0)) + ).ravel() + x_var = "xyz"[non_flat_axis[0]] - spatial_qty = spatial_quantity(mesh) - kwargs.setdefault("label", variable.data_name) - kwargs.setdefault("color", variable.color) - kwargs.setdefault("linestyle", variable.linestyle) - if "ls" in kwargs: - kwargs.pop("linestyle") + x_var = get_preset(x_var, mesh).magnitude + y_var = get_preset(y_var, mesh).magnitude - utils.update_font_sizes(axes=ax, fontsize=kwargs.pop("fontsize", 20)) - ax.plot( - spatial_qty.transform(mesh_sp[x]), - mesh_sp[variable.data_name], - **kwargs, - ) - ax.set_xlabel("Profile distance / " + spatial_qty.output_unit) - ax.set_ylabel(variable.get_label(setup.label_split)) + kwargs.setdefault("color", y_var.color) + pure_spatial = y_var.data_name in "xyz" and x_var.data_name in "xyz" + lw_scale = 5 if pure_spatial else 3 + kwargs.setdefault("linewidth", setup.linewidth * lw_scale) + fontsize = kwargs.pop("fontsize", setup.fontsize) + show_grid = kwargs.pop("grid", True) and not pure_spatial - if grid in ["both", "major"]: - ax.grid(which="major", color="lightgrey", linestyle="-") - if grid == "major": - ax.minorticks_off() - if grid == "both": - ax.grid(which="minor", color="0.95", linestyle="--") - ax.minorticks_on() + ax_.plot(x_var.transform(mesh), y_var.transform(mesh), **kwargs) - return ax + if "label" in kwargs: + ax_.legend(fontsize=fontsize) + if ax_.get_xlabel() == "": + ax_.set_xlabel(x_var.get_label(setup.label_split)) + if ax_.get_ylabel() == "": + ax_.set_ylabel(y_var.get_label(setup.label_split)) -def linesample_contourf( - mesh: pv.UnstructuredGrid, - variables: str | list | Variable, - profile_points: np.ndarray, - resolution: int | None = None, - plot_nodal_pts: bool | None = True, - nodal_pts_labels: str | list | None = None, -) -> tuple[plt.Figure, plt.Axes]: - """ - Default plot for the data obtained from sampling along a profile on a mesh. - - :param mesh: mesh to plot and sample from. - :param variables: Variables to be read from the mesh - :param profile_points: Points defining the profile (and its segments) - :param resolution: Resolution of the sampled profile. Total number of - points within all profile segments. - :param plot_nodal_pts: Plot and annotate all nodal points in profile - :param nodal_pts_labels: Labels for nodal points (only use if - plot_nodal_points is set to True) - - :returns: Tuple containing Matplotlib Figure and Axis objects - """ - # TODO: Add support for plotting only geometry at top subplot and - # lineplot with twinx in the bottom one - if not isinstance(variables, list): - variables = [variables] - - _, dist_at_knot = sample_polyline( - mesh, variables, profile_points, resolution=resolution - ) - - fig, ax = plt.subplots( - 2, len(variables), figsize=(len(variables) * 13, 12), squeeze=False - ) - spatial_qty = spatial_quantity(mesh) - x_id, y_id, _, _ = utils.get_projection(mesh) - for index, variable in enumerate(variables): - contourf(mesh, variable, fig=fig, ax=ax[0, index]) - linesample( - mesh, - x="dist", - variable=variable, - profile_points=profile_points, - ax=ax[1, index], - resolution=resolution, - grid="both", - ) - - if plot_nodal_pts: - if nodal_pts_labels is None: - nodal_pts_labels = list( - ascii_uppercase[0 : len(profile_points)] - ) - ax[0][index].plot( - spatial_qty.transform(profile_points[:, x_id]), - spatial_qty.transform(profile_points[:, y_id]), - "-*", - linewidth=2, - markersize=7, - color="orange", - ) - for nodal_pt_id, nodal_pt in enumerate(dist_at_knot): - xy = profile_points[nodal_pt_id, [x_id, y_id]] - text_xy = utils.padded(ax[0][index], *spatial_qty.transform(xy)) - ax[0][index].text( - *text_xy, - nodal_pts_labels[nodal_pt_id], - color="orange", - fontsize=setup.fontsize, - ha="center", - va="center", - ) - ax[1][index].axvline( - spatial_qty.transform(nodal_pt), - linestyle="--", - color="orange", - linewidth=2, - ) - ax_twiny = ax[1][index].twiny() - ax_twiny.set_xlim(ax[1][index].get_xlim()) - ax_twiny.set_xticks( - spatial_qty.transform(dist_at_knot), - nodal_pts_labels, - color="orange", - ) - utils.update_font_sizes(fig.axes) - fig.tight_layout() - - return fig, ax + utils.update_font_sizes(axes=ax_, fontsize=fontsize) + + if show_grid: + ax_.grid(which="major", color="lightgrey", linestyle="-") + ax_.grid(which="minor", color="0.95", linestyle="--") + ax_.minorticks_on() + + if ax is not None: + return ax.figure + + return ax_.figure diff --git a/ogstools/plot/plot_setup.py b/ogstools/plot/plot_setup.py index 5371cc10916859d0890343764aeba4474a74c1ac..662ff5d94b7250bd5151592a88bb9356a3505d5d 100644 --- a/ogstools/plot/plot_setup.py +++ b/ogstools/plot/plot_setup.py @@ -64,6 +64,10 @@ class PlotSetup: "Thickness of lines." label_split: int | None "Split Variable labels if they exceed this value." + spatial_unit: str + "Unit of the spatial dimension." + time_unit: str + "Unit of the time dimension." @classmethod def from_dict(cls: type["PlotSetup"], obj: dict) -> "PlotSetup": @@ -89,9 +93,20 @@ class PlotSetup: fontsize=obj["fontsize"], linewidth=obj["linewidth"], label_split=obj["label_split"], + spatial_unit=obj["spatial_unit"], + time_unit=obj["time_unit"], ) def reset(self) -> None: """Reset the plot setup to default values.""" for k, v in self.from_dict(setup_dict).__dict__.items(): self.__dict__[k] = v + + def set_units( + self, spatial: str | None = None, time: str | None = None + ) -> None: + "Convenience function to update spatial and time unit at once." + if spatial is not None: + self.spatial_unit = spatial + if time is not None: + self.time_unit = time diff --git a/ogstools/plot/plot_setup_defaults.py b/ogstools/plot/plot_setup_defaults.py index a794986f537a153c90e3e5356b06ea84120f7f30..3c7300604a65001c868caf9bd690b7dc5f7b0ee3 100644 --- a/ogstools/plot/plot_setup_defaults.py +++ b/ogstools/plot/plot_setup_defaults.py @@ -35,4 +35,6 @@ setup_dict = { "fontsize": 32, "linewidth": 1, "label_split": 37, + "spatial_unit": "m", + "time_unit": "s", } diff --git a/ogstools/plot/shared.py b/ogstools/plot/shared.py index 1aa98b5a016eef8963c95c04b3ba82d78524d7ab..b2befeb46a235a044cd4b21c6c03e58b3ea8d906 100644 --- a/ogstools/plot/shared.py +++ b/ogstools/plot/shared.py @@ -1,22 +1,4 @@ -from typing import cast - -import pyvista as pv - -from ogstools.definitions import SPATIAL_UNITS_KEY -from ogstools.variables import Scalar - from .plot_setup import PlotSetup from .plot_setup_defaults import setup_dict setup = PlotSetup.from_dict(setup_dict) - - -def spatial_quantity(mesh: pv.UnstructuredGrid | pv.DataSet) -> Scalar: - "Return a Scalar able to transform the spatial units of the mesh." - - units = cast( - list[int], - mesh.field_data.get(SPATIAL_UNITS_KEY, [ord(c) for c in "m,m"]), - ) - data_unit, output_unit = "".join(chr(unit) for unit in units).split(",") - return Scalar("", data_unit, output_unit, "") diff --git a/ogstools/plot/utils.py b/ogstools/plot/utils.py index d1fec6e87fee55f441554fc42dd35b56fdccabbc..33d5c8a7712f0bfea10fadd25787ac72508f7606 100644 --- a/ogstools/plot/utils.py +++ b/ogstools/plot/utils.py @@ -89,12 +89,18 @@ def update_font_sizes( """ Update font sizes of labels and texts. + This also scales the ticks accordingly. + :param ax: matplotlib axes which should be updated :param fontsize: font size for the labels and ticks """ if fontsize is None: fontsize = setup.fontsize ax: plt.Axes + scale = fontsize / setup.fontsize + tick_pad = scale * setup.tick_pad + tick_len = scale * setup.tick_length + min_tick_len = tick_len * 2.0 / 3.5 # matplotlib default for ax in np.ravel(np.asarray(axes)): tick_labels = ax.get_xticklabels() + ax.get_yticklabels() labels = [ax.title, ax.xaxis.label, ax.yaxis.label] @@ -102,6 +108,8 @@ def update_font_sizes( ax.tick_params(axis="both", which="both", labelsize=fontsize) for item in tick_labels + labels + [offset_text]: item.set_fontsize(fontsize) + ax.tick_params("both", which="major", pad=tick_pad, length=tick_len) + ax.tick_params("both", which="minor", pad=tick_pad, length=min_tick_len) return diff --git a/ogstools/plot/vectorplots.py b/ogstools/plot/vectorplots.py index 46381290583954d7485cb3d41fca7f53c54edc04..45e2974cc6874f4f5e83c7cf15df9cbb2577d116 100644 --- a/ogstools/plot/vectorplots.py +++ b/ogstools/plot/vectorplots.py @@ -6,7 +6,7 @@ import pyvista as pv from ogstools.variables import Vector -from .shared import setup, spatial_quantity +from .shared import setup def _vectorfield( @@ -59,7 +59,7 @@ def _vectorfield( val[mask == 0, :] = 0 val_norm = np.linalg.norm(np.nan_to_num(val), axis=-1) lw = 2.5 * val_norm / max(1e-16, np.max(val_norm)) * setup.linewidth - i_grid, j_grid = spatial_quantity(mesh).transform(np.meshgrid(i_pts, j_pts)) + i_grid, j_grid = np.meshgrid(i_pts, j_pts) return (i_grid, j_grid, val[..., 0], val[..., 1], lw) diff --git a/ogstools/studies/convergence/convergence.py b/ogstools/studies/convergence/convergence.py index bf9386d741d3dd12707c5b756438649158aae27e..a477b7ced9af537648b32d0850a942d2722f1858 100644 --- a/ogstools/studies/convergence/convergence.py +++ b/ogstools/studies/convergence/convergence.py @@ -282,14 +282,14 @@ def convergence_metrics_evolution( :returns: A pandas Dataframe containing all metrics. """ - all_timevalues = [ms.timevalues() for ms in mesh_series] + all_timevalues = [ms.timevalues for ms in mesh_series] common_timevalues = sorted( set(all_timevalues[0]).intersection(*all_timevalues[1:]) ) p_metrics_per_t = np.empty((0, 9)) - timestep_sizes = [np.mean(np.diff(ms.timevalues())) for ms in mesh_series] + timestep_sizes = [np.mean(np.diff(ms.timevalues)) for ms in mesh_series] for timevalue in tqdm(common_timevalues): meshes = [ms.mesh(ms.closest_timestep(timevalue)) for ms in mesh_series] reference = richardson_extrapolation( diff --git a/ogstools/studies/templates/convergence_study.py b/ogstools/studies/templates/convergence_study.py index fab3f86b8ef49c958302a00dca555e493b1ee5be..ba1ca0243b3552c94247d8e3f6bcb8ed4b994a4c 100644 --- a/ogstools/studies/templates/convergence_study.py +++ b/ogstools/studies/templates/convergence_study.py @@ -37,7 +37,7 @@ plot.setup.combined_colorbar = False # extrapolation. mesh_series = [meshlib.MeshSeries(mesh_path) for mesh_path in mesh_paths] -timestep_sizes = [np.mean(np.diff(ms.timevalues())) for ms in mesh_series] +timestep_sizes = [np.mean(np.diff(ms.timevalues)) for ms in mesh_series] meshes = [ms.mesh(ms.closest_timestep(timevalue)) for ms in mesh_series] topology: pv.UnstructuredGrid = meshes[-3] variable = variables.get_preset(variable_name, meshes[0]) diff --git a/ogstools/variables/__init__.py b/ogstools/variables/__init__.py index 23d2d46734d579c1a995d978370d7c1f3c16a696..86891af180796e720e9058ce4520d2025eb50d75 100644 --- a/ogstools/variables/__init__.py +++ b/ogstools/variables/__init__.py @@ -26,20 +26,10 @@ T_MASK = "temperature_active" H_MASK = "pressure_active" M_MASK = "displacement_active" -# Default style to be used in plotting functions -# For now for Scalars only -# TODO: expand to Matrix and Vector -line_styles = [ - (0, ()), # solid - (0, (1, 1)), # dotted - (0, (5, 5)), # dashed - (0, (3, 5, 1, 5)), # dash dotted - (0, (3, 5, 1, 5, 1, 5)), # dash dot dotted - (0, (3, 10, 1, 10)), # loosely dash dotted -] -group_color_thermal = "tab:red" -group_color_hydraulic = "tab:blue" -group_color_mechanical = "black" # green would be bad for colorblindess +# Default colors to be used in plotting functions +COLOR_THERMAL = "tab:red" +COLOR_HYDRO = "tab:blue" +COLOR_MECH = "black" # green would be bad for colorblindess # ====== general ====== material_id = Scalar(data_name="MaterialIDs", categoric=True, cmap="tab20") @@ -53,14 +43,10 @@ temperature = Scalar( mask=T_MASK, cmap=temperature_cmap, bilinear_cmap=True, - color=group_color_thermal, - linestyle=line_styles[0], + color=COLOR_THERMAL, ) heatflowrate = Scalar( - data_name="HeatFlowRate", - mask=T_MASK, - color=group_color_thermal, - linestyle=line_styles[1], + data_name="HeatFlowRate", mask=T_MASK, color=COLOR_THERMAL ) # ====== hydraulic ====== @@ -72,8 +58,7 @@ pressure = Scalar( symbol="p", mask=H_MASK, cmap="Blues", - color=group_color_hydraulic, - linestyle=line_styles[0], + color=COLOR_HYDRO, ) hydraulic_head = Scalar( data_name="pressure", @@ -83,8 +68,7 @@ hydraulic_head = Scalar( symbol="h", mask=H_MASK, cmap="Blues", - color=group_color_hydraulic, - linestyle=line_styles[1], + color=COLOR_HYDRO, ) velocity = Vector( data_name="velocity", @@ -93,6 +77,8 @@ velocity = Vector( output_name="darcy_velocity", symbol="v", mask=H_MASK, + cmap="Blues", + color=COLOR_HYDRO, ) massflowrate = Scalar(data_name="MassFlowRate", mask=H_MASK) @@ -104,6 +90,7 @@ displacement = Vector( symbol="u", mask=M_MASK, cmap="PRGn", + color=COLOR_MECH, bilinear_cmap=True, ) strain = Matrix( @@ -112,6 +99,7 @@ strain = Matrix( output_unit="percent", output_name="strain", symbol=r"\varepsilon", + color=COLOR_MECH, mask=M_MASK, ) stress = Matrix( @@ -120,6 +108,7 @@ stress = Matrix( output_unit="MPa", output_name="stress", symbol=r"\sigma", + color=COLOR_MECH, mask=M_MASK, ) effective_pressure = Scalar( @@ -128,10 +117,9 @@ effective_pressure = Scalar( output_unit="MPa", output_name="effective_pressure", symbol=r"\pi", - mask=M_MASK, func=tensor_math.effective_pressure, - color=group_color_mechanical, - linestyle=line_styles[0], + mask=M_MASK, + color=COLOR_MECH, ) dilatancy_critescu_tot = Scalar( data_name="sigma", @@ -139,18 +127,16 @@ dilatancy_critescu_tot = Scalar( output_unit="", output_name="dilatancy_criterion", symbol=r"F_\mathrm{dil}", - mask=M_MASK, func=mesh_dependent.dilatancy_critescu, + mask=M_MASK, + color=COLOR_MECH, mesh_dependent=True, cmap=integrity_cmap, bilinear_cmap=True, - color=group_color_mechanical, - linestyle=line_styles[1], ) dilatancy_critescu_eff = dilatancy_critescu_tot.replace( output_name="effective_dilatancy_criterion", func=partial(mesh_dependent.dilatancy_critescu, effective=True), - linestyle=line_styles[2], ) dilatancy_alkan = Scalar( @@ -159,18 +145,16 @@ dilatancy_alkan = Scalar( output_unit="MPa", output_name="dilatancy_criterion", symbol=r"F_\mathrm{dil}", - mask=M_MASK, func=mesh_dependent.dilatancy_alkan, + mask=M_MASK, + color=COLOR_MECH, mesh_dependent=True, cmap=integrity_cmap, bilinear_cmap=True, - color=group_color_mechanical, - linestyle=line_styles[3], ) dilatancy_alkan_eff = dilatancy_alkan.replace( output_name="effective_dilatancy_criterion", func=partial(mesh_dependent.dilatancy_alkan, effective=True), - linestyle=line_styles[4], ) fluid_pressure_crit = Scalar( @@ -179,13 +163,12 @@ fluid_pressure_crit = Scalar( output_unit="MPa", output_name="fluid_pressure_criterion", symbol="F_p", - mask=M_MASK, func=mesh_dependent.fluid_pressure_criterion, + mask=M_MASK, + color=COLOR_MECH, mesh_dependent=True, cmap=integrity_cmap, bilinear_cmap=True, - color=group_color_mechanical, - linestyle=line_styles[5], ) nodal_forces = Vector(data_name="NodalForces", mask=M_MASK) @@ -201,6 +184,25 @@ saturation = Scalar( all_variables = [v for v in locals().values() if isinstance(v, Variable)] +def _spatial_preset(axis: str) -> Scalar: + # pylint: disable=import-outside-toplevel + # Importing here dynamically to avoid circular import + # If we want to avoid this, we'd have to move plot.setup to someplace + # outside of plot + from ogstools.plot import setup # noq: I001 + + # pylint: enable=import-outside-toplevel + + return Scalar( + axis, + setup.spatial_unit, # type:ignore[attr-defined] + setup.spatial_unit, # type:ignore[attr-defined] + mesh_dependent=True, + func=mesh_dependent.get_pts("xyz".index(axis)), + color="k", + ) + + def get_preset(variable: Variable | str, mesh: pv.UnstructuredGrid) -> Variable: """ Returns a Variable preset or creates one with correct type. @@ -219,6 +221,8 @@ def get_preset(variable: Variable | str, mesh: pv.UnstructuredGrid) -> Variable: error_msg = ( f"Data not found in mesh. Available data names are {data_keys}. " ) + if isinstance(variable, str) and variable in ["x", "y", "z"]: + return _spatial_preset(variable) if isinstance(variable, Variable): if variable.data_name in data_keys: diff --git a/ogstools/variables/mesh_dependent.py b/ogstools/variables/mesh_dependent.py index cf9e1fb30713698466e15b5db2148b8969f92496..47f6b2124ea48a05c71a4f9a4f308bb62147ff9c 100644 --- a/ogstools/variables/mesh_dependent.py +++ b/ogstools/variables/mesh_dependent.py @@ -7,6 +7,8 @@ "Functions related to stress analysis which can be only applied to a mesh." +from collections.abc import Callable + import numpy as np import pyvista as pv from pint.facets.plain import PlainQuantity @@ -16,6 +18,19 @@ from .unit_registry import u_reg from .variable import Variable +def get_pts( + index: int, +) -> Callable[[pv.UnstructuredGrid, Variable], np.ndarray]: + "Returns the coordinates of all points with the given index" + + def get_pts_coordinate( + mesh: pv.UnstructuredGrid, _: Variable + ) -> np.ndarray: + return mesh.points[:, index] + + return get_pts_coordinate + + def fluid_pressure_criterion( mesh: pv.UnstructuredGrid, variable: Variable ) -> PlainQuantity: diff --git a/ogstools/variables/variable.py b/ogstools/variables/variable.py index e8ae2a6f66af576cdbe8ab63883be9b53b85c98e..b2b4edf0d67a72e5760218777c81a9eb3b0281d8 100644 --- a/ogstools/variables/variable.py +++ b/ogstools/variables/variable.py @@ -56,8 +56,6 @@ class Variable: """Does this variable only have categoric values?""" color: str | None = None """Default color for plotting""" - linestyle: tuple | None = None - """Default linestyle for plotting""" def __post_init__(self) -> None: if not self.output_name: @@ -99,7 +97,6 @@ class Variable: bilinear_cmap=new_variable.bilinear_cmap, categoric=new_variable.categoric, color=new_variable.color, - linestyle=new_variable.linestyle, ).replace(**changes) def transform( diff --git a/pyproject.toml b/pyproject.toml index b4add095e87ecd999150025678a60d3191b8bf82..9810f0ea7856e858262a38cfb57b0993e31c1031 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -269,6 +269,7 @@ disallow_untyped_defs = true disallow_incomplete_defs = true warn_unreachable = true enable_incomplete_feature = ["Unpack"] +show_absolute_path = true [tool.ruff] target-version = "py310" diff --git a/tests/test_meshlib.py b/tests/test_meshlib.py index 9fde1e87085d5064043a26bce28c8905019a7660..8208ed9f89a700762d163f6e5109adfb67e61800 100644 --- a/tests/test_meshlib.py +++ b/tests/test_meshlib.py @@ -165,7 +165,7 @@ class TestUtils: assert not np.any(np.isnan(ms.timesteps)) assert not np.any(np.isnan(ms.values("temperature"))) - assert ms.timevalues()[ + assert ms.timevalues[ ms.closest_timestep(1.0) ] == ms.closest_timevalue(1.0) @@ -229,14 +229,14 @@ class TestUtils: def test_plot_domain_aggregate(self): "Test aggregation of meshseries." mesh_series = examples.load_meshseries_THM_2D_PVD() - mesh_series.plot_domain_aggregate("temperature", np.mean, "a") + mesh_series.plot_domain_aggregate("temperature", np.mean) def test_time_slice(self): mesh_series = examples.load_meshseries_HT_2D_XDMF() points = np.linspace([2, 2, 0], [4, 18, 0], num=100) mesh_series.plot_time_slice("temperature", points, levels=[78, 79, 80]) mesh_series.plot_time_slice( - "temperature", points, y_axis="y", interpolate=False, time_unit="h", + "temperature", points, y_axis="y", interpolate=False, time_logscale=True, cb_loc="left", dpi=50, fontsize=10 ) # fmt: skip @@ -350,99 +350,6 @@ class TestUtils: mesh["depth"] = mesh.depth() assert np.all(mesh["depth"] < -mesh.points[..., -1]) - def test_interp_points(self): - profile = np.array( - [ - [-1000, -175, 6700], - [-600, -600, 6700], - [100, -300, 6700], - ] - ) - profile_points = ot.meshlib.interp_points(profile, resolution=100) - assert isinstance(profile_points, np.ndarray) - # Check first point - assert (profile_points[0, :] == profile[0, :]).all() - # Check last point - assert (profile_points[-1, :] == profile[2, :]).all() - # Check if middle point is present in the profile at expected index - assert (profile_points == profile[1, :]).any() - - def test_distance_in_segments(self): - profile = np.array([[0, 0, 0], [1, 0, 0], [2, 0, 0]]) - profile_points = ot.meshlib.interp_points(profile, resolution=9) - dist_in_seg = ot.meshlib.distance_in_segments(profile, profile_points) - assert len(np.where(dist_in_seg == 0)[0]) == 2 - assert len(np.where(dist_in_seg == 1)[0]) == 1 - - def test_distance_in_profile(self): - profile = np.array([[0, 0, 0], [1, 0, 0], [2, 0, 0]]) - profile_points = ot.meshlib.interp_points(profile, resolution=9) - dist_in_seg = ot.meshlib.distance_in_profile(profile_points) - # Check if distance is increasing - assert np.all(np.diff(dist_in_seg) > 0) - # Check if distances at the beginning and end of profile are correct - assert dist_in_seg[0] == 0 - assert dist_in_seg[-1] == 2 - - def test_sample_over_polyline_single_segment(self): - ms = examples.load_meshseries_HT_2D_XDMF() - profile = np.array([[4, 2, 0], [4, 18, 0]]) - ms_sp, _ = ot.meshlib.sample_polyline( - ms.mesh(-1), - ["pressure", "temperature"], - profile, - ) - assert not np.any(np.isnan(ms_sp["pressure"])) - assert (ms_sp["pressure"].to_numpy() > 0).all() - assert not np.any(np.isnan(ms_sp["temperature"])) - assert ( - ms_sp["temperature"].to_numpy() - >= np.zeros_like(ms_sp["temperature"].to_numpy()) - ).all() - assert (ms_sp["dist"] == ms_sp["dist_in_segment"]).all() - - def test_sample_over_polyline_multi_segment(self): - ms = examples.load_meshseries_THM_2D_PVD() - profile = np.array( - [ - [-1000, -175, 6700], - [-600, -600, 6700], - [100, -300, 6700], - [910, -590, 6700], - ] - ) - ms_sp, _ = ot.meshlib.sample_polyline( - ms.mesh(1), - ot.variables.temperature, - profile, - resolution=10, - ) - data = ms_sp[ot.variables.temperature.data_name].to_numpy() - assert not np.any(np.isnan(data)) - assert (np.abs(data) > np.zeros_like(data)).all() - # output should be in Celsius - assert (data >= np.ones_like(data) * -273.15).all() - assert (ms_sp["dist"] != ms_sp["dist_in_segment"]).any() - - def test_sample_over_polyline_single_segment_vec_prop(self): - ms = examples.load_meshseries_HT_2D_XDMF() - profile = np.array([[4, 2, 0], [4, 18, 0]]) - ms_sp, _ = ot.meshlib.sample_polyline( - ms.mesh(-1), - "darcy_velocity", - profile, - ) - assert not np.any(np.isnan(ms_sp["darcy_velocity_0"])) - assert ( - np.abs(ms_sp["darcy_velocity_0"].to_numpy()) - > np.zeros_like(ms_sp["darcy_velocity_0"].to_numpy()) - ).all() - assert not np.any(np.isnan(ms_sp["darcy_velocity_1"])) - assert ( - np.abs(ms_sp["darcy_velocity_1"].to_numpy()) - > np.zeros_like(ms_sp["darcy_velocity_1"].to_numpy()) - ).all() - @pytest.mark.parametrize( ("elem_order", "quads", "intpt_order", "mixed"), [ @@ -542,7 +449,7 @@ class TestUtils: def test_slice(self): ms = examples.load_meshseries_HT_2D_XDMF() ms_sliced = ms[1::2] - assert len(ms.timevalues()) >= 2 * len(ms_sliced.timevalues()) + assert len(ms.timevalues) >= 2 * len(ms_sliced.timevalues) def test_transform(self): ms = examples.load_meshseries_THM_2D_PVD() diff --git a/tests/test_plot.py b/tests/test_plot.py index 06ee952d7c9fc29e47c1c244afbd2f3262f62beb..6ffa05fe7c0e875323d6fc7143a8ba7db7e8e2ac 100644 --- a/tests/test_plot.py +++ b/tests/test_plot.py @@ -201,7 +201,7 @@ class TestPlotting: def test_animation(self): """Test creation of animation.""" meshseries = examples.load_meshseries_THM_2D_PVD() - timevalues = np.linspace(0, meshseries.timevalues()[-1], num=3) + timevalues = np.linspace(0, meshseries.timevalues[-1], num=3) anim = meshseries.animate( ot.variables.temperature, timevalues, @@ -214,7 +214,7 @@ class TestPlotting: def test_save_animation(self): """Test saving of an animation.""" meshseries = examples.load_meshseries_THM_2D_PVD() - timevalues = np.linspace(0, meshseries.timevalues()[-1], num=3) + timevalues = np.linspace(0, meshseries.timevalues[-1], num=3) anim = meshseries.animate(ot.variables.temperature, timevalues) if not utils.save_animation(anim, mkstemp()[1], 5): pytest.skip("Saving animation failed.") @@ -263,33 +263,11 @@ class TestPlotting: def test_lineplot(self): """Test creation of a linesplot from sampled profile data""" mesh = examples.load_meshseries_HT_2D_XDMF().mesh(-1) - profile_HT = np.array([[4, 2, 0], [4, 18, 0]]) - fig, ax = plt.subplots(1, 1, figsize=(5, 5)) - ax = mesh.plot_linesample( - x="dist", - variable="pressure", - profile_points=profile_HT, - ax=ax, - fontsize=15, - ) - ax_twinx = ax.twinx() - ax_twinx = mesh.plot_linesample( - x="dist", - variable="temperature", - profile_points=profile_HT, - ax=ax_twinx, - fontsize=15, - ) - plt.close() - - def test_plot_profile(self): - """Test creation of a profile plot from sampled profile data""" - ms_CT = examples.load_meshseries_CT_2D_XDMF() - profile_CT = np.array([[47.0, 1.17, 72.0], [-4.5, 1.17, -59.0]]) - fig, ax = ms_CT.mesh(11).plot_linesample_contourf( - ot.variables.saturation, - profile_CT, - resolution=100, - plot_nodal_pts=True, - ) - plt.close() + profile_HT = mesh.sample_over_line([4, 2, 0], [4, 18, 0]) + ot.plot.setup.set_units(spatial="km", time="a") + fig = ot.plot.line(profile_HT, "pressure") + ot.plot.line(profile_HT, ot.variables.pressure, "x", ax=fig.axes[0]) + fig = ot.plot.line( + profile_HT, "y", "x", figsize=[5, 5], color="g", linewidth=1, + ls="--", label="test", grid=True, + ) # fmt: skip