From a1ec240475077e9da1122efbd05689fe827dff3c Mon Sep 17 00:00:00 2001
From: Wenqing Wang <wenqing.wang@ufz.de>
Date: Thu, 23 Nov 2023 17:13:21 +0100
Subject: [PATCH] [TimeLoop] Use the new class StaggeredCoupling

---
 ProcessLib/TimeLoop.cpp | 168 ++++------------------------------------
 ProcessLib/TimeLoop.h   |  26 +------
 2 files changed, 18 insertions(+), 176 deletions(-)

diff --git a/ProcessLib/TimeLoop.cpp b/ProcessLib/TimeLoop.cpp
index b687c02f3d6..fb97d6f59a6 100644
--- a/ProcessLib/TimeLoop.cpp
+++ b/ProcessLib/TimeLoop.cpp
@@ -12,7 +12,6 @@
 
 #include <range/v3/algorithm/any_of.hpp>
 #include <range/v3/algorithm/contains.hpp>
-#include <range/v3/view/map.hpp>
 
 #include "BaseLib/Error.h"
 #include "BaseLib/RunTime.h"
@@ -20,15 +19,11 @@
 #include "NumLib/ODESolver/ConvergenceCriterionPerComponent.h"
 #include "NumLib/ODESolver/PETScNonlinearSolver.h"
 #include "NumLib/ODESolver/TimeDiscretizedODESystem.h"
+#include "NumLib/StaggeredCoupling/StaggeredCoupling.h"
 #include "ProcessData.h"
 
 namespace
 {
-bool isMonolithicProcess(ProcessLib::ProcessData const& process_data)
-{
-    return process_data.process.isMonolithicSchemeUsed();
-}
-
 void updateDeactivatedSubdomains(
     std::vector<std::unique_ptr<ProcessLib::ProcessData>> const&
         per_process_data,
@@ -268,38 +263,19 @@ NumLib::NonlinearSolverStatus solveOneTimeStepOneProcess(
     return nonlinear_solver_status;
 }
 
-TimeLoop::TimeLoop(std::vector<Output>&& outputs,
-                   std::vector<std::unique_ptr<ProcessData>>&& per_process_data,
-                   const int global_coupling_max_iterations,
-                   std::vector<std::unique_ptr<NumLib::ConvergenceCriterion>>&&
-                       global_coupling_conv_crit,
-                   std::map<std::string, int>&& local_coupling_processes,
-                   const double start_time, const double end_time)
+TimeLoop::TimeLoop(
+    std::vector<Output>&& outputs,
+    std::vector<std::unique_ptr<ProcessData>>&& per_process_data,
+    std::unique_ptr<NumLib::StaggeredCoupling>&& staggered_coupling,
+    const double start_time, const double end_time)
     : _outputs{std::move(outputs)},
       _per_process_data(std::move(per_process_data)),
       _start_time(start_time),
       _end_time(end_time),
-      _global_coupling_max_iterations(global_coupling_max_iterations),
-      _global_coupling_conv_crit(std::move(global_coupling_conv_crit)),
-      _local_coupling_processes(std::move(local_coupling_processes))
+      _staggered_coupling(std::move(staggered_coupling))
 {
 }
 
-void TimeLoop::setCoupledSolutions()
-{
-    for (auto const& process_data : _per_process_data)
-    {
-        auto const& x = *_process_solutions[process_data->process_id];
-
-        // Create a vector to store the solution of the last coupling iteration
-        auto& x0 = NumLib::GlobalVectorProvider::provider.getVector(x);
-        MathLib::LinAlg::copy(x, x0);
-
-        // append a solution vector of suitable size
-        _solutions_of_last_cpl_iteration.emplace_back(&x0);
-    }
-}
-
 bool computationOfChangeNeeded(
     NumLib::TimeStepAlgorithm const& timestep_algorithm, double const time)
 {
@@ -512,13 +488,9 @@ void TimeLoop::initialize()
     std::tie(_process_solutions, _process_solutions_prev) =
         setInitialConditions(_start_time, _per_process_data);
 
-    // All _per_process_data share the first process.
-    bool const is_staggered_coupling =
-        !isMonolithicProcess(*_per_process_data[0]);
-
-    if (is_staggered_coupling)
+    if (_staggered_coupling)
     {
-        setCoupledSolutions();
+        _staggered_coupling->initializeCoupledSolutions(_process_solutions);
     }
 
     updateDeactivatedSubdomains(_per_process_data, _start_time);
@@ -622,12 +594,9 @@ bool TimeLoop::preTsNonlinearSolvePostTs(double const t, double const dt,
 {
     preTimestepForAllProcesses(t, dt, _per_process_data, _process_solutions);
 
-    // All _per_process_data share the first process.
-    bool const is_staggered_coupling =
-        !isMonolithicProcess(*_per_process_data[0]);
     NumLib::NonlinearSolverStatus nonlinear_solver_status;
 
-    if (is_staggered_coupling)
+    if (_staggered_coupling)
     {
         nonlinear_solver_status =
             solveCoupledEquationSystemsByStaggeredScheme(t, dt, timesteps);
@@ -723,114 +692,12 @@ NumLib::NonlinearSolverStatus
 TimeLoop::solveCoupledEquationSystemsByStaggeredScheme(
     const double t, const double dt, const std::size_t timestep_id)
 {
-    // Coupling iteration
-    if (_global_coupling_max_iterations != 0)
-    {
-        // Set the flag of the first iteration be true.
-        for (auto& conv_crit : _global_coupling_conv_crit)
-        {
-            conv_crit->preFirstIteration();
-        }
-    }
-    auto resetCouplingConvergenceCriteria = [&]()
-    {
-        for (auto& conv_crit : _global_coupling_conv_crit)
-        {
-            conv_crit->reset();
-        }
-    };
-
-    NumLib::NonlinearSolverStatus nonlinear_solver_status{false, -1};
-    bool coupling_iteration_converged = true;
-    bool local_coupling_iteration_converged = true;
-    for (int global_coupling_iteration = 0;
-         global_coupling_iteration < _global_coupling_max_iterations;
-         global_coupling_iteration++, resetCouplingConvergenceCriteria())
-    {
-        // TODO(wenqing): use process name
-        coupling_iteration_converged = true;
-        bool local_iteration_converged = true;
-        for (auto const& process_data : _per_process_data)
-        {
-            auto const process_id = process_data->process_id;
-
-            bool const isLocalCouplingProcess = ranges::contains(
-                _local_coupling_processes | ranges::views::values, process_id);
-
-            if (!local_coupling_iteration_converged && !isLocalCouplingProcess)
-            {
-                coupling_iteration_converged = false;
-                continue;
-            }
-
-            BaseLib::RunTime time_timestep_process;
-            time_timestep_process.start();
-
-            nonlinear_solver_status = solveOneTimeStepOneProcess(
-                _process_solutions, _process_solutions_prev, timestep_id, t, dt,
-                *process_data, _outputs);
-            process_data->nonlinear_solver_status = nonlinear_solver_status;
-
-            INFO(
-                "[time] Solving process #{:d} took {:g} s in time step #{:d}  "
-                "coupling iteration #{:d}",
-                process_id, time_timestep_process.elapsed(), timestep_id,
-                global_coupling_iteration);
-
-            if (!nonlinear_solver_status.error_norms_met)
-            {
-                WARN(
-                    "The nonlinear solver failed in time step #{:d} at t = "
-                    "{:g} s for process #{:d}.",
-                    timestep_id, t, process_id);
-                _last_step_rejected = true;
-                return nonlinear_solver_status;
-            }
-
-            // Check the convergence of the coupling iteration
-            auto& x = *_process_solutions[process_id];
-            auto& x_old = *_solutions_of_last_cpl_iteration[process_id];
-            if (global_coupling_iteration > 0)
-            {
-                MathLib::LinAlg::axpy(x_old, -1.0, x);  // save dx to x_old
-                INFO(
-                    "------- Checking convergence criterion for coupled "
-                    "solution of process #{:d} -------",
-                    process_id);
-                _global_coupling_conv_crit[process_id]->checkDeltaX(x_old, x);
-                coupling_iteration_converged =
-                    coupling_iteration_converged &&
-                    _global_coupling_conv_crit[process_id]->isSatisfied();
-                if (isLocalCouplingProcess)
-                {
-                    local_iteration_converged =
-                        local_iteration_converged &&
-                        _global_coupling_conv_crit[process_id]->isSatisfied();
-                }
-            }
-            MathLib::LinAlg::copy(x, x_old);
-        }  // end of for (auto& process_data : _per_process_data)
-
-        local_coupling_iteration_converged = local_iteration_converged;
-        if (local_coupling_iteration_converged &&
-            coupling_iteration_converged && global_coupling_iteration > 0)
-        {
-            break;
-        }
-
-        if (!nonlinear_solver_status.error_norms_met)
-        {
-            return nonlinear_solver_status;
-        }
-    }
+    auto const nonlinear_solver_status =
+        _staggered_coupling->execute<ProcessData, Output>(
+            t, dt, timestep_id, _process_solutions, _process_solutions_prev,
+            _per_process_data, _outputs, &solveOneTimeStepOneProcess);
 
-    if (!coupling_iteration_converged || !local_coupling_iteration_converged)
-    {
-        WARN(
-            "The coupling iterations reaches its maximum number in time step "
-            "#{:d} at t = {:g} s",
-            timestep_id, t);
-    }
+    _last_step_rejected = nonlinear_solver_status.error_norms_met;
 
     {
         for (auto const& process_data : _per_process_data)
@@ -883,11 +750,6 @@ TimeLoop::~TimeLoop()
     {
         NumLib::GlobalVectorProvider::provider.releaseVector(*x);
     }
-
-    for (auto* x : _solutions_of_last_cpl_iteration)
-    {
-        NumLib::GlobalVectorProvider::provider.releaseVector(*x);
-    }
 }
 
 void TimeLoop::preOutputInitialConditions(const double t) const
diff --git a/ProcessLib/TimeLoop.h b/ProcessLib/TimeLoop.h
index ae6da8a59e5..43c5c112529 100644
--- a/ProcessLib/TimeLoop.h
+++ b/ProcessLib/TimeLoop.h
@@ -21,6 +21,7 @@
 namespace NumLib
 {
 class ConvergenceCriterion;
+class StaggeredCoupling;
 }
 
 namespace ChemistryLib
@@ -38,10 +39,7 @@ class TimeLoop
 public:
     TimeLoop(std::vector<Output>&& outputs,
              std::vector<std::unique_ptr<ProcessData>>&& per_process_data,
-             const int global_coupling_max_iterations,
-             std::vector<std::unique_ptr<NumLib::ConvergenceCriterion>>&&
-                 global_coupling_conv_crit,
-             std::map<std::string, int>&& local_coupling_processes,
+             std::unique_ptr<NumLib::StaggeredCoupling>&& staggered_coupling,
              const double start_time, const double end_time);
 
     void initialize();
@@ -67,13 +65,6 @@ public:
 private:
     bool preTsNonlinearSolvePostTs(double const t, double const dt,
                                    std::size_t const timesteps);
-    /**
-     * This function fills the vector of solutions of coupled processes of
-     * processes, _solutions_of_coupled_processes, and initializes the vector
-     * of solutions of the previous coupling iteration,
-     * _solutions_of_last_cpl_iteration.
-     */
-    void setCoupledSolutions();
 
     /**
      * \brief Member to solver non coupled systems of equations, which can be
@@ -146,17 +137,6 @@ private:
     int _repeating_times_of_rejected_step = 0;
     bool _last_step_rejected = false;
 
-    /// Maximum iterations of the global coupling.
-    const int _global_coupling_max_iterations;
-    /// Convergence criteria of processes for the global coupling iterations.
-    std::vector<std::unique_ptr<NumLib::ConvergenceCriterion>>
-        _global_coupling_conv_crit;
-
-    /// Processes that will be solved in a local iteration.
-    std::map<std::string, int> _local_coupling_processes;
-
-    /// Solutions of the previous coupling iteration for the convergence
-    /// criteria of the coupling iteration.
-    std::vector<GlobalVector*> _solutions_of_last_cpl_iteration;
+    std::unique_ptr<NumLib::StaggeredCoupling> _staggered_coupling;
 };
 }  // namespace ProcessLib
-- 
GitLab