Skip to content
Snippets Groups Projects
Commit 6ad22e5d authored by Christoph Lehmann's avatar Christoph Lehmann
Browse files

[MaL] Separated compute() and solve() step of Eigen linear solvers

parent 60816372
No related branches found
No related tags found
No related merge requests found
......@@ -12,6 +12,7 @@
#include <Eigen/Sparse>
#include "BaseLib/Error.h"
#include "BaseLib/Logging.h"
#ifdef USE_MKL
......@@ -27,7 +28,6 @@
#endif
#include "EigenMatrix.h"
#include "EigenTools.h"
#include "EigenVector.h"
namespace MathLib
......@@ -41,9 +41,49 @@ public:
virtual ~EigenLinearSolverBase() = default;
//! Solves the linear equation system \f$ A x = b \f$ for \f$ x \f$.
virtual bool solve(Matrix& A, Vector const& b, Vector& x,
EigenOption& opt) = 0;
bool solve(Vector& b, Vector& x, EigenOption& opt)
{
#ifdef USE_EIGEN_UNSUPPORTED
if (scaling_)
{
b = scaling_->LeftScaling().cwiseProduct(b);
}
#endif
auto const success = solveImpl(b, x, opt);
if (scaling_)
{
x = scaling_->RightScaling().cwiseProduct(x);
}
return success;
}
bool compute(Matrix& A, EigenOption& opt)
{
#ifdef USE_EIGEN_UNSUPPORTED
if (opt.scaling)
{
INFO("-> scale");
scaling_ = std::make_unique<
Eigen::IterScaling<EigenMatrix::RawMatrixType>>();
scaling_->computeRef(A);
}
#endif
return computeImpl(A, opt);
}
protected:
virtual bool solveImpl(Vector const& b, Vector& x, EigenOption& opt) = 0;
virtual bool computeImpl(Matrix& A, EigenOption& opt) = 0;
private:
#ifdef USE_EIGEN_UNSUPPORTED
std::unique_ptr<Eigen::IterScaling<EigenMatrix::RawMatrixType>> scaling_;
#endif
};
namespace details
......@@ -53,10 +93,25 @@ template <class T_SOLVER>
class EigenDirectLinearSolver final : public EigenLinearSolverBase
{
public:
bool solve(Matrix& A, Vector const& b, Vector& x, EigenOption& opt) override
bool solveImpl(Vector const& b, Vector& x, EigenOption& opt) override
{
INFO("-> solve with Eigen direct linear solver {:s}",
EigenOption::getSolverName(opt.solver_type));
x = solver_.solve(b);
if (solver_.info() != Eigen::Success)
{
ERR("Failed during Eigen linear solve");
return false;
}
return true;
}
bool computeImpl(Matrix& A, EigenOption& opt) override
{
INFO("-> compute with Eigen direct linear solver {:s}",
EigenOption::getSolverName(opt.solver_type));
if (!A.isCompressed())
{
A.makeCompressed();
......@@ -69,13 +124,6 @@ public:
return false;
}
x = solver_.solve(b);
if (solver_.info() != Eigen::Success)
{
ERR("Failed during Eigen linear solve");
return false;
}
return true;
}
......@@ -186,9 +234,9 @@ template <class T_SOLVER>
class EigenIterativeLinearSolver final : public EigenLinearSolverBase
{
public:
bool solve(Matrix& A, Vector const& b, Vector& x, EigenOption& opt) override
bool computeImpl(Matrix& A, EigenOption& opt) override
{
INFO("-> solve with Eigen iterative linear solver {:s} (precon {:s})",
INFO("-> compute with Eigen iterative linear solver {:s} (precon {:s})",
EigenOption::getSolverName(opt.solver_type),
EigenOption::getPreconName(opt.precon_type));
solver_.setTolerance(opt.error_tolerance);
......@@ -207,18 +255,32 @@ public:
T_SOLVER>::setResidualUpdate(opt.residualupdate);
#endif
if (!A.isCompressed())
// matrix must be copied, because Eigen's linear solver stores a
// reference to it cf.
// https://eigen.tuxfamily.org/dox/classEigen_1_1IterativeSolverBase.html#a7dfa55c55e82d697bde227696a630914
A_ = A;
if (!A_.isCompressed())
{
A.makeCompressed();
A_.makeCompressed();
}
solver_.compute(A);
solver_.compute(A_);
if (solver_.info() != Eigen::Success)
{
ERR("Failed during Eigen linear solver initialization");
return false;
}
return true;
}
bool solveImpl(Vector const& b, Vector& x, EigenOption& opt) override
{
INFO("-> solve with Eigen iterative linear solver {:s} (precon {:s})",
EigenOption::getSolverName(opt.solver_type),
EigenOption::getPreconName(opt.precon_type));
x = solver_.solveWithGuess(b, x);
INFO("\t iteration: {:d}/{:d}", solver_.iterations(),
opt.max_iterations);
......@@ -235,6 +297,7 @@ public:
private:
T_SOLVER solver_;
Matrix A_;
#ifdef USE_EIGEN_UNSUPPORTED
void setRestart(int const restart) { setRestartImpl(solver_, restart); }
void setL(int const l) { setLImpl(solver_, l); }
......@@ -401,34 +464,28 @@ EigenLinearSolver::EigenLinearSolver(std::string const& /*solver_name*/,
EigenLinearSolver::~EigenLinearSolver() = default;
bool EigenLinearSolver::solve(EigenMatrix& A, EigenVector& b, EigenVector& x)
bool EigenLinearSolver::compute(EigenMatrix& A)
{
INFO("------------------------------------------------------------------");
INFO("*** Eigen solver computation");
INFO("*** Eigen solver compute()");
#ifdef USE_EIGEN_UNSUPPORTED
std::unique_ptr<Eigen::IterScaling<EigenMatrix::RawMatrixType>> scal;
if (option_.scaling)
{
INFO("-> scale");
scal =
std::make_unique<Eigen::IterScaling<EigenMatrix::RawMatrixType>>();
scal->computeRef(A.getRawMatrix());
b.getRawVector() = scal->LeftScaling().cwiseProduct(b.getRawVector());
}
#endif
auto const success = solver_->solve(A.getRawMatrix(), b.getRawVector(),
x.getRawVector(), option_);
#ifdef USE_EIGEN_UNSUPPORTED
if (scal)
{
x.getRawVector() = scal->RightScaling().cwiseProduct(x.getRawVector());
}
#endif
return solver_->compute(A.getRawMatrix(), option_);
}
bool EigenLinearSolver::solve(EigenVector& b, EigenVector& x)
{
INFO("------------------------------------------------------------------");
INFO("*** Eigen solver solve()");
return success;
return solver_->solve(b.getRawVector(), x.getRawVector(), option_);
INFO("------------------------------------------------------------------");
}
bool EigenLinearSolver::solve(EigenMatrix& A, EigenVector& b, EigenVector& x)
{
return solver_->compute(A.getRawMatrix(), option_) &&
solver_->solve(b.getRawVector(), x.getRawVector(), option_);
}
} // namespace MathLib
......@@ -46,6 +46,24 @@ public:
*/
EigenOption& getOption() { return option_; }
/**
* Performs the compute() step of the Eigen linear solver.
*
* I.e., computes the (LU) decomposition in case of a direct solver, or
* computes the preconditioner of an iterative solver.
*/
bool compute(EigenMatrix& A);
/**
* Solves the linear system for the given right-hand side \c b and initial
* guess \c x.
*
* \pre compute() must have been called before. (Not necessarily for every
* \c x and \c b separately, but for every new/changed matrix A.
*/
bool solve(EigenVector& b, EigenVector& x);
/// Computes and solves in a single call.
bool solve(EigenMatrix& A, EigenVector& b, EigenVector& x);
protected:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment