From 1ef741517ff9e39bd1218a142390e845d1dc005f Mon Sep 17 00:00:00 2001
From: Dmitri Naumov <dmitri.naumov@ufz.de>
Date: Fri, 9 Aug 2024 13:52:07 +0200
Subject: [PATCH] [PL] Add staggered scheme assembly to parallel asm

This is mostly copy from the VectorMatrixAssembler.
---
 .../ParallelVectorMatrixAssembler.cpp         | 121 ++++++++++++++++--
 1 file changed, 113 insertions(+), 8 deletions(-)

diff --git a/ProcessLib/Assembly/ParallelVectorMatrixAssembler.cpp b/ProcessLib/Assembly/ParallelVectorMatrixAssembler.cpp
index e46377bb59d..14b9d24a6b9 100644
--- a/ProcessLib/Assembly/ParallelVectorMatrixAssembler.cpp
+++ b/ProcessLib/Assembly/ParallelVectorMatrixAssembler.cpp
@@ -24,6 +24,7 @@
 #include "ProcessLib/Assembly/MatrixAssemblyStats.h"
 #include "ProcessLib/Assembly/MatrixElementCache.h"
 #include "ProcessLib/Assembly/MatrixOutput.h"
+#include "ProcessLib/CoupledSolutionsForStaggeredScheme.h"
 
 namespace
 {
@@ -59,6 +60,52 @@ void assembleWithJacobianOneElement(
     cache.add(local_b_data, local_Jac_data, indices);
 }
 
+void assembleWithJacobianForStaggeredSchemeOneElement(
+    const std::size_t mesh_item_id,
+    ProcessLib::LocalAssemblerInterface& local_assembler,
+    std::vector<NumLib::LocalToGlobalIndexMap const*> const& dof_tables,
+    const double t, const double dt, std::vector<GlobalVector*> const& x,
+    std::vector<GlobalVector*> const& x_prev, int const process_id,
+    std::vector<double>& local_b_data, std::vector<double>& local_Jac_data,
+    ProcessLib::AbstractJacobianAssembler& jacobian_assembler,
+    ProcessLib::Assembly::MultiMatrixElementCache& cache)
+{
+    std::vector<std::vector<GlobalIndexType>> indices_of_processes;
+    indices_of_processes.reserve(dof_tables.size());
+    transform(cbegin(dof_tables), cend(dof_tables),
+              back_inserter(indices_of_processes),
+              [&](auto const* dof_table)
+              { return NumLib::getIndices(mesh_item_id, *dof_table); });
+
+    auto local_coupled_xs =
+        ProcessLib::getCoupledLocalSolutions(x, indices_of_processes);
+    auto const local_x = MathLib::toVector(local_coupled_xs);
+
+    auto local_coupled_x_prevs =
+        ProcessLib::getCoupledLocalSolutions(x_prev, indices_of_processes);
+    auto const local_x_prev = MathLib::toVector(local_coupled_x_prevs);
+
+    std::vector<GlobalIndexType> const& indices =
+        indices_of_processes[process_id];
+
+    local_b_data.clear();
+    local_Jac_data.clear();
+
+    jacobian_assembler.assembleWithJacobianForStaggeredScheme(
+        local_assembler, t, dt, local_x, local_x_prev, process_id, local_b_data,
+        local_Jac_data);
+
+    if (local_Jac_data.empty())
+    {
+        OGS_FATAL(
+            "No Jacobian has been assembled! This might be due to "
+            "programming errors in the local assembler of the "
+            "current process.");
+    }
+
+    cache.add(local_b_data, local_Jac_data, indices);
+}
+
 /// Returns a vector of active element ids with corresponding local assemblers.
 std::vector<
     std::tuple<std::ptrdiff_t,
@@ -129,6 +176,49 @@ void runAssemblyForEachLocalAssembler(
     }
 }
 
+void runStaggeredAssemblyForEachLocalAssembler(
+    std::vector<std::tuple<
+        std::ptrdiff_t,
+        std::reference_wrapper<ProcessLib::LocalAssemblerInterface>>> const&
+        ids_local_assemblers,
+    std::vector<NumLib::LocalToGlobalIndexMap const*> const& dof_tables,
+    double const t, double const dt, std::vector<GlobalVector*> const& x,
+    std::vector<GlobalVector*> const& x_prev, int const process_id,
+    std::vector<double>& local_b_data, std::vector<double>& local_Jac_data,
+    ProcessLib::AbstractJacobianAssembler& jac_asm, ThreadException& exception,
+    ProcessLib::Assembly::MultiMatrixElementCache& cache,
+    auto local_matrix_output)
+{
+    // due to MSVC++ error:
+    // error C3016: 'element_id': index variable in OpenMP 'for'
+    // statement must have signed integral type
+    std::ptrdiff_t n_elements =
+        static_cast<std::ptrdiff_t>(ids_local_assemblers.size());
+#pragma omp for nowait
+    for (std::ptrdiff_t i = 0; i < n_elements; ++i)
+    {
+        if (exception)
+        {
+            continue;
+        }
+        auto [element_id, loc_asm] = ids_local_assemblers[i];
+
+        try
+        {
+            assembleWithJacobianForStaggeredSchemeOneElement(
+                element_id, loc_asm, dof_tables, t, dt, x, x_prev, process_id,
+                local_b_data, local_Jac_data, jac_asm, cache);
+        }
+        catch (...)
+        {
+            exception.capture();
+            continue;
+        }
+
+        local_matrix_output(element_id);
+    }
+}
+
 int getNumberOfThreads()
 {
     char const* const num_threads_env = std::getenv("OGS_ASM_THREADS");
@@ -200,10 +290,7 @@ void ParallelVectorMatrixAssembler::assembleWithJacobian(
         OGS_FATAL("Different number of DOF tables and solution vectors.");
     }
 
-    auto const& dof_table = *dof_tables[process_id];
-    auto const& x = *xs[process_id];
-    auto const& x_prev = *x_prevs[process_id];
-
+    std::size_t const number_of_processes = xs.size();
     // algorithm ///////////////////////////////////////////////////////////////
 
     auto stats = CumulativeStats<MultiStats>::create();
@@ -239,10 +326,28 @@ void ParallelVectorMatrixAssembler::assembleWithJacobian(
         };
 
         // TODO corner case: what if all elements on a submesh are deactivated?
-        runAssemblyForEachLocalAssembler(
-            collectActiveLocalAssemblers(local_assemblers, active_elements),
-            dof_table, t, dt, x, x_prev, local_b_data, local_Jac_data, *jac_asm,
-            exception, cache, local_matrix_output);
+
+        // Monolithic scheme
+        if (number_of_processes == 1)
+        {
+            assert(process_id == 0);
+            auto const& dof_table = *dof_tables[0];
+            auto const& x = *xs[0];
+            auto const& x_prev = *x_prevs[0];
+
+            runAssemblyForEachLocalAssembler(
+                collectActiveLocalAssemblers(local_assemblers, active_elements),
+                dof_table, t, dt, x, x_prev, local_b_data, local_Jac_data,
+                *jac_asm, exception, cache, local_matrix_output);
+        }
+        else  // Staggered scheme
+        {
+            runStaggeredAssemblyForEachLocalAssembler(
+                collectActiveLocalAssemblers(local_assemblers, active_elements),
+                dof_tables, t, dt, xs, x_prevs, process_id, local_b_data,
+                local_Jac_data, *jac_asm, exception, cache,
+                local_matrix_output);
+        }
     }
 
     stats->print();
-- 
GitLab