From 6d6a01b53f4b08071a1b8e7cfd06d0ec95b692c1 Mon Sep 17 00:00:00 2001
From: Wenqing Wang <wenqing.wang@ufz.de>
Date: Tue, 28 Nov 2017 17:04:00 +0100
Subject: [PATCH] [ODE] Added an argument of equation_id to
 MathLib::getMatrixSpecifications

---
 NumLib/DOF/MatrixProviderUser.h               |  3 ++-
 NumLib/ODESolver/TimeDiscretizedODESystem.cpp | 20 +++++++++--------
 NumLib/ODESolver/TimeDiscretizedODESystem.h   | 22 ++++++++++---------
 ProcessLib/Process.cpp                        |  3 ++-
 ProcessLib/Process.h                          |  6 ++---
 ProcessLib/UncoupledProcessesTimeLoop.cpp     | 10 +++++----
 Tests/NumLib/ODEs.h                           | 18 +++++++++------
 Tests/NumLib/TestODEInt.cpp                   |  5 +++--
 8 files changed, 50 insertions(+), 37 deletions(-)

diff --git a/NumLib/DOF/MatrixProviderUser.h b/NumLib/DOF/MatrixProviderUser.h
index 4be6e8efe07..c0d3c9cccdc 100644
--- a/NumLib/DOF/MatrixProviderUser.h
+++ b/NumLib/DOF/MatrixProviderUser.h
@@ -19,7 +19,8 @@ namespace NumLib
 class MatrixSpecificationsProvider
 {
 public:
-    virtual MathLib::MatrixSpecifications getMatrixSpecifications() const = 0;
+    virtual MathLib::MatrixSpecifications getMatrixSpecifications(
+        const int equation_id) const = 0;
 
     virtual ~MatrixSpecificationsProvider() = default;
 };
diff --git a/NumLib/ODESolver/TimeDiscretizedODESystem.cpp b/NumLib/ODESolver/TimeDiscretizedODESystem.cpp
index aa4b0818444..eaad9a12bf8 100644
--- a/NumLib/ODESolver/TimeDiscretizedODESystem.cpp
+++ b/NumLib/ODESolver/TimeDiscretizedODESystem.cpp
@@ -40,19 +40,20 @@ namespace NumLib
 {
 TimeDiscretizedODESystem<ODESystemTag::FirstOrderImplicitQuasilinear,
                          NonlinearSolverTag::Newton>::
-    TimeDiscretizedODESystem(ODE& ode, TimeDisc& time_discretization)
+    TimeDiscretizedODESystem(const int equation_id,
+                             ODE& ode, TimeDisc& time_discretization)
     : _ode(ode),
       _time_disc(time_discretization),
       _mat_trans(createMatrixTranslator<ODETag>(time_discretization))
 {
     _Jac = &NumLib::GlobalMatrixProvider::provider.getMatrix(
-        _ode.getMatrixSpecifications(), _Jac_id);
+        _ode.getMatrixSpecifications(equation_id), _Jac_id);
     _M = &NumLib::GlobalMatrixProvider::provider.getMatrix(
-        _ode.getMatrixSpecifications(), _M_id);
+        _ode.getMatrixSpecifications(equation_id), _M_id);
     _K = &NumLib::GlobalMatrixProvider::provider.getMatrix(
-        _ode.getMatrixSpecifications(), _K_id);
+        _ode.getMatrixSpecifications(equation_id), _K_id);
     _b = &NumLib::GlobalVectorProvider::provider.getVector(
-        _ode.getMatrixSpecifications(), _b_id);
+        _ode.getMatrixSpecifications(equation_id), _b_id);
 }
 
 TimeDiscretizedODESystem<
@@ -152,17 +153,18 @@ void TimeDiscretizedODESystem<ODESystemTag::FirstOrderImplicitQuasilinear,
 
 TimeDiscretizedODESystem<ODESystemTag::FirstOrderImplicitQuasilinear,
                          NonlinearSolverTag::Picard>::
-    TimeDiscretizedODESystem(ODE& ode, TimeDisc& time_discretization)
+    TimeDiscretizedODESystem(const int equation_id, ODE& ode,
+                             TimeDisc& time_discretization)
     : _ode(ode),
       _time_disc(time_discretization),
       _mat_trans(createMatrixTranslator<ODETag>(time_discretization))
 {
     _M = &NumLib::GlobalMatrixProvider::provider.getMatrix(
-        ode.getMatrixSpecifications(), _M_id);
+        ode.getMatrixSpecifications(equation_id), _M_id);
     _K = &NumLib::GlobalMatrixProvider::provider.getMatrix(
-        ode.getMatrixSpecifications(), _K_id);
+        ode.getMatrixSpecifications(equation_id), _K_id);
     _b = &NumLib::GlobalVectorProvider::provider.getVector(
-        ode.getMatrixSpecifications(), _b_id);
+        ode.getMatrixSpecifications(equation_id), _b_id);
 }
 
 TimeDiscretizedODESystem<
diff --git a/NumLib/ODESolver/TimeDiscretizedODESystem.h b/NumLib/ODESolver/TimeDiscretizedODESystem.h
index a17cb034bef..9a9d351adb5 100644
--- a/NumLib/ODESolver/TimeDiscretizedODESystem.h
+++ b/NumLib/ODESolver/TimeDiscretizedODESystem.h
@@ -76,12 +76,12 @@ public:
      * \param ode the ODE to be wrapped.
      * \param time_discretization the time discretization to be used.
      */
-    explicit TimeDiscretizedODESystem(ODE& ode, TimeDisc& time_discretization);
+    explicit TimeDiscretizedODESystem(const int equation_id, ODE& ode,
+                                      TimeDisc& time_discretization);
 
     ~TimeDiscretizedODESystem() override;
 
-    void assemble(const GlobalVector& x_new_timestep)
-                  override;
+    void assemble(const GlobalVector& x_new_timestep) override;
 
     void getResidual(GlobalVector const& x_new_timestep,
                      GlobalVector& res) const override;
@@ -114,9 +114,10 @@ public:
     }
 
     TimeDisc& getTimeDiscretization() override { return _time_disc; }
-    MathLib::MatrixSpecifications getMatrixSpecifications() const override
+    MathLib::MatrixSpecifications getMatrixSpecifications(
+        const int equation_id) const override
     {
-        return _ode.getMatrixSpecifications();
+        return _ode.getMatrixSpecifications(equation_id);
     }
 
 private:
@@ -169,12 +170,12 @@ public:
      * \param ode the ODE to be wrapped.
      * \param time_discretization the time discretization to be used.
      */
-    explicit TimeDiscretizedODESystem(ODE& ode, TimeDisc& time_discretization);
+    explicit TimeDiscretizedODESystem(const int equation_id, ODE& ode,
+                                      TimeDisc& time_discretization);
 
     ~TimeDiscretizedODESystem() override;
 
-    void assemble(const GlobalVector& x_new_timestep)
-                  override;
+    void assemble(const GlobalVector& x_new_timestep) override;
 
     void getA(GlobalMatrix& A) const override
     {
@@ -212,9 +213,10 @@ public:
     }
 
     TimeDisc& getTimeDiscretization() override { return _time_disc; }
-    MathLib::MatrixSpecifications getMatrixSpecifications() const override
+    MathLib::MatrixSpecifications getMatrixSpecifications(
+        const int equation_id) const override
     {
-        return _ode.getMatrixSpecifications();
+        return _ode.getMatrixSpecifications(equation_id);
     }
 
 private:
diff --git a/ProcessLib/Process.cpp b/ProcessLib/Process.cpp
index e062d076324..53c7345ef8d 100644
--- a/ProcessLib/Process.cpp
+++ b/ProcessLib/Process.cpp
@@ -150,7 +150,8 @@ void Process::setInitialConditions(const unsigned pcs_id, double const t,
     }
 }
 
-MathLib::MatrixSpecifications Process::getMatrixSpecifications() const
+MathLib::MatrixSpecifications Process::getMatrixSpecifications(
+    const int /*equation_id*/) const
 {
     auto const& l = *_local_to_global_index_map;
     return {l.dofSizeWithoutGhosts(), l.dofSizeWithoutGhosts(),
diff --git a/ProcessLib/Process.h b/ProcessLib/Process.h
index 44bbe15c6f2..ab3c2c09e9a 100644
--- a/ProcessLib/Process.h
+++ b/ProcessLib/Process.h
@@ -71,13 +71,13 @@ public:
     void setInitialConditions(const unsigned pcs_id, const double t,
                               GlobalVector& x);
 
-    MathLib::MatrixSpecifications getMatrixSpecifications() const final;
+    virtual MathLib::MatrixSpecifications getMatrixSpecifications(
+        const int equation_id) const override;
 
     void setCoupledSolutionsForStaggeredScheme(
         CoupledSolutionsForStaggeredScheme* const coupled_solutions)
     {
         _coupled_solutions = coupled_solutions;
-
     }
 
     bool isMonolithicSchemeUsed() const { return _use_monolithic_scheme; }
@@ -236,9 +236,9 @@ protected:
     /// implemented in MathLib::GaussLegendre.
     unsigned const _integration_order;
 
-private:
     GlobalSparsityPattern _sparsity_pattern;
 
+private:
     /// Variables used by this process.  For the monolithic scheme or a
     /// single process, the size of the outer vector is one. For the
     /// staggered scheme, the size of the outer vector is the number of the
diff --git a/ProcessLib/UncoupledProcessesTimeLoop.cpp b/ProcessLib/UncoupledProcessesTimeLoop.cpp
index 6532c3d7480..4ba93cc8be1 100644
--- a/ProcessLib/UncoupledProcessesTimeLoop.cpp
+++ b/ProcessLib/UncoupledProcessesTimeLoop.cpp
@@ -126,6 +126,8 @@ struct SingleProcessData
     //! cast of \c tdisc_ode_sys to NumLib::InternalMatrixStorage
     NumLib::InternalMatrixStorage* mat_strg = nullptr;
 
+    int equation_id = 0;
+
     Process& process;
     ProcessOutput process_output;
 };
@@ -183,7 +185,7 @@ void setTimeDiscretizedODESystem(
 
         spd.tdisc_ode_sys = std::make_unique<
             NumLib::TimeDiscretizedODESystem<ODETag, Tag::Picard>>(
-            ode_sys, *spd.time_disc);
+            spd.equation_id, ode_sys, *spd.time_disc);
     }
     else if (dynamic_cast<NonlinearSolverNewton*>(&spd.nonlinear_solver))
     {
@@ -194,7 +196,7 @@ void setTimeDiscretizedODESystem(
         {
             spd.tdisc_ode_sys = std::make_unique<
                 NumLib::TimeDiscretizedODESystem<ODETag, Tag::Newton>>(
-                *ode_newton, *spd.time_disc);
+                spd.equation_id, *ode_newton, *spd.time_disc);
         }
         else
         {
@@ -368,7 +370,7 @@ std::vector<GlobalVector*> setInitialConditions(
 {
     std::vector<GlobalVector*> process_solutions;
 
-    unsigned pcs_idx = 0;
+    int pcs_idx = 0;
     for (auto& spd : per_process_data)
     {
         auto& pcs = spd->process;
@@ -380,7 +382,7 @@ std::vector<GlobalVector*> setInitialConditions(
         // append a solution vector of suitable size
         process_solutions.emplace_back(
             &NumLib::GlobalVectorProvider::provider.getVector(
-                ode_sys.getMatrixSpecifications()));
+                ode_sys.getMatrixSpecifications(pcs_idx)));
 
         auto& x0 = *process_solutions[pcs_idx];
         pcs.setInitialConditions(pcs_idx, t0, x0);
diff --git a/Tests/NumLib/ODEs.h b/Tests/NumLib/ODEs.h
index 099e7150ace..b270f5241de 100644
--- a/Tests/NumLib/ODEs.h
+++ b/Tests/NumLib/ODEs.h
@@ -38,9 +38,9 @@ public:
     }
 
     void assembleWithJacobian(const double t, GlobalVector const& x_curr,
-                              GlobalVector const& /*xdot*/, const double dxdot_dx,
-                              const double dx_dx, GlobalMatrix& M,
-                              GlobalMatrix& K, GlobalVector& b,
+                              GlobalVector const& /*xdot*/,
+                              const double dxdot_dx, const double dx_dx,
+                              GlobalMatrix& M, GlobalMatrix& K, GlobalVector& b,
                               GlobalMatrix& Jac) override
     {
         namespace LinAlg = MathLib::LinAlg;
@@ -51,13 +51,15 @@ public:
         LinAlg::finalizeAssembly(M);
         LinAlg::copy(M, Jac);
         LinAlg::scale(Jac, dxdot_dx);
-        if (dx_dx != 0.0) {
+        if (dx_dx != 0.0)
+        {
             LinAlg::finalizeAssembly(K);
             LinAlg::axpy(Jac, dx_dx, K);
         }
     }
 
-    MathLib::MatrixSpecifications getMatrixSpecifications() const override
+    MathLib::MatrixSpecifications getMatrixSpecifications(
+        const int /*equation_id*/) const override
     {
         return { N, N, nullptr, nullptr };
     }
@@ -140,7 +142,8 @@ public:
         }
     }
 
-    MathLib::MatrixSpecifications getMatrixSpecifications() const override
+    MathLib::MatrixSpecifications getMatrixSpecifications(
+        const int /*equation_id*/) const override
     {
         return { N, N, nullptr, nullptr };
     }
@@ -266,7 +269,8 @@ public:
         // INFO("Det J: %e <<<", J.determinant());
     }
 
-    MathLib::MatrixSpecifications getMatrixSpecifications() const override
+    MathLib::MatrixSpecifications getMatrixSpecifications(
+        const int /*equation_id*/) const override
     {
         return { N, N, nullptr, nullptr };
     }
diff --git a/Tests/NumLib/TestODEInt.cpp b/Tests/NumLib/TestODEInt.cpp
index b14c79d00ac..ca936a30a53 100644
--- a/Tests/NumLib/TestODEInt.cpp
+++ b/Tests/NumLib/TestODEInt.cpp
@@ -68,8 +68,9 @@ public:
 
         Solution sol;
 
+        const int equation_id = 0; 
         NumLib::TimeDiscretizedODESystem<ODE_::ODETag, NLTag>
-                ode_sys(ode, timeDisc);
+                ode_sys(equation_id, ode, timeDisc);
 
         auto linear_solver = createLinearSolver();
         auto conv_crit = std::make_unique<NumLib::ConvergenceCriterionDeltaX>(
@@ -90,7 +91,7 @@ public:
              delta_t);
 
         // initial condition
-        GlobalVector x0(ode.getMatrixSpecifications().nrows);
+        GlobalVector x0(ode.getMatrixSpecifications(equation_id).nrows);
         ODET::setIC(x0);
 
         sol.ts.push_back(t0);
-- 
GitLab