From 6049c1d6c7839bbe5ef30175630a07986d1e1d1d Mon Sep 17 00:00:00 2001
From: Tobias Meisel <tobias.meisel@ufz.de>
Date: Wed, 27 Mar 2024 14:02:50 +0100
Subject: [PATCH] [logparser] Extend type annotations

---
 ogstools/logparser/common_ogs_analyses.py | 18 +++++++++---------
 ogstools/logparser/log_parser.py          | 10 +++++-----
 2 files changed, 14 insertions(+), 14 deletions(-)

diff --git a/ogstools/logparser/common_ogs_analyses.py b/ogstools/logparser/common_ogs_analyses.py
index c6608c731..1ae1bd007 100644
--- a/ogstools/logparser/common_ogs_analyses.py
+++ b/ogstools/logparser/common_ogs_analyses.py
@@ -4,7 +4,7 @@
 #              http://www.opengeosys.org/project/license
 
 
-from typing import Callable
+from typing import Any, Callable
 
 import numpy as np
 import pandas as pd
@@ -55,8 +55,8 @@ def pre_post_check(interest: list[str], context: list[str]) -> Callable:
       based on the specified 'interest' and 'context'.
     """
 
-    def wrap(func):
-        def wrapped_f(df):
+    def wrap(func: Callable[[Any], Any]) -> Callable[[Any], Any]:
+        def wrapped_f(df: Any) -> Any:
             _check_input(df, interest, context)
             pt = func(df)
             _check_output(pt, interest, context)
@@ -67,7 +67,7 @@ def pre_post_check(interest: list[str], context: list[str]) -> Callable:
     return wrap
 
 
-def analysis_time_step(df: pd.DataFrame):
+def analysis_time_step(df: pd.DataFrame) -> pd.DataFrame:
     """
     Analysis with focus on computation time per time step. It combines time step specific measurements 'output time'
     and 'time step solution time' with iteration specific measurements 'assembly time', 'linear solver time', 'Dirichlet time'.
@@ -88,7 +88,7 @@ def analysis_time_step(df: pd.DataFrame):
     return dfe
 
 
-def analysis_simulation(df: pd.DataFrame):
+def analysis_simulation(df: pd.DataFrame) -> pd.DataFrame:
     interest = ["execution_time"]  # 'start_time'
     context = ["mpi_process"]
     _check_input(df, interest, context)
@@ -98,7 +98,7 @@ def analysis_simulation(df: pd.DataFrame):
     return pt
 
 
-def analysis_convergence_newton_iteration(df: pd.DataFrame):
+def analysis_convergence_newton_iteration(df: pd.DataFrame) -> pd.DataFrame:
     dfe_newton_iteration = df.copy()
     interest = ["dx", "x", "dx_x"]
     if "coupling_iteration" in df:
@@ -141,7 +141,7 @@ def analysis_convergence_newton_iteration(df: pd.DataFrame):
     interest=["dx", "x", "dx_x"],
     context=["time_step", "coupling_iteration", "coupling_iteration_process"],
 )
-def analysis_convergence_coupling_iteration(df: pd.DataFrame):
+def analysis_convergence_coupling_iteration(df: pd.DataFrame) -> pd.DataFrame:
     # Coupling iteration column will be modified specific for coupling iteration analysis, modified data can not be used for other analysis ->copy!
     dfe_convergence_coupling_iteration = df.copy()
     interest = ["dx", "x", "dx_x"]
@@ -169,7 +169,7 @@ def analysis_convergence_coupling_iteration(df: pd.DataFrame):
     return pt
 
 
-def time_step_vs_iterations(df: pd.DataFrame):
+def time_step_vs_iterations(df: pd.DataFrame) -> pd.DataFrame:
     interest = ["iteration_number"]
     context = ["time_step"]
     _check_input(df, interest, context)
@@ -178,7 +178,7 @@ def time_step_vs_iterations(df: pd.DataFrame):
     return pt
 
 
-def analysis_simulation_termination(df: pd.DataFrame):
+def analysis_simulation_termination(df: pd.DataFrame) -> pd.DataFrame:
     # For full print of messages consider setup jupyter notebook:
     # pd.set_option('display.max_colwidth', None)
     interest = ["message"]
diff --git a/ogstools/logparser/log_parser.py b/ogstools/logparser/log_parser.py
index 299ddf912..f0f1e18f2 100644
--- a/ogstools/logparser/log_parser.py
+++ b/ogstools/logparser/log_parser.py
@@ -7,12 +7,12 @@ import re
 from pathlib import Path
 from typing import Any, Callable, Optional, Union
 
-from ogstools.logparser.ogs_regexes import ogs_regexes
+from ogstools.logparser.regexes import Log, ogs_regexes
 
 
 def _try_match_parallel_line(
-    line: str, line_nr: int, regex: re.Pattern, pattern_class
-):
+    line: str, line_nr: int, regex: re.Pattern, pattern_class: type[Log]
+) -> Optional[Any]:
     if match := regex.match(line):
         # Line , Process, Type specific
         ts = pattern_class.type_str()
@@ -30,8 +30,8 @@ def _try_match_parallel_line(
 
 
 def _try_match_serial_line(
-    line: str, line_nr: int, regex: re.Pattern, pattern_class
-):
+    line: str, line_nr: int, regex: re.Pattern, pattern_class: type[Log]
+) -> Optional[list[tuple[str, Log]]]:
     if match := regex.match(line):
         # Line , Process, Type specific
         ts = pattern_class.type_str()
-- 
GitLab