From b419ade77d5e98cf03ef1da5c019a6e60b5a1965 Mon Sep 17 00:00:00 2001 From: FZill <florian.zill@ufz.de> Date: Thu, 21 Mar 2024 12:45:26 +0100 Subject: [PATCH] [all] many type fixes also ignore jupyter notebooks in main folder --- .gitignore | 5 +- ogstools/meshlib/boundary.py | 10 +-- ogstools/meshlib/boundary_set.py | 18 ++--- ogstools/meshlib/boundary_subset.py | 10 ++- ogstools/meshlib/gmsh_meshing.py | 8 +-- ogstools/meshlib/mesh_series.py | 7 +- ogstools/meshlib/region/region.py | 7 +- ogstools/meshlib/xdmf_reader.py | 11 +-- ogstools/meshplotlib/animation.py | 16 +++-- ogstools/meshplotlib/core.py | 69 +++++++++++-------- ogstools/meshplotlib/plot_features.py | 15 ++-- ogstools/meshplotlib/plot_setup.py | 14 ++-- ogstools/meshplotlib/utils.py | 5 +- ogstools/msh2vtu/__init__.py | 18 ++--- ogstools/propertylib/property.py | 14 ++-- ogstools/studies/convergence/convergence.py | 6 +- .../examples/steady_state_diffusion.py | 6 +- 17 files changed, 134 insertions(+), 105 deletions(-) diff --git a/.gitignore b/.gitignore index f0aa73237..3e4a20641 100644 --- a/.gitignore +++ b/.gitignore @@ -162,6 +162,9 @@ cython_debug/ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. -#.idea/ +#.idea/S /docs/auto_examples/ + +# ignore jupyter notebooks in main folder (useful for testing, annoying for git) +/**.ipynb diff --git a/ogstools/meshlib/boundary.py b/ogstools/meshlib/boundary.py index d298cc7e6..030e3fb95 100644 --- a/ogstools/meshlib/boundary.py +++ b/ogstools/meshlib/boundary.py @@ -21,7 +21,7 @@ class Boundary(ABC): """ @abstractmethod - def dim(self): + def dim(self) -> int: """ Get the dimension of the boundary. @@ -29,7 +29,7 @@ class Boundary(ABC): int: The dimension of the boundary. For example, the dimension of a boundary of a cube (3D) is 2. """ - return + return 0 @dataclass(frozen=True) @@ -47,7 +47,7 @@ class Layer(Boundary): of the layer in the GIS system. """ - def __post_init__(self): + def __post_init__(self) -> None: if not self.material_id: object.__setattr__(self, "material_id", self.bottom._material_id) @@ -89,7 +89,7 @@ class Layer(Boundary): rasters.append(bottom_raster) return rasters - def dim(self): + def dim(self) -> int: return 3 @@ -100,7 +100,7 @@ class LocationFrame: ymin: float ymax: float - def as_gml(self, filename: Path): + def as_gml(self, filename: Path) -> None: """ Generate GML representation of the location frame. diff --git a/ogstools/meshlib/boundary_set.py b/ogstools/meshlib/boundary_set.py index 17dd54b92..22de2ef82 100644 --- a/ogstools/meshlib/boundary_set.py +++ b/ogstools/meshlib/boundary_set.py @@ -20,12 +20,12 @@ class BoundarySet(ABC): """ @abstractmethod - def bounds(self): - return + def bounds(self) -> list: + return [] @abstractmethod - def filenames(self): - return + def filenames(self) -> list[Path]: + return [] class LayerSet(BoundarySet): @@ -49,16 +49,16 @@ class LayerSet(BoundarySet): raise ValueError(msg) self.layers = layers - def bounds(self): + def bounds(self) -> list: return list(self.layers[0].top.mesh.bounds) - def filenames(self): + def filenames(self) -> list[Path]: layer_filenames = [layer.bottom.filename for layer in self.layers] layer_filenames.insert(0, self.layers[0].top.filename) # file interface return layer_filenames @classmethod - def from_pandas(cls, df: pd.DataFrame): + def from_pandas(cls, df: pd.DataFrame) -> "LayerSet": """Create a LayerSet from a Pandas DataFrame.""" Row = namedtuple("Row", ["material_id", "mesh", "resolution"]) surfaces = [ @@ -79,7 +79,7 @@ class LayerSet(BoundarySet): ] return cls(layers=base_layer) - def create_raster(self, resolution): + def create_raster(self, resolution: float) -> tuple[Path, Path]: """ Create raster representations for the LayerSet. @@ -103,7 +103,7 @@ class LayerSet(BoundarySet): file.write("\n".join(str(item) for item in raster_set)) return raster_vtu, rastered_layers_txt - def create_rasters(self, resolution: int) -> list[Path]: + def create_rasters(self, resolution: float) -> list[Path]: """ For each surface a (temporary) raster file with given resolution is created. """ diff --git a/ogstools/meshlib/boundary_subset.py b/ogstools/meshlib/boundary_subset.py index d297b4fc7..ef74fc93e 100644 --- a/ogstools/meshlib/boundary_subset.py +++ b/ogstools/meshlib/boundary_subset.py @@ -21,9 +21,7 @@ class Surface: return self._material_id @typechecked - def __init__( - self, input: Union[Path, pv.UnstructuredGrid], material_id: int - ): + def __init__(self, input: Union[Path, pv.DataObject], material_id: int): """Initialize a surface mesh. Either from pyvista or from a file.""" self._material_id = material_id @@ -33,7 +31,7 @@ class Surface: msg = f"{self.filename} does not exist." raise ValueError(msg) self.mesh = pv.get_reader(self.filename).read() - elif isinstance(input, pv.UnstructuredGrid): + elif isinstance(input, pv.DataObject): self.mesh = input self.filename = Path(tempfile.mkstemp(".vtu", "surface")[1]) pv.save_meshio(self.filename, self.mesh, file_format="vtu") @@ -42,7 +40,7 @@ class Surface: np.ones(self.mesh.n_cells) * self.material_id ).astype(np.int32) - def __eq__(self, other): + def __eq__(self, other: object) -> bool: return self.__dict__ == other.__dict__ def create_raster_file(self, resolution: float) -> Path: @@ -76,7 +74,7 @@ def Gaussian2D( spread: float, height_offset: float, n: int, -): +) -> pv.DataSet: """ Generate a 2D Gaussian-like surface using the provided parameters. diff --git a/ogstools/meshlib/gmsh_meshing.py b/ogstools/meshlib/gmsh_meshing.py index c58cb076d..c985a5000 100644 --- a/ogstools/meshlib/gmsh_meshing.py +++ b/ogstools/meshlib/gmsh_meshing.py @@ -9,7 +9,7 @@ def _geo_square( lengths: Union[float, list[float]], n_edge_cells: Union[int, list[int]], structured: bool, -): +) -> None: _lengths = lengths if isinstance(lengths, list) else [lengths] * 2 _n = n_edge_cells if isinstance(n_edge_cells, list) else [n_edge_cells] * 2 geo.addPoint(0, 0, 0, tag=1) @@ -41,7 +41,7 @@ def rect( structured_grid: bool = True, order: int = 1, out_name: Path = Path("unit_square.msh"), -): +) -> None: gmsh.initialize() gmsh.option.set_number("General.Verbosity", 0) gmsh.model.add("unit_square") @@ -75,7 +75,7 @@ def cuboid( structured_grid: bool = True, order: int = 1, out_name: Path = Path("unit_cube.msh"), -): +) -> None: gmsh.initialize() gmsh.option.set_number("General.Verbosity", 0) gmsh.model.add("unit_cube") @@ -118,7 +118,7 @@ def bhe_mesh( bhe_depth: float = 20, order: int = 1, out_name: Path = Path("bhe_mesh.msh"), -): +) -> None: gmsh.initialize() model, geo = (gmsh.model, gmsh.model.geo) model.add(Path(out_name).stem) diff --git a/ogstools/meshlib/mesh_series.py b/ogstools/meshlib/mesh_series.py index 7021fed53..c7dd4a5d9 100644 --- a/ogstools/meshlib/mesh_series.py +++ b/ogstools/meshlib/mesh_series.py @@ -7,6 +7,7 @@ import meshio import numpy as np import pyvista as pv import vtuIO +from h5py import File from scipy.interpolate import LinearNDInterpolator, NearestNDInterpolator from tqdm.auto import tqdm @@ -55,7 +56,7 @@ class MeshSeries: raise TypeError(msg) @property - def hdf5(self): + def hdf5(self) -> File: # We assume there is only one h5 file return next(iter(self._xdmf_reader.hdf5_files.values())) @@ -225,7 +226,7 @@ class MeshSeries: data_name: str, interp_method: Optional[Literal["nearest", "probefilter"]] = None, interp_backend: Optional[Literal["vtk", "scipy"]] = None, - ): + ) -> np.ndarray: obs_pts_dict = {f"pt{j}": point for j, point in enumerate(points)} dim = self.read(0).get_cell(0).dimension pvd_path = self.filepath @@ -242,7 +243,7 @@ class MeshSeries: points: np.ndarray, data_name: str, interp_method: Optional[Literal["nearest", "linear"]] = None, - ): + ) -> np.ndarray: values = self.hdf5["meshes"][self.hdf5_bulk_name][data_name][:] geom = self.hdf5["meshes"][self.hdf5_bulk_name]["geometry"][0] values = np.swapaxes(values, 0, 1) diff --git a/ogstools/meshlib/region/region.py b/ogstools/meshlib/region/region.py index dd308bd62..ac6e4b8dc 100644 --- a/ogstools/meshlib/region/region.py +++ b/ogstools/meshlib/region/region.py @@ -7,7 +7,7 @@ import numpy as np import pyvista as pv from ogs import cli -from ..boundary_set import LayerSet +from ..boundary_set import Layer, LayerSet class RegionSet: @@ -26,7 +26,7 @@ class RegionSet: self.filename = Path(tempfile.mkstemp(".vtu", "region_set")[1]) self.mesh = input - def box_boundaries(self): + def box_boundaries(self) -> tuple[pv.UnstructuredGrid, ...]: """ Retrieve the boundaries of the mesh in local coordinate system (u, v, w). @@ -46,6 +46,7 @@ class RegionSet: mesh = ... u_min, u_max, v_min, v_max, w_min, w_max = mesh.box_boundaries() """ + assert isinstance(self.mesh, pv.UnstructuredGrid) surface = self.mesh.extract_surface() u_max = to_boundary(surface, lambda normals: normals[:, 0] > 0.5) u_min = to_boundary(surface, lambda normals: normals[:, 0] < -0.5) @@ -152,7 +153,7 @@ def to_region_prism(layer_set: LayerSet, resolution: float) -> RegionSet: def layer_to_simplified_mesh( - layer, resolution: float, rank: int, bounds: list[float] + layer: Layer, resolution: float, rank: int, bounds: list[float] ) -> pv.UnstructuredGrid: """Convert a geological layer to a simplified mesh. diff --git a/ogstools/meshlib/xdmf_reader.py b/ogstools/meshlib/xdmf_reader.py index 6f1b49b57..c82d49041 100644 --- a/ogstools/meshlib/xdmf_reader.py +++ b/ogstools/meshlib/xdmf_reader.py @@ -15,6 +15,8 @@ to be read like:: """ +from xml.etree.ElementTree import Element + import meshio import numpy as np from meshio.xdmf.time_series import ( @@ -25,10 +27,10 @@ from meshio.xdmf.time_series import ( class XDMFReader(meshio.xdmf.TimeSeriesReader): - def __init__(self, filename): + def __init__(self, filename: str): super().__init__(filename) - def read_data(self, k: int): + def read_data(self, k: int) -> tuple[float, dict, dict, dict]: point_data = {} cell_data_raw = {} other_data = {} @@ -67,8 +69,8 @@ class XDMFReader(meshio.xdmf.TimeSeriesReader): return t, point_data, cell_data, other_data - def _read_data_item(self, data_item): - dims = [int(d) for d in data_item.get("Dimensions").split()] + def _read_data_item(self, data_item: Element) -> np.ndarray: + dims = [int(d) for d in data_item.get("Dimensions", "").split()] # Actually, `NumberType` is XDMF2 and `DataType` XDMF3, but many files out there # use both keys interchangeably. @@ -92,6 +94,7 @@ class XDMFReader(meshio.xdmf.TimeSeriesReader): data_format = data_item.attrib["Format"] + assert isinstance(data_item.text, str) if data_format == "XML": return np.fromstring( data_item.text, diff --git a/ogstools/meshplotlib/animation.py b/ogstools/meshplotlib/animation.py index d21069933..e05db2f6a 100644 --- a/ogstools/meshplotlib/animation.py +++ b/ogstools/meshplotlib/animation.py @@ -44,7 +44,7 @@ def animate( index = np.argmin(np.abs(np.asarray(ts) - i)) fig.axes[-1].remove() # remove colorbar - for ax in np.ravel(fig.axes): + for ax in np.ravel(np.asarray(fig.axes)): ax.clear() if titles is not None: setup.title_center = titles[index] @@ -54,18 +54,20 @@ def animate( mesh = mesh_series.read_interp(i, True) with warnings.catch_warnings(): warnings.simplefilter("ignore") - fig = _draw_plot(mesh, property, fig=fig) + fig = _draw_plot( + mesh, property, fig=fig + ) # type: ignore[assignment] _func = partial(animate_func, fig=fig) return FuncAnimation( - fig, - _func, + fig, # type: ignore[arg-type] + _func, # type: ignore[arg-type] frames=tqdm(ts), blit=False, interval=50, repeat=False, - init_func=init, + init_func=init, # type: ignore[arg-type] ) @@ -83,7 +85,7 @@ def save_animation(anim: FuncAnimation, filename: str, fps: int) -> bool: "-vf pad=ceil(iw/2)*2:ceil(ih/2)*2" ).split(" ") - writer = None + writer: Optional[Union[FFMpegWriter, ImageMagickWriter]] = None if FFMpegWriter.isAvailable(): writer = FFMpegWriter(fps=fps, codec="libx265", extra_args=codec_args) filename += ".mp4" @@ -91,7 +93,7 @@ def save_animation(anim: FuncAnimation, filename: str, fps: int) -> bool: print("\nffmpeg not available. It is recommended for saving animation.") filename += ".gif" if ImageMagickWriter.isAvailable(): - writer = "imagemagick" + writer = ImageMagickWriter() else: print( "ImageMagick also not available. Falling back to" diff --git a/ogstools/meshplotlib/core.py b/ogstools/meshplotlib/core.py index ca139bbfa..8fe6ac039 100644 --- a/ogstools/meshplotlib/core.py +++ b/ogstools/meshplotlib/core.py @@ -2,7 +2,7 @@ import warnings from math import nextafter -from typing import Literal, Optional, Union +from typing import Any, Literal, Optional, Union import numpy as np import pyvista as pv @@ -28,13 +28,13 @@ from .utils import get_style_cycler # TODO: define default data_name for regions in setup -def _q_zero_line(mesh_property: Property, levels: np.ndarray): +def _q_zero_line(mesh_property: Property, levels: np.ndarray) -> bool: return mesh_property.bilinear_cmap or ( mesh_property.data_name == "temperature" and levels[0] < 0 < levels[-1] ) -def get_level_boundaries(levels: np.ndarray): +def get_level_boundaries(levels: np.ndarray) -> np.ndarray: return np.array( [ levels[0] - 0.5 * (levels[1] - levels[0]), @@ -57,6 +57,7 @@ def get_cmap_norm( continuous_cmap = colormaps[mesh_property.cmap] else: continuous_cmap = mesh_property.cmap + conti_norm: Union[mcolors.TwoSlopeNorm, mcolors.Normalize] if mesh_property.bilinear_cmap: if vmin <= 0.0 <= vmax: vcenter = 0.0 @@ -258,9 +259,13 @@ def subplot( cmap, norm = get_cmap_norm(levels, mesh_property) if mesh_property.data_name in mesh.point_data: - ax.tricontourf(x, y, tri, values, levels=levels, cmap=cmap, norm=norm) + ax.tricontourf( # type: ignore[call-overload] + x, y, tri, values, levels=levels, cmap=cmap, norm=norm + ) if _q_zero_line(mesh_property, levels): - ax.tricontour(x, y, tri, values, levels=[0], colors="w") + ax.tricontour( # type: ignore[call-overload] + x, y, tri, values, levels=[0], colors="w" + ) else: ax.tripcolor(x, y, tri, facecolors=values, cmap=cmap, norm=norm) if mesh_property.is_mask(): @@ -299,7 +304,9 @@ def subplot( sec_labels += [""] # TODO: use a function to make this short secax = ax.secondary_xaxis("top") - secax.xaxis.set_major_locator(mticker.FixedLocator(ax.get_xticks())) + secax.xaxis.set_major_locator( + mticker.FixedLocator(list(ax.get_xticks())) + ) secax.set_xticklabels(sec_labels) secax.set_xlabel(f'{"xyz"[projection]} / {setup.length.output_unit}') @@ -419,21 +426,21 @@ def _draw_plot( meshes: Union[list[pv.UnstructuredGrid], np.ndarray, pv.UnstructuredGrid], mesh_property: Property, fig: Optional[mfigure.Figure] = None, - ax: Optional[plt.Axes] = None, -) -> mfigure.Figure: + axes: Optional[plt.Axes] = None, +) -> Optional[mfigure.Figure]: """ Plot the property field of meshes on existing figure. :param meshes: Singular mesh of 2D numpy array of meshes :param property: the property field to be visualized on all meshes - :param fig: Matplotlib Figure object to use for plotting (optional) - :param ax: Matplotlib Axis object to use for plotting (optional) + :param fig: Matplotlib figure to use for plotting (optional) + :param axes: Matplotlib Axes to use for plotting (optional) """ shape = _get_rows_cols(meshes) np_meshes = np.reshape(meshes, shape) - if fig is not None and ax is not None: - np_axs = np.reshape(np.array(ax), shape) - elif fig is not None and ax is None: + if fig is not None and axes is not None: + np_axs = np.reshape(np.array(axes), shape) + elif fig is not None and axes is None: # Only Fig is given # Multiple meshes should be accepted warnings.warn( @@ -441,14 +448,14 @@ def _draw_plot( Warning, stacklevel=4, ) - np_axs = np.reshape(fig.axes, shape) - elif fig is None and ax is not None: + np_axs = np.reshape(np.asarray(fig.axes), shape) + elif fig is None and axes is not None: # Only ax is given # Only one mesh should be accepted if shape != (1, 1): msg = "You have provided only one Axis object but multiple meshes. Provide only one mesh per Axis object, or provide Figure object instead." raise ValueError(msg) - np_axs = np.reshape(np.array(ax), (1, 1)) + np_axs = np.reshape(np.array(axes), (1, 1)) else: msg = "Neither Figure nor Axis object was provided." raise TypeError(msg) @@ -459,14 +466,14 @@ def _draw_plot( _levels = ( combined_levels if setup.combined_colorbar - else get_combined_levels([np_meshes[i, j]], mesh_property) + else get_combined_levels(np_meshes[i, j, None], mesh_property) ) subplot(np_meshes[i, j], mesh_property, np_axs[i, j], _levels) x_id, y_id = get_projection( np_meshes[0, 0] ) # One mesh is sufficient, it should be the same for all of them - np_axs = label_spatial_axes(np_axs, np.array([x_id, y_id])) + label_spatial_axes(np_axs, "xyz"[x_id], "xyz"[y_id]) np_axs[0, 0].set_title(setup.title_center, loc="center", y=1.02) np_axs[0, 0].set_title(setup.title_left, loc="left", y=1.02) np_axs[0, 0].set_title(setup.title_right, loc="right", y=1.02) @@ -481,7 +488,7 @@ def _draw_plot( stacklevel=4, ) else: - cb_axs = np.ravel(fig.axes).tolist() + cb_axs = np.ravel(np.asarray(fig.axes)).tolist() add_colorbars( fig, cb_axs, mesh_property, combined_levels, pad=0.05 / shape[1] ) @@ -497,7 +504,7 @@ def _draw_plot( for i in range(shape[0]): for j in range(shape[1]): _levels = get_combined_levels( - [np_meshes[i, j]], mesh_property + np_meshes[i, j, None], mesh_property ) add_colorbars(fig, np_axs[i, j], mesh_property, _levels) return fig @@ -532,10 +539,14 @@ def update_font_sizes( subax_xlim = subax.get_xlim() subax_ylim = subax.get_ylim() subax.set_xticks( - subax.get_xticks(), subax.get_xticklabels(), fontsize=fontsize + subax.get_xticks(), + [label.get_text() for label in subax.get_xticklabels()], + fontsize=fontsize, ) subax.set_yticks( - subax.get_yticks(), subax.get_yticklabels(), fontsize=fontsize + subax.get_yticks(), + [label.get_text() for label in subax.get_yticklabels()], + fontsize=fontsize, ) subax.set_xlim(subax_xlim) subax.set_ylim(subax_ylim) @@ -549,7 +560,7 @@ def plot( mesh_property: Union[Property, str], fig: Optional[mfigure.Figure] = None, ax: Optional[plt.Axes] = None, -) -> mfigure.Figure: +) -> Optional[mfigure.Figure]: """ Plot the property field of meshes with default settings. @@ -576,18 +587,20 @@ def plot( n_axs = shape[0] * shape[1] if ax is None and fig is None: _fig, _ax = _fig_init(rows=shape[0], cols=shape[1], aspect=fig_aspect) - fig = _draw_plot(meshes, mesh_property, fig=_fig, ax=_ax) + fig = _draw_plot(meshes, mesh_property, fig=_fig, axes=_ax) + assert isinstance(fig, plt.Figure) for ax, aspect in zip(fig.axes[: n_axs + 1], ax_aspects): ax.set_aspect(1.0 / aspect) elif ax is not None and fig is None: - _draw_plot(meshes, mesh_property, ax=ax) + _draw_plot(meshes, mesh_property, axes=ax) ax.set_aspect(1.0 / ax_aspects[0]) elif ax is None and fig is not None: fig = _draw_plot(meshes, mesh_property, fig=fig) + assert isinstance(fig, plt.Figure) for ax, aspect in zip(fig.axes[: n_axs + 1], ax_aspects): ax.set_aspect(1.0 / aspect) elif ax is not None and fig is not None: - _draw_plot(meshes, mesh_property, fig=fig, ax=ax) + _draw_plot(meshes, mesh_property, fig=fig, axes=ax) for ax, aspect in zip(fig.axes[: n_axs + 1], ax_aspects): ax.set_aspect(1.0 / aspect) return fig @@ -606,8 +619,8 @@ def plot_probe( linestyles: Optional[list] = None, ax: Optional[plt.Axes] = None, fill_between: bool = False, - **kwargs, -) -> mfigure.Figure: + **kwargs: Any, +) -> Optional[mfigure.Figure]: """ Plot the transient property on the observation points in the MeshSeries. diff --git a/ogstools/meshplotlib/plot_features.py b/ogstools/meshplotlib/plot_features.py index 5b358b30c..9a5ee911c 100644 --- a/ogstools/meshplotlib/plot_features.py +++ b/ogstools/meshplotlib/plot_features.py @@ -74,7 +74,9 @@ def plot_element_edges(ax: plt.Axes, surf: pv.DataSet, projection: int) -> None: ] verts = setup.length.transform(np.delete(cell_pts, projection, -1)) lw = 0.5 * setup.rcParams_scaled["lines.linewidth"] - pc = PolyCollection(verts, fc="None", ec="black", lw=lw) + pc = PolyCollection( + verts, fc="None", ec="black", lw=lw # type: ignore[arg-type] + ) ax.add_collection(pc) @@ -131,9 +133,9 @@ def plot_streamlines( lw = 2.5 * val_norm / max(1e-16, np.max(val_norm)) lw *= setup.rcParams_scaled["lines.linewidth"] x_g, y_g = setup.length.transform(np.meshgrid(x, y)) - plot_args = [x_g, y_g, val[..., 0], val[..., 1]] if plot_type == "streamlines": - ax.streamplot(*plot_args, color="k", linewidth=lw, density=1.5) + ax.streamplot(x_g, y_g, val[..., 0], val[..., 1], + color="k", linewidth=lw, density=1.5) # fmt: skip else: line_args = ( dict( # noqa: C408 @@ -142,7 +144,8 @@ def plot_streamlines( if plot_type == "lines" else {} ) - ax.quiver(*plot_args, **line_args, scale=1 / 0.03) + scale = 1.0 / 0.03 + ax.quiver(x_g, y_g, val[..., 0], val[..., 1], **line_args, scale=scale) def plot_on_top( @@ -161,7 +164,7 @@ def plot_on_top( x_vals = df_pts.groupby("x")["x"].agg(np.mean).to_numpy() y_vals = df_pts.groupby("x")["y"].agg(np.max).to_numpy() contour_vals = [y + scaling * contour(x) for y, x in zip(y_vals, x_vals)] - ax.set_ylim(top=setup.length.transform(np.max(contour_vals))) + ax.set_ylim(top=float(setup.length.transform(np.max(contour_vals)))) ax.fill_between( setup.length.transform(x_vals), setup.length.transform(y_vals), @@ -172,7 +175,7 @@ def plot_on_top( def plot_contour( ax: plt.Axes, mesh: pv.DataSet, style: str, lw: int, projection: int = 2 -): +) -> None: contour = mesh.extract_surface().strip(join=True) x_id, y_id = np.delete([0, 1, 2], projection) x, y = 1e-3 * contour.points[contour.lines[1:]].T[[x_id, y_id]] diff --git a/ogstools/meshplotlib/plot_setup.py b/ogstools/meshplotlib/plot_setup.py index d36e7f327..6c933c6b7 100644 --- a/ogstools/meshplotlib/plot_setup.py +++ b/ogstools/meshplotlib/plot_setup.py @@ -1,7 +1,7 @@ """Plot configuration setup.""" from dataclasses import dataclass -from typing import Union +from typing import Optional, Union from ogstools.propertylib.property import Scalar @@ -23,9 +23,9 @@ class PlotSetup: "The resolution (dots per inch) for the figure." fig_scale: float "A scaling factor for the figure." - min_ax_aspect: float + min_ax_aspect: Optional[float] "Minimum aspect ratio of subplots." - max_ax_aspect: float + max_ax_aspect: Optional[float] "Maximum aspect ratio of subplots." invert_colorbar: bool "A boolean indicating whether to invert the colorbar." @@ -38,11 +38,11 @@ class PlotSetup: num_levels: int """The aimed number of levels / bins of the colorbar. See :obj:`ogstools.meshplotlib.levels`""" - num_streamline_interp_pts: int + num_streamline_interp_pts: Optional[int] "The number of interpolation points for streamlines." - p_max: float + p_max: Optional[float] "The fixed upper limit for the current scale." - p_min: float + p_min: Optional[float] "The fixed lower limit for the current scale." rcParams: dict """Matplotlib runtime configuration. See @@ -78,7 +78,7 @@ class PlotSetup: return params @classmethod - def from_dict(cls: type["PlotSetup"], obj: dict): + def from_dict(cls: type["PlotSetup"], obj: dict) -> "PlotSetup": """Create a PlotSetup instance from a dictionary.""" return cls( fig_scale=obj["fig_scale"], diff --git a/ogstools/meshplotlib/utils.py b/ogstools/meshplotlib/utils.py index e31ce760c..8cdc448c9 100644 --- a/ogstools/meshplotlib/utils.py +++ b/ogstools/meshplotlib/utils.py @@ -2,12 +2,13 @@ from typing import Optional import matplotlib.pyplot as plt import numpy as np +from cycler import Cycler def justified_labels(points: np.ndarray) -> list[str]: "Formats an array of points to a list of aligned str." - def fmt(val: float): + def fmt(val: float) -> str: return f"{val:.2f}".rstrip("0").rstrip(".") col_lens = np.max( @@ -24,7 +25,7 @@ def get_style_cycler( min_number_of_styles: int, colors: Optional[Optional[list]] = None, linestyles: Optional[list] = None, -) -> plt.cycler: +) -> Cycler: if colors is None: colors = plt.rcParams["axes.prop_cycle"].by_key()["color"] if linestyles is None: diff --git a/ogstools/msh2vtu/__init__.py b/ogstools/msh2vtu/__init__.py index ac2a24d6e..251124ed3 100644 --- a/ogstools/msh2vtu/__init__.py +++ b/ogstools/msh2vtu/__init__.py @@ -1,7 +1,7 @@ # Author: Dominik Kern (TU Bergakademie Freiberg) import logging from pathlib import Path -from typing import Union +from typing import Any, Union import meshio import numpy as np @@ -18,7 +18,7 @@ __all__ = [ ] -def my_remove_orphaned_nodes(my_mesh: meshio.Mesh): +def my_remove_orphaned_nodes(my_mesh: meshio.Mesh) -> None: """Auxiliary function to remove points not belonging to any cell""" # find connected points and derive mapping from all points to them @@ -51,11 +51,9 @@ def my_remove_orphaned_nodes(my_mesh: meshio.Mesh): # cell data are not affected by point changes my_mesh.cells = output_cell_blocks # update cells - return - # print info for mesh: statistics and data field names -def print_info(mesh: meshio.Mesh): +def print_info(mesh: meshio.Mesh) -> None: N, D = mesh.points.shape logging.info("%d points in %d dimensions", N, D) cell_info = "cells: " @@ -70,13 +68,15 @@ def print_info(mesh: meshio.Mesh): # function to create node connectivity list, i.e. store for each node (point) to # which element (cell) it belongs -def find_cells_at_nodes(cells, node_count, cell_start_index): +def find_cells_at_nodes( + cells: Any, node_count: int, cell_start_index: int +) -> list[set]: # depending on the numbering of mixed meshes in OGS one may think of an # object-oriented way to add elements (of different type) to node # connectivity # initialize list of sets - node_connectivity = [set() for _ in range(node_count)] + node_connectivity: list[set] = [set() for _ in range(node_count)] cell_index = cell_start_index for cell in cells: for node in cell: @@ -95,7 +95,7 @@ def find_cells_at_nodes(cells, node_count, cell_start_index): # function to find out to which domain elements a boundary element belongs def find_connected_domain_cells( - boundary_cells_values, domain_cells_at_node + boundary_cells_values: Any, domain_cells_at_node: list[set[int]] ) -> tuple[np.ndarray, np.ndarray]: warned_gt1 = False # to avoid flood of warnings warned_lt1 = False # to avoid flood of warnings @@ -160,7 +160,7 @@ def msh2vtu( keep_ids: bool = False, ascii: bool = False, log_level: Union[int, str] = "DEBUG", -): +) -> int: """ Convert a gmsh mesh (.msh) to an unstructured grid file (.vtu). diff --git a/ogstools/propertylib/property.py b/ogstools/propertylib/property.py index 26a2381ba..0ee376e00 100644 --- a/ogstools/propertylib/property.py +++ b/ogstools/propertylib/property.py @@ -7,7 +7,7 @@ via pint. from collections.abc import Sequence from dataclasses import dataclass, replace -from typing import Callable, Union +from typing import Any, Callable, Union import numpy as np import pyvista as pv @@ -46,15 +46,15 @@ class Property: categoric: bool = False """Does this property only have categoric values?""" - def __post_init__(self): + def __post_init__(self) -> None: if not self.output_name: self.output_name = self.data_name @property - def type_name(self): + def type_name(self) -> str: return type(self).__name__ - def replace(self, **changes): + def replace(self: "Property", **changes: Any) -> "Property": """ Create a new Property object with modified attributes. @@ -68,7 +68,9 @@ class Property: return replace(self, **changes) @classmethod - def from_property(cls, new_property: "Property", **changes): + def from_property( # type: ignore[no-untyped-def] + cls, new_property: "Property", **changes: Any + ): "Create a new Property object with modified attributes." return cls( data_name=new_property.data_name, @@ -151,7 +153,7 @@ class Property: """ return self.data_name == self.mask - def get_mask(self): + def get_mask(self) -> "Property": """ :returns: A property representing this properties mask. """ diff --git a/ogstools/studies/convergence/convergence.py b/ogstools/studies/convergence/convergence.py index 9cc39ac75..c5282ecfa 100644 --- a/ogstools/studies/convergence/convergence.py +++ b/ogstools/studies/convergence/convergence.py @@ -148,7 +148,7 @@ def convergence_metrics( :returns: A pandas Dataframe containing all metrics. """ - def _data(m: pv.UnstructuredGrid): + def _data(m: pv.UnstructuredGrid) -> np.ndarray: return mesh_property.magnitude.transform( m.point_data[mesh_property.data_name] ) @@ -304,7 +304,9 @@ def convergence_metrics_evolution( .to(units[1]) .magnitude ) - p_metrics_per_t = np.concatenate(([time_vals], p_metrics_per_t.T)).T + p_metrics_per_t = np.concatenate( + (np.asarray([time_vals]), p_metrics_per_t.T) + ).T columns = ["timevalue"] + [ f"{t} ({x})" for t in ["abs. error", "rel. error", "p"] diff --git a/ogstools/studies/convergence/examples/steady_state_diffusion.py b/ogstools/studies/convergence/examples/steady_state_diffusion.py index 88be208ea..aba217301 100644 --- a/ogstools/studies/convergence/examples/steady_state_diffusion.py +++ b/ogstools/studies/convergence/examples/steady_state_diffusion.py @@ -6,15 +6,15 @@ import numpy as np import pyvista as pv -def _c_k(k): +def _c_k(k: float) -> float: return 0.5 * (2 * k - 1) * np.pi -def _a_k(k): +def _a_k(k: float) -> float: return 2 / (_c_k(k) ** 2 * np.cosh(_c_k(k))) -def _h(points): +def _h(points: np.ndarray) -> np.ndarray: result = np.ones(len(points)) for k in np.arange(1, 100): c_k_val = _c_k(k) -- GitLab