Skip to content
Snippets Groups Projects
Forked from ogs / ogs
9865 commits behind the upstream repository.
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
UnifiedMatrixSetters.cpp 4.36 KiB
/**
 * \file
 * \copyright
 * Copyright (c) 2012-2021, OpenGeoSys Community (http://www.opengeosys.org)
 *            Distributed under a Modified BSD License.
 *              See accompanying file LICENSE.txt or
 *              http://www.opengeosys.org/project/license
 *
 */

#include <cassert>
#include "UnifiedMatrixSetters.h"

#ifdef USE_PETSC

// Global PETScMatrix/PETScVector //////////////////////////////////////////

#include <numeric>
#include "MathLib/LinAlg/PETSc/PETScVector.h"
#include "MathLib/LinAlg/PETSc/PETScMatrix.h"

namespace MathLib
{
void setVector(PETScVector& v,
               std::initializer_list<double> values)
{
    std::vector<double> const vals(values);
    std::vector<PETScVector::IndexType> idcs(vals.size());
    std::iota(idcs.begin(), idcs.end(), 0);

    v.set(idcs, vals);
}

void setVector(PETScVector& v, MatrixVectorTraits<PETScVector>::Index const index,
               double const value)
{
    v.set(index, value); // TODO handle negative indices
}

void setMatrix(PETScMatrix& m,
               std::initializer_list<double> values)
{
    m.setZero();
    addToMatrix(m, values);
}

void setMatrix(PETScMatrix& m, Eigen::MatrixXd const& tmp)
{
    using IndexType = PETScMatrix::IndexType;

    auto const rows = tmp.rows();
    auto const cols = tmp.cols();

    assert(rows == m.getNumberOfRows() && cols == m.getNumberOfColumns());

    m.setZero();
    std::vector<IndexType> row_idcs(rows);
    std::vector<IndexType> col_idcs(cols);

    std::iota(row_idcs.begin(), row_idcs.end(), 0);
    std::iota(col_idcs.begin(), col_idcs.end(), 0);

    // PETSc wants row-major
    Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> tmp_ = tmp;

    m.add(row_idcs, col_idcs, tmp_);
}

void addToMatrix(PETScMatrix& m,
                 std::initializer_list<double> values)