From a15413b1eb8c778b67c54977031f0bd66747c8e5 Mon Sep 17 00:00:00 2001
From: Haibing Shao <haibing.shao@ufz.de>
Date: Thu, 11 Jul 2024 13:51:25 +0200
Subject: [PATCH] add feature to normalize A matrix and RHS vector

implementation, normalize A and rhs

add feature in nonlinear solver

Comment out of writing matrices for debugging

Normalize linear equation system only for rectangular shaped A Matrix

add the switch in NonlinearSolver

add feature in matMultTranspose

only when A is not square

deactivate PETSc implementation

add the check function in EquationSystem

add override functions in ODE

add override functions in TimeDiscretizedODESystem
---
 MathLib/LinAlg/LinAlg.cpp                   | 47 +++++++++++++++++++++
 MathLib/LinAlg/LinAlg.h                     | 10 ++++-
 NumLib/ODESolver/EquationSystem.h           |  9 ++++
 NumLib/ODESolver/MatrixTranslator.cpp       | 19 +++++++++
 NumLib/ODESolver/MatrixTranslator.h         |  8 ++++
 NumLib/ODESolver/NonlinearSolver.cpp        |  5 +++
 NumLib/ODESolver/NonlinearSystem.h          |  6 +++
 NumLib/ODESolver/TimeDiscretizedODESystem.h | 15 +++++++
 Tests/NumLib/ODEs.h                         |  6 +++
 9 files changed, 124 insertions(+), 1 deletion(-)

diff --git a/MathLib/LinAlg/LinAlg.cpp b/MathLib/LinAlg/LinAlg.cpp
index 5e688a3fd1e..8be69e83804 100644
--- a/MathLib/LinAlg/LinAlg.cpp
+++ b/MathLib/LinAlg/LinAlg.cpp
@@ -167,6 +167,34 @@ void matMultAdd(PETScMatrix const& A, PETScVector const& v1,
                v3.getRawVector());
 }
 
+void linearSysNormalize(PETScMatrix const& /*A*/, PETScMatrix& /*new_A*/,
+                        PETScVector const& /*b*/, PETScVector& /*new_b*/)
+{
+    // The following block is deactivated, because there is no tests yet for the
+    // normalization operation in PETSc. This will be a task for later.
+    /*
+    assert(&A != &new_A);
+    assert(&b != &new_b);
+
+    PetscInt n_rows(0);
+    PetscInt n_cols(0);
+    MatGetSize(A.getRawMatrix(), &n_rows, &n_cols);
+    // only when A matrix is not square
+    if (n_rows != n_cols)
+    {
+        // new_b = A^T * b
+        MatMultTranspose(A.getRawMatrix(), b.getRawVector(),
+                         new_b.getRawVector());
+        // new_A = A^T * A
+        MatTranspose(A.getRawMatrix(), MAT_INITIAL_MATRIX,
+                     &(new_A.getRawMatrix()));
+    }
+    */
+    OGS_FATAL(
+        "Normalization operation is not implemented yet for PETSc library! "
+        "Program terminated.");
+}
+
 void finalizeAssembly(PETScMatrix& A)
 {
     A.finalizeAssembly(MAT_FINAL_ASSEMBLY);
@@ -313,6 +341,25 @@ void matMultAdd(EigenMatrix const& A, EigenVector const& v1,
         v2.getRawVector() + A.getRawMatrix() * v1.getRawVector();
 }
 
+void linearSysNormalize(EigenMatrix const& A, EigenMatrix& new_A,
+                        EigenVector const& b, EigenVector& new_b)
+{
+    // make sure that new_A and new_b are not the same memory
+    assert(&A != &new_A);
+    assert(&b != &new_b);
+
+    if (A.getRawMatrix().rows() == A.getRawMatrix().cols())
+    {
+        WARN(
+            "The number of rows and columns are the same for the LHS matrix."
+            "Are you sure you still need to normalize the LHS matrix and RHS "
+            "vector? ");
+    }
+
+    new_b.getRawVector() = A.getRawMatrix().transpose() * b.getRawVector();
+    new_A.getRawMatrix() = A.getRawMatrix().transpose() * A.getRawMatrix();
+}
+
 void finalizeAssembly(EigenMatrix& x)
 {
     x.getRawMatrix().makeCompressed();
diff --git a/MathLib/LinAlg/LinAlg.h b/MathLib/LinAlg/LinAlg.h
index f982306db82..206395e8f90 100644
--- a/MathLib/LinAlg/LinAlg.h
+++ b/MathLib/LinAlg/LinAlg.h
@@ -196,6 +196,11 @@ void matMult(PETScMatrix const& A, PETScVector const& x, PETScVector& y);
 void matMultAdd(PETScMatrix const& A, PETScVector const& v1,
                 PETScVector const& v2, PETScVector& v3);
 
+// new_A = A^T * A
+// new_b = A^T * b
+void linearSysNormalize(PETScMatrix const& A, PETScMatrix& new_A,
+                        PETScVector const& b, PETScVector& new_b);
+
 void finalizeAssembly(PETScMatrix& A);
 void finalizeAssembly(PETScVector& x);
 
@@ -264,7 +269,10 @@ void matMult(EigenMatrix const& A, EigenVector const& x, EigenVector& y);
 // v3 = A*v1 + v2
 void matMultAdd(EigenMatrix const& A, EigenVector const& v1,
                 EigenVector const& v2, EigenVector& v3);
-
+// new_A = A^T * A
+// new_b = A^T * b
+void linearSysNormalize(EigenMatrix const& A, EigenMatrix& new_A,
+                        EigenVector const& b, EigenVector& new_b);
 void finalizeAssembly(EigenMatrix& x);
 void finalizeAssembly(EigenVector& A);
 
diff --git a/NumLib/ODESolver/EquationSystem.h b/NumLib/ODESolver/EquationSystem.h
index 643d1bdb57a..c851ed8f92c 100644
--- a/NumLib/ODESolver/EquationSystem.h
+++ b/NumLib/ODESolver/EquationSystem.h
@@ -39,6 +39,15 @@ public:
      */
     virtual bool isLinear() const = 0;
 
+    /*! Check whether normalization of A and rhs is required.
+     *
+     * \remark
+     * In some processes, a normalization operation is required, to calculate
+     * A^T * A, and overwrite A; also calculate A^T * rhs and overwrite rhs.
+     * This parameter reflect whether such operation is required.
+     */
+    virtual bool requiresNormalization() const = 0;
+
     /*! Prepares a new iteration in the solution process of this equation.
      *
      * \param iter the current iteration number, starting from 1.
diff --git a/NumLib/ODESolver/MatrixTranslator.cpp b/NumLib/ODESolver/MatrixTranslator.cpp
index 5be630904e1..c86edbba043 100644
--- a/NumLib/ODESolver/MatrixTranslator.cpp
+++ b/NumLib/ODESolver/MatrixTranslator.cpp
@@ -43,6 +43,25 @@ void MatrixTranslatorGeneral<ODESystemTag::FirstOrderImplicitQuasilinear>::
     NumLib::GlobalVectorProvider::provider.releaseVector(tmp);
 }
 
+void MatrixTranslatorGeneral<ODESystemTag::FirstOrderImplicitQuasilinear>::
+    normalizeAandRhs(GlobalMatrix& A, GlobalVector& b) const
+{
+    namespace LinAlg = MathLib::LinAlg;
+
+    // check whether A is square?
+
+    GlobalMatrix new_A(A);
+    GlobalVector new_b(b);
+    LinAlg::copy(A, new_A);
+    LinAlg::copy(b, new_b);
+    // rhs = A^T * rhs
+    // A = A^T * A
+    LinAlg::linearSysNormalize(A, new_A, b, new_b);
+
+    LinAlg::copy(new_A, A);
+    LinAlg::copy(new_b, b);
+}
+
 void MatrixTranslatorGeneral<ODESystemTag::FirstOrderImplicitQuasilinear>::
     computeResidual(GlobalMatrix const& M, GlobalMatrix const& K,
                     GlobalVector const& b, double const dt,
diff --git a/NumLib/ODESolver/MatrixTranslator.h b/NumLib/ODESolver/MatrixTranslator.h
index b3752586c0b..60cc11dfcf5 100644
--- a/NumLib/ODESolver/MatrixTranslator.h
+++ b/NumLib/ODESolver/MatrixTranslator.h
@@ -49,6 +49,10 @@ public:
                             const GlobalVector& b, const GlobalVector& x_prev,
                             GlobalVector& rhs) const = 0;
 
+    //! Computes \f$ A = A^T \cdot A \f$, and
+    //! also \f$ rhs = A^T \cdot rhs \f$.
+    virtual void normalizeAandRhs(GlobalMatrix& A, GlobalVector& b) const = 0;
+
     /*! Computes \c res from \c M, \c K, \c b, \f$ \hat x \f$ and \f$ x_N \f$.
      * You might also want read the remarks on
      * \ref concept_time_discretization "time discretization".
@@ -105,6 +109,10 @@ public:
                     const GlobalVector& b, const GlobalVector& x_prev,
                     GlobalVector& rhs) const override;
 
+    //! Computes \f$ A = A^T \cdot A \f$, and
+    //! also \f$ rhs = A^T \cdot rhs \f$.
+    void normalizeAandRhs(GlobalMatrix& A, GlobalVector& b) const override;
+
     //! Computes \f$ r = M \cdot \hat x + K \cdot x_C - b \f$.
     void computeResidual(GlobalMatrix const& M, GlobalMatrix const& K,
                          GlobalVector const& b, double dt,
diff --git a/NumLib/ODESolver/NonlinearSolver.cpp b/NumLib/ODESolver/NonlinearSolver.cpp
index 324c11905c6..59c0886b55b 100644
--- a/NumLib/ODESolver/NonlinearSolver.cpp
+++ b/NumLib/ODESolver/NonlinearSolver.cpp
@@ -167,6 +167,11 @@ NonlinearSolverStatus NonlinearSolver<NonlinearSolverTag::Picard>::solve(
         sys.assemble(x_new, x_prev, process_id);
         sys.getA(A);
         sys.getRhs(*x_prev[process_id], rhs);
+
+        // Normalize the linear equation system, if required
+        if (sys.requiresNormalization())
+            sys.getAandRhsNormalized(A, rhs);
+
         INFO("[time] Assembly took {:g} s.", time_assembly.elapsed());
 
         // Subtract non-equilibrium initial residuum if set
diff --git a/NumLib/ODESolver/NonlinearSystem.h b/NumLib/ODESolver/NonlinearSystem.h
index e3b20f6e12d..82d9681525e 100644
--- a/NumLib/ODESolver/NonlinearSystem.h
+++ b/NumLib/ODESolver/NonlinearSystem.h
@@ -121,6 +121,12 @@ public:
     virtual void getRhs(GlobalVector const& x_prev,
                         GlobalVector& rhs) const = 0;
 
+    //! Writes the A_transposed times A into \c A
+    //! and also writes A_transposed times rhs into \c rhs
+    //! \pre getA() and getRhs must have been called before.
+    virtual void getAandRhsNormalized(GlobalMatrix& A,
+                                      GlobalVector& rhs) const = 0;
+
     //! Pre-compute known solutions and possibly store them internally.
     virtual void computeKnownSolutions(GlobalVector const& x,
                                        int const process_id) = 0;
diff --git a/NumLib/ODESolver/TimeDiscretizedODESystem.h b/NumLib/ODESolver/TimeDiscretizedODESystem.h
index 4fffd07cbcb..e93751dd9d8 100644
--- a/NumLib/ODESolver/TimeDiscretizedODESystem.h
+++ b/NumLib/ODESolver/TimeDiscretizedODESystem.h
@@ -120,6 +120,11 @@ public:
 
     bool isLinear() const override { return _ode.isLinear(); }
 
+    bool requiresNormalization() const override
+    {
+        return _ode.requiresNormalization();
+    }
+
     void preIteration(const unsigned iter, GlobalVector const& x) override
     {
         _ode.preIteration(iter, x);
@@ -207,6 +212,11 @@ public:
         _mat_trans->computeRhs(*_M, *_K, *_b, x_prev, rhs);
     }
 
+    void getAandRhsNormalized(GlobalMatrix& A, GlobalVector& rhs) const override
+    {
+        _mat_trans->normalizeAandRhs(A, rhs);
+    }
+
     void computeKnownSolutions(GlobalVector const& x,
                                int const process_id) override;
 
@@ -217,6 +227,11 @@ public:
 
     bool isLinear() const override { return _ode.isLinear(); }
 
+    bool requiresNormalization() const override
+    {
+        return _ode.requiresNormalization();
+    }
+
     void preIteration(const unsigned iter, GlobalVector const& x) override
     {
         _ode.preIteration(iter, x);
diff --git a/Tests/NumLib/ODEs.h b/Tests/NumLib/ODEs.h
index 98d7d61b869..2f3d2e66234 100644
--- a/Tests/NumLib/ODEs.h
+++ b/Tests/NumLib/ODEs.h
@@ -115,6 +115,8 @@ public:
         return true;
     }
 
+    bool requiresNormalization() const override { return false; }
+
     std::size_t const N = 2;
 };
 
@@ -205,6 +207,8 @@ public:
         return false;
     }
 
+    bool requiresNormalization() const override { return false; }
+
     std::size_t const N = 1;
 };
 
@@ -333,6 +337,8 @@ public:
         return false;
     }
 
+    bool requiresNormalization() const override { return false; }
+
     std::size_t const N = 2;
 };
 
-- 
GitLab