From b419ade77d5e98cf03ef1da5c019a6e60b5a1965 Mon Sep 17 00:00:00 2001
From: FZill <florian.zill@ufz.de>
Date: Thu, 21 Mar 2024 12:45:26 +0100
Subject: [PATCH] [all] many type fixes

also ignore jupyter notebooks in main folder
---
 .gitignore                                    |  5 +-
 ogstools/meshlib/boundary.py                  | 10 +--
 ogstools/meshlib/boundary_set.py              | 18 ++---
 ogstools/meshlib/boundary_subset.py           | 10 ++-
 ogstools/meshlib/gmsh_meshing.py              |  8 +--
 ogstools/meshlib/mesh_series.py               |  7 +-
 ogstools/meshlib/region/region.py             |  7 +-
 ogstools/meshlib/xdmf_reader.py               | 11 +--
 ogstools/meshplotlib/animation.py             | 16 +++--
 ogstools/meshplotlib/core.py                  | 69 +++++++++++--------
 ogstools/meshplotlib/plot_features.py         | 15 ++--
 ogstools/meshplotlib/plot_setup.py            | 14 ++--
 ogstools/meshplotlib/utils.py                 |  5 +-
 ogstools/msh2vtu/__init__.py                  | 18 ++---
 ogstools/propertylib/property.py              | 14 ++--
 ogstools/studies/convergence/convergence.py   |  6 +-
 .../examples/steady_state_diffusion.py        |  6 +-
 17 files changed, 134 insertions(+), 105 deletions(-)

diff --git a/.gitignore b/.gitignore
index f0aa73237..3e4a20641 100644
--- a/.gitignore
+++ b/.gitignore
@@ -162,6 +162,9 @@ cython_debug/
 #  be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
 #  and can be added to the global gitignore or merged into this file.  For a more nuclear
 #  option (not recommended) you can uncomment the following to ignore the entire idea folder.
-#.idea/
+#.idea/S
 
 /docs/auto_examples/
+
+# ignore jupyter notebooks in main folder (useful for testing, annoying for git)
+/**.ipynb
diff --git a/ogstools/meshlib/boundary.py b/ogstools/meshlib/boundary.py
index d298cc7e6..030e3fb95 100644
--- a/ogstools/meshlib/boundary.py
+++ b/ogstools/meshlib/boundary.py
@@ -21,7 +21,7 @@ class Boundary(ABC):
     """
 
     @abstractmethod
-    def dim(self):
+    def dim(self) -> int:
         """
         Get the dimension of the boundary.
 
@@ -29,7 +29,7 @@ class Boundary(ABC):
             int: The dimension of the boundary. For example, the dimension of a boundary
             of a cube (3D) is 2.
         """
-        return
+        return 0
 
 
 @dataclass(frozen=True)
@@ -47,7 +47,7 @@ class Layer(Boundary):
     of the layer in the GIS system.
     """
 
-    def __post_init__(self):
+    def __post_init__(self) -> None:
         if not self.material_id:
             object.__setattr__(self, "material_id", self.bottom._material_id)
 
@@ -89,7 +89,7 @@ class Layer(Boundary):
         rasters.append(bottom_raster)
         return rasters
 
-    def dim(self):
+    def dim(self) -> int:
         return 3
 
 
@@ -100,7 +100,7 @@ class LocationFrame:
     ymin: float
     ymax: float
 
-    def as_gml(self, filename: Path):
+    def as_gml(self, filename: Path) -> None:
         """
         Generate GML representation of the location frame.
 
diff --git a/ogstools/meshlib/boundary_set.py b/ogstools/meshlib/boundary_set.py
index 17dd54b92..22de2ef82 100644
--- a/ogstools/meshlib/boundary_set.py
+++ b/ogstools/meshlib/boundary_set.py
@@ -20,12 +20,12 @@ class BoundarySet(ABC):
     """
 
     @abstractmethod
-    def bounds(self):
-        return
+    def bounds(self) -> list:
+        return []
 
     @abstractmethod
-    def filenames(self):
-        return
+    def filenames(self) -> list[Path]:
+        return []
 
 
 class LayerSet(BoundarySet):
@@ -49,16 +49,16 @@ class LayerSet(BoundarySet):
                 raise ValueError(msg)
         self.layers = layers
 
-    def bounds(self):
+    def bounds(self) -> list:
         return list(self.layers[0].top.mesh.bounds)
 
-    def filenames(self):
+    def filenames(self) -> list[Path]:
         layer_filenames = [layer.bottom.filename for layer in self.layers]
         layer_filenames.insert(0, self.layers[0].top.filename)  # file interface
         return layer_filenames
 
     @classmethod
-    def from_pandas(cls, df: pd.DataFrame):
+    def from_pandas(cls, df: pd.DataFrame) -> "LayerSet":
         """Create a LayerSet from a Pandas DataFrame."""
         Row = namedtuple("Row", ["material_id", "mesh", "resolution"])
         surfaces = [
@@ -79,7 +79,7 @@ class LayerSet(BoundarySet):
         ]
         return cls(layers=base_layer)
 
-    def create_raster(self, resolution):
+    def create_raster(self, resolution: float) -> tuple[Path, Path]:
         """
         Create raster representations for the LayerSet.
 
@@ -103,7 +103,7 @@ class LayerSet(BoundarySet):
             file.write("\n".join(str(item) for item in raster_set))
         return raster_vtu, rastered_layers_txt
 
-    def create_rasters(self, resolution: int) -> list[Path]:
+    def create_rasters(self, resolution: float) -> list[Path]:
         """
         For each surface a (temporary) raster file with given resolution is created.
         """
diff --git a/ogstools/meshlib/boundary_subset.py b/ogstools/meshlib/boundary_subset.py
index d297b4fc7..ef74fc93e 100644
--- a/ogstools/meshlib/boundary_subset.py
+++ b/ogstools/meshlib/boundary_subset.py
@@ -21,9 +21,7 @@ class Surface:
         return self._material_id
 
     @typechecked
-    def __init__(
-        self, input: Union[Path, pv.UnstructuredGrid], material_id: int
-    ):
+    def __init__(self, input: Union[Path, pv.DataObject], material_id: int):
         """Initialize a surface mesh. Either from pyvista or from a file."""
         self._material_id = material_id
 
@@ -33,7 +31,7 @@ class Surface:
                 msg = f"{self.filename} does not exist."
                 raise ValueError(msg)
             self.mesh = pv.get_reader(self.filename).read()
-        elif isinstance(input, pv.UnstructuredGrid):
+        elif isinstance(input, pv.DataObject):
             self.mesh = input
             self.filename = Path(tempfile.mkstemp(".vtu", "surface")[1])
             pv.save_meshio(self.filename, self.mesh, file_format="vtu")
@@ -42,7 +40,7 @@ class Surface:
             np.ones(self.mesh.n_cells) * self.material_id
         ).astype(np.int32)
 
-    def __eq__(self, other):
+    def __eq__(self, other: object) -> bool:
         return self.__dict__ == other.__dict__
 
     def create_raster_file(self, resolution: float) -> Path:
@@ -76,7 +74,7 @@ def Gaussian2D(
     spread: float,
     height_offset: float,
     n: int,
-):
+) -> pv.DataSet:
     """
     Generate a 2D Gaussian-like surface using the provided parameters.
 
diff --git a/ogstools/meshlib/gmsh_meshing.py b/ogstools/meshlib/gmsh_meshing.py
index c58cb076d..c985a5000 100644
--- a/ogstools/meshlib/gmsh_meshing.py
+++ b/ogstools/meshlib/gmsh_meshing.py
@@ -9,7 +9,7 @@ def _geo_square(
     lengths: Union[float, list[float]],
     n_edge_cells: Union[int, list[int]],
     structured: bool,
-):
+) -> None:
     _lengths = lengths if isinstance(lengths, list) else [lengths] * 2
     _n = n_edge_cells if isinstance(n_edge_cells, list) else [n_edge_cells] * 2
     geo.addPoint(0, 0, 0, tag=1)
@@ -41,7 +41,7 @@ def rect(
     structured_grid: bool = True,
     order: int = 1,
     out_name: Path = Path("unit_square.msh"),
-):
+) -> None:
     gmsh.initialize()
     gmsh.option.set_number("General.Verbosity", 0)
     gmsh.model.add("unit_square")
@@ -75,7 +75,7 @@ def cuboid(
     structured_grid: bool = True,
     order: int = 1,
     out_name: Path = Path("unit_cube.msh"),
-):
+) -> None:
     gmsh.initialize()
     gmsh.option.set_number("General.Verbosity", 0)
     gmsh.model.add("unit_cube")
@@ -118,7 +118,7 @@ def bhe_mesh(
     bhe_depth: float = 20,
     order: int = 1,
     out_name: Path = Path("bhe_mesh.msh"),
-):
+) -> None:
     gmsh.initialize()
     model, geo = (gmsh.model, gmsh.model.geo)
     model.add(Path(out_name).stem)
diff --git a/ogstools/meshlib/mesh_series.py b/ogstools/meshlib/mesh_series.py
index 7021fed53..c7dd4a5d9 100644
--- a/ogstools/meshlib/mesh_series.py
+++ b/ogstools/meshlib/mesh_series.py
@@ -7,6 +7,7 @@ import meshio
 import numpy as np
 import pyvista as pv
 import vtuIO
+from h5py import File
 from scipy.interpolate import LinearNDInterpolator, NearestNDInterpolator
 from tqdm.auto import tqdm
 
@@ -55,7 +56,7 @@ class MeshSeries:
             raise TypeError(msg)
 
     @property
-    def hdf5(self):
+    def hdf5(self) -> File:
         # We assume there is only one h5 file
         return next(iter(self._xdmf_reader.hdf5_files.values()))
 
@@ -225,7 +226,7 @@ class MeshSeries:
         data_name: str,
         interp_method: Optional[Literal["nearest", "probefilter"]] = None,
         interp_backend: Optional[Literal["vtk", "scipy"]] = None,
-    ):
+    ) -> np.ndarray:
         obs_pts_dict = {f"pt{j}": point for j, point in enumerate(points)}
         dim = self.read(0).get_cell(0).dimension
         pvd_path = self.filepath
@@ -242,7 +243,7 @@ class MeshSeries:
         points: np.ndarray,
         data_name: str,
         interp_method: Optional[Literal["nearest", "linear"]] = None,
-    ):
+    ) -> np.ndarray:
         values = self.hdf5["meshes"][self.hdf5_bulk_name][data_name][:]
         geom = self.hdf5["meshes"][self.hdf5_bulk_name]["geometry"][0]
         values = np.swapaxes(values, 0, 1)
diff --git a/ogstools/meshlib/region/region.py b/ogstools/meshlib/region/region.py
index dd308bd62..ac6e4b8dc 100644
--- a/ogstools/meshlib/region/region.py
+++ b/ogstools/meshlib/region/region.py
@@ -7,7 +7,7 @@ import numpy as np
 import pyvista as pv
 from ogs import cli
 
-from ..boundary_set import LayerSet
+from ..boundary_set import Layer, LayerSet
 
 
 class RegionSet:
@@ -26,7 +26,7 @@ class RegionSet:
             self.filename = Path(tempfile.mkstemp(".vtu", "region_set")[1])
             self.mesh = input
 
-    def box_boundaries(self):
+    def box_boundaries(self) -> tuple[pv.UnstructuredGrid, ...]:
         """
         Retrieve the boundaries of the mesh in local coordinate system (u, v, w).
 
@@ -46,6 +46,7 @@ class RegionSet:
             mesh = ...
             u_min, u_max, v_min, v_max, w_min, w_max = mesh.box_boundaries()
         """
+        assert isinstance(self.mesh, pv.UnstructuredGrid)
         surface = self.mesh.extract_surface()
         u_max = to_boundary(surface, lambda normals: normals[:, 0] > 0.5)
         u_min = to_boundary(surface, lambda normals: normals[:, 0] < -0.5)
@@ -152,7 +153,7 @@ def to_region_prism(layer_set: LayerSet, resolution: float) -> RegionSet:
 
 
 def layer_to_simplified_mesh(
-    layer, resolution: float, rank: int, bounds: list[float]
+    layer: Layer, resolution: float, rank: int, bounds: list[float]
 ) -> pv.UnstructuredGrid:
     """Convert a geological layer to a simplified mesh.
 
diff --git a/ogstools/meshlib/xdmf_reader.py b/ogstools/meshlib/xdmf_reader.py
index 6f1b49b57..c82d49041 100644
--- a/ogstools/meshlib/xdmf_reader.py
+++ b/ogstools/meshlib/xdmf_reader.py
@@ -15,6 +15,8 @@ to be read like::
 
 """
 
+from xml.etree.ElementTree import Element
+
 import meshio
 import numpy as np
 from meshio.xdmf.time_series import (
@@ -25,10 +27,10 @@ from meshio.xdmf.time_series import (
 
 
 class XDMFReader(meshio.xdmf.TimeSeriesReader):
-    def __init__(self, filename):
+    def __init__(self, filename: str):
         super().__init__(filename)
 
-    def read_data(self, k: int):
+    def read_data(self, k: int) -> tuple[float, dict, dict, dict]:
         point_data = {}
         cell_data_raw = {}
         other_data = {}
@@ -67,8 +69,8 @@ class XDMFReader(meshio.xdmf.TimeSeriesReader):
 
         return t, point_data, cell_data, other_data
 
-    def _read_data_item(self, data_item):
-        dims = [int(d) for d in data_item.get("Dimensions").split()]
+    def _read_data_item(self, data_item: Element) -> np.ndarray:
+        dims = [int(d) for d in data_item.get("Dimensions", "").split()]
 
         # Actually, `NumberType` is XDMF2 and `DataType` XDMF3, but many files out there
         # use both keys interchangeably.
@@ -92,6 +94,7 @@ class XDMFReader(meshio.xdmf.TimeSeriesReader):
 
         data_format = data_item.attrib["Format"]
 
+        assert isinstance(data_item.text, str)
         if data_format == "XML":
             return np.fromstring(
                 data_item.text,
diff --git a/ogstools/meshplotlib/animation.py b/ogstools/meshplotlib/animation.py
index d21069933..e05db2f6a 100644
--- a/ogstools/meshplotlib/animation.py
+++ b/ogstools/meshplotlib/animation.py
@@ -44,7 +44,7 @@ def animate(
         index = np.argmin(np.abs(np.asarray(ts) - i))
 
         fig.axes[-1].remove()  # remove colorbar
-        for ax in np.ravel(fig.axes):
+        for ax in np.ravel(np.asarray(fig.axes)):
             ax.clear()
         if titles is not None:
             setup.title_center = titles[index]
@@ -54,18 +54,20 @@ def animate(
             mesh = mesh_series.read_interp(i, True)
         with warnings.catch_warnings():
             warnings.simplefilter("ignore")
-            fig = _draw_plot(mesh, property, fig=fig)
+            fig = _draw_plot(
+                mesh, property, fig=fig
+            )  # type: ignore[assignment]
 
     _func = partial(animate_func, fig=fig)
 
     return FuncAnimation(
-        fig,
-        _func,
+        fig,  # type: ignore[arg-type]
+        _func,  # type: ignore[arg-type]
         frames=tqdm(ts),
         blit=False,
         interval=50,
         repeat=False,
-        init_func=init,
+        init_func=init,  # type: ignore[arg-type]
     )
 
 
@@ -83,7 +85,7 @@ def save_animation(anim: FuncAnimation, filename: str, fps: int) -> bool:
         "-vf pad=ceil(iw/2)*2:ceil(ih/2)*2"
     ).split(" ")
 
-    writer = None
+    writer: Optional[Union[FFMpegWriter, ImageMagickWriter]] = None
     if FFMpegWriter.isAvailable():
         writer = FFMpegWriter(fps=fps, codec="libx265", extra_args=codec_args)
         filename += ".mp4"
@@ -91,7 +93,7 @@ def save_animation(anim: FuncAnimation, filename: str, fps: int) -> bool:
         print("\nffmpeg not available. It is recommended for saving animation.")
         filename += ".gif"
         if ImageMagickWriter.isAvailable():
-            writer = "imagemagick"
+            writer = ImageMagickWriter()
         else:
             print(
                 "ImageMagick also not available. Falling back to"
diff --git a/ogstools/meshplotlib/core.py b/ogstools/meshplotlib/core.py
index ca139bbfa..8fe6ac039 100644
--- a/ogstools/meshplotlib/core.py
+++ b/ogstools/meshplotlib/core.py
@@ -2,7 +2,7 @@
 
 import warnings
 from math import nextafter
-from typing import Literal, Optional, Union
+from typing import Any, Literal, Optional, Union
 
 import numpy as np
 import pyvista as pv
@@ -28,13 +28,13 @@ from .utils import get_style_cycler
 # TODO: define default data_name for regions in setup
 
 
-def _q_zero_line(mesh_property: Property, levels: np.ndarray):
+def _q_zero_line(mesh_property: Property, levels: np.ndarray) -> bool:
     return mesh_property.bilinear_cmap or (
         mesh_property.data_name == "temperature" and levels[0] < 0 < levels[-1]
     )
 
 
-def get_level_boundaries(levels: np.ndarray):
+def get_level_boundaries(levels: np.ndarray) -> np.ndarray:
     return np.array(
         [
             levels[0] - 0.5 * (levels[1] - levels[0]),
@@ -57,6 +57,7 @@ def get_cmap_norm(
         continuous_cmap = colormaps[mesh_property.cmap]
     else:
         continuous_cmap = mesh_property.cmap
+    conti_norm: Union[mcolors.TwoSlopeNorm, mcolors.Normalize]
     if mesh_property.bilinear_cmap:
         if vmin <= 0.0 <= vmax:
             vcenter = 0.0
@@ -258,9 +259,13 @@ def subplot(
     cmap, norm = get_cmap_norm(levels, mesh_property)
 
     if mesh_property.data_name in mesh.point_data:
-        ax.tricontourf(x, y, tri, values, levels=levels, cmap=cmap, norm=norm)
+        ax.tricontourf(  # type: ignore[call-overload]
+            x, y, tri, values, levels=levels, cmap=cmap, norm=norm
+        )
         if _q_zero_line(mesh_property, levels):
-            ax.tricontour(x, y, tri, values, levels=[0], colors="w")
+            ax.tricontour(  # type: ignore[call-overload]
+                x, y, tri, values, levels=[0], colors="w"
+            )
     else:
         ax.tripcolor(x, y, tri, facecolors=values, cmap=cmap, norm=norm)
         if mesh_property.is_mask():
@@ -299,7 +304,9 @@ def subplot(
                 sec_labels += [""]
         # TODO: use a function to make this short
         secax = ax.secondary_xaxis("top")
-        secax.xaxis.set_major_locator(mticker.FixedLocator(ax.get_xticks()))
+        secax.xaxis.set_major_locator(
+            mticker.FixedLocator(list(ax.get_xticks()))
+        )
         secax.set_xticklabels(sec_labels)
         secax.set_xlabel(f'{"xyz"[projection]} / {setup.length.output_unit}')
 
@@ -419,21 +426,21 @@ def _draw_plot(
     meshes: Union[list[pv.UnstructuredGrid], np.ndarray, pv.UnstructuredGrid],
     mesh_property: Property,
     fig: Optional[mfigure.Figure] = None,
-    ax: Optional[plt.Axes] = None,
-) -> mfigure.Figure:
+    axes: Optional[plt.Axes] = None,
+) -> Optional[mfigure.Figure]:
     """
     Plot the property field of meshes on existing figure.
 
     :param meshes: Singular mesh of 2D numpy array of meshes
     :param property: the property field to be visualized on all meshes
-    :param fig: Matplotlib Figure object to use for plotting (optional)
-    :param ax: Matplotlib Axis object to use for plotting (optional)
+    :param fig: Matplotlib figure to use for plotting (optional)
+    :param axes: Matplotlib Axes to use for plotting (optional)
     """
     shape = _get_rows_cols(meshes)
     np_meshes = np.reshape(meshes, shape)
-    if fig is not None and ax is not None:
-        np_axs = np.reshape(np.array(ax), shape)
-    elif fig is not None and ax is None:
+    if fig is not None and axes is not None:
+        np_axs = np.reshape(np.array(axes), shape)
+    elif fig is not None and axes is None:
         # Only Fig is given
         # Multiple meshes should be accepted
         warnings.warn(
@@ -441,14 +448,14 @@ def _draw_plot(
             Warning,
             stacklevel=4,
         )
-        np_axs = np.reshape(fig.axes, shape)
-    elif fig is None and ax is not None:
+        np_axs = np.reshape(np.asarray(fig.axes), shape)
+    elif fig is None and axes is not None:
         # Only ax is given
         # Only one mesh should be accepted
         if shape != (1, 1):
             msg = "You have provided only one Axis object but multiple meshes. Provide only one mesh per Axis object, or provide Figure object instead."
             raise ValueError(msg)
-        np_axs = np.reshape(np.array(ax), (1, 1))
+        np_axs = np.reshape(np.array(axes), (1, 1))
     else:
         msg = "Neither Figure nor Axis object was provided."
         raise TypeError(msg)
@@ -459,14 +466,14 @@ def _draw_plot(
             _levels = (
                 combined_levels
                 if setup.combined_colorbar
-                else get_combined_levels([np_meshes[i, j]], mesh_property)
+                else get_combined_levels(np_meshes[i, j, None], mesh_property)
             )
             subplot(np_meshes[i, j], mesh_property, np_axs[i, j], _levels)
 
     x_id, y_id = get_projection(
         np_meshes[0, 0]
     )  # One mesh is sufficient, it should be the same for all of them
-    np_axs = label_spatial_axes(np_axs, np.array([x_id, y_id]))
+    label_spatial_axes(np_axs, "xyz"[x_id], "xyz"[y_id])
     np_axs[0, 0].set_title(setup.title_center, loc="center", y=1.02)
     np_axs[0, 0].set_title(setup.title_left, loc="left", y=1.02)
     np_axs[0, 0].set_title(setup.title_right, loc="right", y=1.02)
@@ -481,7 +488,7 @@ def _draw_plot(
                 stacklevel=4,
             )
         else:
-            cb_axs = np.ravel(fig.axes).tolist()
+            cb_axs = np.ravel(np.asarray(fig.axes)).tolist()
             add_colorbars(
                 fig, cb_axs, mesh_property, combined_levels, pad=0.05 / shape[1]
             )
@@ -497,7 +504,7 @@ def _draw_plot(
             for i in range(shape[0]):
                 for j in range(shape[1]):
                     _levels = get_combined_levels(
-                        [np_meshes[i, j]], mesh_property
+                        np_meshes[i, j, None], mesh_property
                     )
                     add_colorbars(fig, np_axs[i, j], mesh_property, _levels)
     return fig
@@ -532,10 +539,14 @@ def update_font_sizes(
         subax_xlim = subax.get_xlim()
         subax_ylim = subax.get_ylim()
         subax.set_xticks(
-            subax.get_xticks(), subax.get_xticklabels(), fontsize=fontsize
+            subax.get_xticks(),
+            [label.get_text() for label in subax.get_xticklabels()],
+            fontsize=fontsize,
         )
         subax.set_yticks(
-            subax.get_yticks(), subax.get_yticklabels(), fontsize=fontsize
+            subax.get_yticks(),
+            [label.get_text() for label in subax.get_yticklabels()],
+            fontsize=fontsize,
         )
         subax.set_xlim(subax_xlim)
         subax.set_ylim(subax_ylim)
@@ -549,7 +560,7 @@ def plot(
     mesh_property: Union[Property, str],
     fig: Optional[mfigure.Figure] = None,
     ax: Optional[plt.Axes] = None,
-) -> mfigure.Figure:
+) -> Optional[mfigure.Figure]:
     """
     Plot the property field of meshes with default settings.
 
@@ -576,18 +587,20 @@ def plot(
     n_axs = shape[0] * shape[1]
     if ax is None and fig is None:
         _fig, _ax = _fig_init(rows=shape[0], cols=shape[1], aspect=fig_aspect)
-        fig = _draw_plot(meshes, mesh_property, fig=_fig, ax=_ax)
+        fig = _draw_plot(meshes, mesh_property, fig=_fig, axes=_ax)
+        assert isinstance(fig, plt.Figure)
         for ax, aspect in zip(fig.axes[: n_axs + 1], ax_aspects):
             ax.set_aspect(1.0 / aspect)
     elif ax is not None and fig is None:
-        _draw_plot(meshes, mesh_property, ax=ax)
+        _draw_plot(meshes, mesh_property, axes=ax)
         ax.set_aspect(1.0 / ax_aspects[0])
     elif ax is None and fig is not None:
         fig = _draw_plot(meshes, mesh_property, fig=fig)
+        assert isinstance(fig, plt.Figure)
         for ax, aspect in zip(fig.axes[: n_axs + 1], ax_aspects):
             ax.set_aspect(1.0 / aspect)
     elif ax is not None and fig is not None:
-        _draw_plot(meshes, mesh_property, fig=fig, ax=ax)
+        _draw_plot(meshes, mesh_property, fig=fig, axes=ax)
         for ax, aspect in zip(fig.axes[: n_axs + 1], ax_aspects):
             ax.set_aspect(1.0 / aspect)
     return fig
@@ -606,8 +619,8 @@ def plot_probe(
     linestyles: Optional[list] = None,
     ax: Optional[plt.Axes] = None,
     fill_between: bool = False,
-    **kwargs,
-) -> mfigure.Figure:
+    **kwargs: Any,
+) -> Optional[mfigure.Figure]:
     """
     Plot the transient property on the observation points in the MeshSeries.
 
diff --git a/ogstools/meshplotlib/plot_features.py b/ogstools/meshplotlib/plot_features.py
index 5b358b30c..9a5ee911c 100644
--- a/ogstools/meshplotlib/plot_features.py
+++ b/ogstools/meshplotlib/plot_features.py
@@ -74,7 +74,9 @@ def plot_element_edges(ax: plt.Axes, surf: pv.DataSet, projection: int) -> None:
         ]
         verts = setup.length.transform(np.delete(cell_pts, projection, -1))
         lw = 0.5 * setup.rcParams_scaled["lines.linewidth"]
-        pc = PolyCollection(verts, fc="None", ec="black", lw=lw)
+        pc = PolyCollection(
+            verts, fc="None", ec="black", lw=lw  # type: ignore[arg-type]
+        )
         ax.add_collection(pc)
 
 
@@ -131,9 +133,9 @@ def plot_streamlines(
     lw = 2.5 * val_norm / max(1e-16, np.max(val_norm))
     lw *= setup.rcParams_scaled["lines.linewidth"]
     x_g, y_g = setup.length.transform(np.meshgrid(x, y))
-    plot_args = [x_g, y_g, val[..., 0], val[..., 1]]
     if plot_type == "streamlines":
-        ax.streamplot(*plot_args, color="k", linewidth=lw, density=1.5)
+        ax.streamplot(x_g, y_g, val[..., 0], val[..., 1],
+                      color="k", linewidth=lw, density=1.5)  # fmt: skip
     else:
         line_args = (
             dict(  # noqa: C408
@@ -142,7 +144,8 @@ def plot_streamlines(
             if plot_type == "lines"
             else {}
         )
-        ax.quiver(*plot_args, **line_args, scale=1 / 0.03)
+        scale = 1.0 / 0.03
+        ax.quiver(x_g, y_g, val[..., 0], val[..., 1], **line_args, scale=scale)
 
 
 def plot_on_top(
@@ -161,7 +164,7 @@ def plot_on_top(
     x_vals = df_pts.groupby("x")["x"].agg(np.mean).to_numpy()
     y_vals = df_pts.groupby("x")["y"].agg(np.max).to_numpy()
     contour_vals = [y + scaling * contour(x) for y, x in zip(y_vals, x_vals)]
-    ax.set_ylim(top=setup.length.transform(np.max(contour_vals)))
+    ax.set_ylim(top=float(setup.length.transform(np.max(contour_vals))))
     ax.fill_between(
         setup.length.transform(x_vals),
         setup.length.transform(y_vals),
@@ -172,7 +175,7 @@ def plot_on_top(
 
 def plot_contour(
     ax: plt.Axes, mesh: pv.DataSet, style: str, lw: int, projection: int = 2
-):
+) -> None:
     contour = mesh.extract_surface().strip(join=True)
     x_id, y_id = np.delete([0, 1, 2], projection)
     x, y = 1e-3 * contour.points[contour.lines[1:]].T[[x_id, y_id]]
diff --git a/ogstools/meshplotlib/plot_setup.py b/ogstools/meshplotlib/plot_setup.py
index d36e7f327..6c933c6b7 100644
--- a/ogstools/meshplotlib/plot_setup.py
+++ b/ogstools/meshplotlib/plot_setup.py
@@ -1,7 +1,7 @@
 """Plot configuration setup."""
 
 from dataclasses import dataclass
-from typing import Union
+from typing import Optional, Union
 
 from ogstools.propertylib.property import Scalar
 
@@ -23,9 +23,9 @@ class PlotSetup:
     "The resolution (dots per inch) for the figure."
     fig_scale: float
     "A scaling factor for the figure."
-    min_ax_aspect: float
+    min_ax_aspect: Optional[float]
     "Minimum aspect ratio of subplots."
-    max_ax_aspect: float
+    max_ax_aspect: Optional[float]
     "Maximum aspect ratio of subplots."
     invert_colorbar: bool
     "A boolean indicating whether to invert the colorbar."
@@ -38,11 +38,11 @@ class PlotSetup:
     num_levels: int
     """The aimed number of levels / bins of the colorbar. See
     :obj:`ogstools.meshplotlib.levels`"""
-    num_streamline_interp_pts: int
+    num_streamline_interp_pts: Optional[int]
     "The number of interpolation points for streamlines."
-    p_max: float
+    p_max: Optional[float]
     "The fixed upper limit for the current scale."
-    p_min: float
+    p_min: Optional[float]
     "The fixed lower limit for the current scale."
     rcParams: dict
     """Matplotlib runtime configuration. See
@@ -78,7 +78,7 @@ class PlotSetup:
         return params
 
     @classmethod
-    def from_dict(cls: type["PlotSetup"], obj: dict):
+    def from_dict(cls: type["PlotSetup"], obj: dict) -> "PlotSetup":
         """Create a PlotSetup instance from a dictionary."""
         return cls(
             fig_scale=obj["fig_scale"],
diff --git a/ogstools/meshplotlib/utils.py b/ogstools/meshplotlib/utils.py
index e31ce760c..8cdc448c9 100644
--- a/ogstools/meshplotlib/utils.py
+++ b/ogstools/meshplotlib/utils.py
@@ -2,12 +2,13 @@ from typing import Optional
 
 import matplotlib.pyplot as plt
 import numpy as np
+from cycler import Cycler
 
 
 def justified_labels(points: np.ndarray) -> list[str]:
     "Formats an array of points to a list of aligned str."
 
-    def fmt(val: float):
+    def fmt(val: float) -> str:
         return f"{val:.2f}".rstrip("0").rstrip(".")
 
     col_lens = np.max(
@@ -24,7 +25,7 @@ def get_style_cycler(
     min_number_of_styles: int,
     colors: Optional[Optional[list]] = None,
     linestyles: Optional[list] = None,
-) -> plt.cycler:
+) -> Cycler:
     if colors is None:
         colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
     if linestyles is None:
diff --git a/ogstools/msh2vtu/__init__.py b/ogstools/msh2vtu/__init__.py
index ac2a24d6e..251124ed3 100644
--- a/ogstools/msh2vtu/__init__.py
+++ b/ogstools/msh2vtu/__init__.py
@@ -1,7 +1,7 @@
 # Author: Dominik Kern (TU Bergakademie Freiberg)
 import logging
 from pathlib import Path
-from typing import Union
+from typing import Any, Union
 
 import meshio
 import numpy as np
@@ -18,7 +18,7 @@ __all__ = [
 ]
 
 
-def my_remove_orphaned_nodes(my_mesh: meshio.Mesh):
+def my_remove_orphaned_nodes(my_mesh: meshio.Mesh) -> None:
     """Auxiliary function to remove points not belonging to any cell"""
 
     # find connected points and derive mapping from all points to them
@@ -51,11 +51,9 @@ def my_remove_orphaned_nodes(my_mesh: meshio.Mesh):
     # cell data are not affected by point changes
     my_mesh.cells = output_cell_blocks  # update cells
 
-    return
-
 
 # print info for mesh: statistics and data field names
-def print_info(mesh: meshio.Mesh):
+def print_info(mesh: meshio.Mesh) -> None:
     N, D = mesh.points.shape
     logging.info("%d points in %d dimensions", N, D)
     cell_info = "cells: "
@@ -70,13 +68,15 @@ def print_info(mesh: meshio.Mesh):
 
 # function to create node connectivity list, i.e. store for each node (point) to
 # which element (cell) it belongs
-def find_cells_at_nodes(cells, node_count, cell_start_index):
+def find_cells_at_nodes(
+    cells: Any, node_count: int, cell_start_index: int
+) -> list[set]:
     # depending on the numbering of mixed meshes in OGS one may think of an
     # object-oriented way to add elements (of different type) to node
     # connectivity
 
     # initialize list of sets
-    node_connectivity = [set() for _ in range(node_count)]
+    node_connectivity: list[set] = [set() for _ in range(node_count)]
     cell_index = cell_start_index
     for cell in cells:
         for node in cell:
@@ -95,7 +95,7 @@ def find_cells_at_nodes(cells, node_count, cell_start_index):
 
 # function to find out to which domain elements a boundary element belongs
 def find_connected_domain_cells(
-    boundary_cells_values, domain_cells_at_node
+    boundary_cells_values: Any, domain_cells_at_node: list[set[int]]
 ) -> tuple[np.ndarray, np.ndarray]:
     warned_gt1 = False  # to avoid flood of warnings
     warned_lt1 = False  # to avoid flood of warnings
@@ -160,7 +160,7 @@ def msh2vtu(
     keep_ids: bool = False,
     ascii: bool = False,
     log_level: Union[int, str] = "DEBUG",
-):
+) -> int:
     """
     Convert a gmsh mesh (.msh) to an unstructured grid file (.vtu).
 
diff --git a/ogstools/propertylib/property.py b/ogstools/propertylib/property.py
index 26a2381ba..0ee376e00 100644
--- a/ogstools/propertylib/property.py
+++ b/ogstools/propertylib/property.py
@@ -7,7 +7,7 @@ via pint.
 
 from collections.abc import Sequence
 from dataclasses import dataclass, replace
-from typing import Callable, Union
+from typing import Any, Callable, Union
 
 import numpy as np
 import pyvista as pv
@@ -46,15 +46,15 @@ class Property:
     categoric: bool = False
     """Does this property only have categoric values?"""
 
-    def __post_init__(self):
+    def __post_init__(self) -> None:
         if not self.output_name:
             self.output_name = self.data_name
 
     @property
-    def type_name(self):
+    def type_name(self) -> str:
         return type(self).__name__
 
-    def replace(self, **changes):
+    def replace(self: "Property", **changes: Any) -> "Property":
         """
         Create a new Property object with modified attributes.
 
@@ -68,7 +68,9 @@ class Property:
         return replace(self, **changes)
 
     @classmethod
-    def from_property(cls, new_property: "Property", **changes):
+    def from_property(  # type: ignore[no-untyped-def]
+        cls, new_property: "Property", **changes: Any
+    ):
         "Create a new Property object with modified attributes."
         return cls(
             data_name=new_property.data_name,
@@ -151,7 +153,7 @@ class Property:
         """
         return self.data_name == self.mask
 
-    def get_mask(self):
+    def get_mask(self) -> "Property":
         """
         :returns: A property representing this properties mask.
         """
diff --git a/ogstools/studies/convergence/convergence.py b/ogstools/studies/convergence/convergence.py
index 9cc39ac75..c5282ecfa 100644
--- a/ogstools/studies/convergence/convergence.py
+++ b/ogstools/studies/convergence/convergence.py
@@ -148,7 +148,7 @@ def convergence_metrics(
     :returns:           A pandas Dataframe containing all metrics.
     """
 
-    def _data(m: pv.UnstructuredGrid):
+    def _data(m: pv.UnstructuredGrid) -> np.ndarray:
         return mesh_property.magnitude.transform(
             m.point_data[mesh_property.data_name]
         )
@@ -304,7 +304,9 @@ def convergence_metrics_evolution(
         .to(units[1])
         .magnitude
     )
-    p_metrics_per_t = np.concatenate(([time_vals], p_metrics_per_t.T)).T
+    p_metrics_per_t = np.concatenate(
+        (np.asarray([time_vals]), p_metrics_per_t.T)
+    ).T
     columns = ["timevalue"] + [
         f"{t} ({x})"
         for t in ["abs. error", "rel. error", "p"]
diff --git a/ogstools/studies/convergence/examples/steady_state_diffusion.py b/ogstools/studies/convergence/examples/steady_state_diffusion.py
index 88be208ea..aba217301 100644
--- a/ogstools/studies/convergence/examples/steady_state_diffusion.py
+++ b/ogstools/studies/convergence/examples/steady_state_diffusion.py
@@ -6,15 +6,15 @@ import numpy as np
 import pyvista as pv
 
 
-def _c_k(k):
+def _c_k(k: float) -> float:
     return 0.5 * (2 * k - 1) * np.pi
 
 
-def _a_k(k):
+def _a_k(k: float) -> float:
     return 2 / (_c_k(k) ** 2 * np.cosh(_c_k(k)))
 
 
-def _h(points):
+def _h(points: np.ndarray) -> np.ndarray:
     result = np.ones(len(points))
     for k in np.arange(1, 100):
         c_k_val = _c_k(k)
-- 
GitLab