Skip to content
Snippets Groups Projects
Commit 02dd458b authored by Christoph Lehmann's avatar Christoph Lehmann Committed by Dmitri Naumov
Browse files

[MaL] renamed and moved FunctionHandles

parent 50d61d9e
No related branches found
No related tags found
No related merge requests found
...@@ -90,7 +90,7 @@ public: ...@@ -90,7 +90,7 @@ public:
~CVodeSolverImpl(); ~CVodeSolverImpl();
private: private:
void setFunction(std::unique_ptr<FunctionHandles>&& f); void setFunction(std::unique_ptr<detail::FunctionHandles>&& f);
void preSolve(); void preSolve();
void solve(const double t_end); void solve(const double t_end);
...@@ -113,7 +113,7 @@ private: ...@@ -113,7 +113,7 @@ private:
unsigned _num_equations; unsigned _num_equations;
void* _cvode_mem; void* _cvode_mem;
std::unique_ptr<FunctionHandles> _f; std::unique_ptr<detail::FunctionHandles> _f;
int _linear_multistep_method = CV_ADAMS; int _linear_multistep_method = CV_ADAMS;
int _nonlinear_solver_iteration = CV_FUNCTIONAL; int _nonlinear_solver_iteration = CV_FUNCTIONAL;
...@@ -178,8 +178,9 @@ CVodeSolverImpl::CVodeSolverImpl(const BaseLib::ConfigTree& config, ...@@ -178,8 +178,9 @@ CVodeSolverImpl::CVodeSolverImpl(const BaseLib::ConfigTree& config,
auto f_wrapped = [](const realtype t, const N_Vector y, N_Vector ydot, auto f_wrapped = [](const realtype t, const N_Vector y, N_Vector ydot,
void* function_handles) -> int void* function_handles) -> int
{ {
bool successful = static_cast<FunctionHandles*>(function_handles) bool successful =
->call(t, NV_DATA_S(y), NV_DATA_S(ydot)); static_cast<detail::FunctionHandles*>(function_handles)
->call(t, NV_DATA_S(y), NV_DATA_S(ydot));
return successful ? 0 : 1; return successful ? 0 : 1;
}; };
...@@ -206,7 +207,7 @@ void CVodeSolverImpl::setTolerance(const double abstol, const double reltol) ...@@ -206,7 +207,7 @@ void CVodeSolverImpl::setTolerance(const double abstol, const double reltol)
_reltol = reltol; _reltol = reltol;
} }
void CVodeSolverImpl::setFunction(std::unique_ptr<FunctionHandles>&& f) void CVodeSolverImpl::setFunction(std::unique_ptr<detail::FunctionHandles>&& f)
{ {
_f = std::move(f); _f = std::move(f);
assert(_num_equations == _f->getNumEquations()); assert(_num_equations == _f->getNumEquations());
...@@ -251,7 +252,7 @@ void CVodeSolverImpl::preSolve() ...@@ -251,7 +252,7 @@ void CVodeSolverImpl::preSolve()
// Caution: by calling the DENSE_COL() macro we assume that matrices // Caution: by calling the DENSE_COL() macro we assume that matrices
// are stored contiguously in memory! // are stored contiguously in memory!
bool successful = bool successful =
static_cast<FunctionHandles*>(function_handles) static_cast<detail::FunctionHandles*>(function_handles)
->callJacobian(t, NV_DATA_S(y), NV_DATA_S(ydot), ->callJacobian(t, NV_DATA_S(y), NV_DATA_S(ydot),
DENSE_COL(jac, 0)); DENSE_COL(jac, 0));
return successful ? 0 : 1; return successful ? 0 : 1;
...@@ -309,7 +310,7 @@ void CVodeSolver::setTolerance(const double abstol, const double reltol) ...@@ -309,7 +310,7 @@ void CVodeSolver::setTolerance(const double abstol, const double reltol)
_impl->setTolerance(abstol, reltol); _impl->setTolerance(abstol, reltol);
} }
void CVodeSolver::setFunction(std::unique_ptr<FunctionHandles>&& f) void CVodeSolver::setFunction(std::unique_ptr<detail::FunctionHandles>&& f)
{ {
_impl->setFunction(std::move(f)); _impl->setFunction(std::move(f));
} }
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
#include "BaseLib/ConfigTree.h" #include "BaseLib/ConfigTree.h"
#include "OdeSolverTypes.h" #include "OdeSolverTypes.h"
#include "FunctionHandles.h"
namespace MathLib namespace MathLib
{ {
...@@ -33,7 +34,7 @@ protected: ...@@ -33,7 +34,7 @@ protected:
void setTolerance(double const* const abstol, const double reltol); void setTolerance(double const* const abstol, const double reltol);
void setTolerance(const double abstol, const double reltol); void setTolerance(const double abstol, const double reltol);
void setFunction(std::unique_ptr<FunctionHandles>&& f); void setFunction(std::unique_ptr<detail::FunctionHandles>&& f);
void setIC(const double t0, double const* const y0); void setIC(const double t0, double const* const y0);
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
#include "BaseLib/ConfigTree.h" #include "BaseLib/ConfigTree.h"
#include "OdeSolver.h" #include "OdeSolver.h"
#include "Handles.h" #include "FunctionHandles.h"
#ifdef CVODE_FOUND #ifdef CVODE_FOUND
#include "CVodeSolver.h" #include "CVodeSolver.h"
...@@ -53,9 +53,8 @@ public: ...@@ -53,9 +53,8 @@ public:
void setFunction(Function f, JacobianFunction df) override void setFunction(Function f, JacobianFunction df) override
{ {
Implementation::setFunction( Implementation::setFunction(
std::unique_ptr< // TODO unique_ptr not needed std::unique_ptr<detail::FunctionHandlesImpl<NumEquations>>{
detail::Handles<NumEquations>>{ new detail::FunctionHandlesImpl<NumEquations>{f, df}});
new detail::Handles<NumEquations>{f, df}});
} }
void setTolerance(const std::array<double, NumEquations>& abstol, void setTolerance(const std::array<double, NumEquations>& abstol,
......
...@@ -16,14 +16,31 @@ namespace MathLib ...@@ -16,14 +16,31 @@ namespace MathLib
{ {
namespace detail namespace detail
{ {
class FunctionHandles
{
public:
virtual bool call(const double t, double const* const y,
double* const ydot) = 0;
virtual bool callJacobian(const double t,
double const* const y,
double* const ydot,
double* const jac) = 0;
virtual bool hasJacobian() const = 0;
virtual unsigned getNumEquations() const = 0;
virtual ~FunctionHandles() = default;
};
/// Function handles for N equations. /// Function handles for N equations.
template <unsigned N> template <unsigned N>
struct Handles : public MathLib::FunctionHandles struct FunctionHandlesImpl : FunctionHandles
{ {
using Function = MathLib::Function<N>; using Function = MathLib::Function<N>;
using JacobianFunction = MathLib::JacobianFunction<N>; using JacobianFunction = MathLib::JacobianFunction<N>;
Handles(Function& f, JacobianFunction& df) : f(f), df(df) {} FunctionHandlesImpl(Function& f, JacobianFunction& df) : f(f), df(df) {}
bool call(const double t, const double* const y, bool call(const double t, const double* const y,
double* const ydot) override double* const ydot) override
{ {
......
...@@ -38,23 +38,6 @@ using JacobianFunction = std::function<bool(const double t, ...@@ -38,23 +38,6 @@ using JacobianFunction = std::function<bool(const double t,
MappedConstVector<N> ydot, MappedConstVector<N> ydot,
MappedMatrix<N, N> jac)>; MappedMatrix<N, N> jac)>;
// This is an internal detail } // namespace MathLib
class FunctionHandles
{
public:
virtual bool call(const double t, double const* const y,
double* const ydot) = 0;
virtual bool callJacobian(const double t,
double const* const y,
double* const ydot,
double* const jac) = 0;
virtual bool hasJacobian() const = 0;
virtual unsigned getNumEquations() const = 0;
virtual ~FunctionHandles() = default;
};
}
#endif // MATHLIB_ODE_ODESOLVERTYPES_H #endif // MATHLIB_ODE_ODESOLVERTYPES_H
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment