Skip to content
Snippets Groups Projects
Commit 954d5326 authored by Dmitri Naumov's avatar Dmitri Naumov
Browse files

[MaL] Consistently replace Arr with MappedVector.

All input arguments and return values are now MappedVector
or MappedMatrix (or const). Before it was mixed up
with std::array.

getYDot() now returns a value instead of modifying input
parameter.
parent 429c94a2
No related branches found
No related tags found
No related merge requests found
......@@ -97,7 +97,7 @@ private:
double const* getSolution() const { return NV_DATA_S(_y); }
double getTime() const { return _t; }
bool getYDot(const double t, double const* const y, double* const ydot);
void getYDot(const double t, double const* const y, double* const y_dot);
void setTolerance(const double* abstol, const double reltol);
void setTolerance(const double abstol, const double reltol);
void setIC(const double t0, double const* const y0);
......@@ -277,15 +277,11 @@ void CVodeSolverImpl::solve(const double t_end)
}
}
bool CVodeSolverImpl::getYDot(const double t, double const* const y,
double* const ydot)
void CVodeSolverImpl::getYDot(const double t, double const* const y,
double* const y_dot)
{
if (_f != nullptr)
{
return _f->call(t, y, ydot);
}
return false;
assert(_f != nullptr);
_f->call(t, y, y_dot);
}
CVodeSolverImpl::~CVodeSolverImpl()
......@@ -346,10 +342,10 @@ double const* CVodeSolverInternal::getSolution() const
return _impl->getSolution();
}
bool CVodeSolverInternal::getYDot(const double t, double const* const y,
double* const ydot) const
void CVodeSolverInternal::getYDot(const double t, double const* const y,
double* const y_dot) const
{
return _impl->getYDot(t, y, ydot);
_impl->getYDot(t, y, y_dot);
}
double CVodeSolverInternal::getTime() const
......
......@@ -42,8 +42,9 @@ protected:
double const* getSolution() const;
double getTime() const;
bool getYDot(const double t, double const* const y,
double* const ydot) const;
void getYDot(const double t,
double const* const y,
double* const y_dot) const;
~CVodeSolverInternal();
......
......@@ -27,8 +27,6 @@ template <unsigned NumEquations, typename... FunctionArguments>
class OdeSolver
{
public:
using Arr = std::array<double, NumEquations>;
using ConstArrRef = MappedConstVector<NumEquations>;
using Function = MathLib::Function<NumEquations, FunctionArguments...>;
using JacobianFunction =
MathLib::JacobianFunction<NumEquations, FunctionArguments...>;
......@@ -36,18 +34,23 @@ public:
virtual void setFunction(Function f, JacobianFunction df,
FunctionArguments&... args) = 0;
virtual void setTolerance(const Arr& abstol, const double reltol) = 0;
virtual void setTolerance(const std::array<double, NumEquations>& abstol,
const double reltol) = 0;
virtual void setTolerance(const double abstol, const double reltol) = 0;
virtual void setIC(const double t0, const Arr& y0) = 0;
virtual void setIC(const double t0,
std::array<double, NumEquations> const& y0) = 0;
virtual void setIC(const double t0,
Eigen::Matrix<double, NumEquations, 1> const& y0) = 0;
virtual void preSolve() = 0;
virtual void solve(const double t) = 0;
virtual unsigned getNumEquations() const { return NumEquations; }
virtual ConstArrRef getSolution() const = 0;
virtual MappedConstVector<NumEquations> getSolution() const = 0;
virtual double getTime() const = 0;
virtual Arr getYDot(const double t, const Arr& y) const = 0;
virtual Eigen::Matrix<double, NumEquations, 1> getYDot(
const double t, const MappedConstVector<NumEquations>& y) const = 0;
virtual ~OdeSolver() = default;
};
......
......@@ -132,8 +132,6 @@ class ConcreteOdeSolver final
{
public:
using Interface = OdeSolver<NumEquations, FunctionArguments...>;
using Arr = typename Interface::Arr;
using ConstArrRef = typename Interface::ConstArrRef;
using Function = typename Interface::Function;
using JacobianFunction = typename Interface::JacobianFunction;
......@@ -147,7 +145,8 @@ public:
f, df, args...}});
}
void setTolerance(const Arr& abstol, const double reltol) override
void setTolerance(const std::array<double, NumEquations>& abstol,
const double reltol) override
{
Implementation::setTolerance(abstol.data(), reltol);
}
......@@ -157,24 +156,32 @@ public:
Implementation::setTolerance(abstol, reltol);
}
void setIC(const double t0, const Arr& y0) override
void setIC(const double t0,
std::array<double, NumEquations> const& y0) override
{
Implementation::setIC(t0, y0.data());
}
void setIC(const double t0,
Eigen::Matrix<double, NumEquations, 1> const& y0) override
{
Implementation::setIC(t0, y0.data());
}
void preSolve() override { Implementation::preSolve(); }
void solve(const double t) override { Implementation::solve(t); }
ConstArrRef getSolution() const override
MappedConstVector<NumEquations> getSolution() const override
{
return ConstArrRef(Implementation::getSolution());
return MappedConstVector<NumEquations>{Implementation::getSolution()};
}
double getTime() const override { return Implementation::getTime(); }
Arr getYDot(const double t, const Arr& y) const override
Eigen::Matrix<double, NumEquations, 1> getYDot(
const double t, const MappedConstVector<NumEquations>& y) const override
{
Arr ydot;
Implementation::getYDot(t, y.data(), ydot.data());
return ydot;
Eigen::Matrix<double, NumEquations, 1> y_dot;
Implementation::getYDot(t, y.data(), y_dot.data());
return y_dot;
}
private:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment