Skip to content
Snippets Groups Projects
Commit b4c4a06b authored by Christoph Lehmann's avatar Christoph Lehmann Committed by Dmitri Naumov
Browse files

[MaL] use std::function

parent cc666761
No related branches found
No related tags found
No related merge requests found
...@@ -23,8 +23,8 @@ ...@@ -23,8 +23,8 @@
namespace MathLib namespace MathLib
{ {
template <unsigned NumEquations, typename... FunctionArguments> template <unsigned NumEquations>
std::unique_ptr<OdeSolver<NumEquations, FunctionArguments...>> createOdeSolver( std::unique_ptr<OdeSolver<NumEquations>> createOdeSolver(
BaseLib::ConfigTree const& config); BaseLib::ConfigTree const& config);
/** /**
...@@ -41,25 +41,21 @@ std::unique_ptr<OdeSolver<NumEquations, FunctionArguments...>> createOdeSolver( ...@@ -41,25 +41,21 @@ std::unique_ptr<OdeSolver<NumEquations, FunctionArguments...>> createOdeSolver(
* *
* This way the \c Implementation does not need to be templated. * This way the \c Implementation does not need to be templated.
*/ */
template <typename Implementation, unsigned NumEquations, template <typename Implementation, unsigned NumEquations>
typename... FunctionArguments> class ConcreteOdeSolver final : public OdeSolver<NumEquations>,
class ConcreteOdeSolver final private Implementation
: public OdeSolver<NumEquations, FunctionArguments...>,
private Implementation
{ {
public: public:
using Interface = OdeSolver<NumEquations, FunctionArguments...>; using Interface = OdeSolver<NumEquations>;
using Function = typename Interface::Function; using Function = typename Interface::Function;
using JacobianFunction = typename Interface::JacobianFunction; using JacobianFunction = typename Interface::JacobianFunction;
void setFunction(Function f, JacobianFunction df, void setFunction(Function f, JacobianFunction df) override
FunctionArguments&... args) override
{ {
Implementation::setFunction( Implementation::setFunction(
std::unique_ptr< std::unique_ptr< // TODO unique_ptr not needed
detail::Handles<NumEquations, FunctionArguments...>>{ detail::Handles<NumEquations>>{
new detail::Handles<NumEquations, FunctionArguments...>{ new detail::Handles<NumEquations>{f, df}});
f, df, args...}});
} }
void setTolerance(const std::array<double, NumEquations>& abstol, void setTolerance(const std::array<double, NumEquations>& abstol,
...@@ -111,19 +107,17 @@ private: ...@@ -111,19 +107,17 @@ private:
{ {
} }
friend std::unique_ptr<OdeSolver<NumEquations, FunctionArguments...>> friend std::unique_ptr<OdeSolver<NumEquations>>
createOdeSolver<NumEquations, FunctionArguments...>( createOdeSolver<NumEquations>(BaseLib::ConfigTree const& config);
BaseLib::ConfigTree const& config);
}; };
template <unsigned NumEquations, typename... FunctionArguments> template <unsigned NumEquations>
std::unique_ptr<OdeSolver<NumEquations, FunctionArguments...>> createOdeSolver( std::unique_ptr<OdeSolver<NumEquations>> createOdeSolver(
BaseLib::ConfigTree const& config) BaseLib::ConfigTree const& config)
{ {
#ifdef CVODE_FOUND #ifdef CVODE_FOUND
return std::unique_ptr<OdeSolver<NumEquations, FunctionArguments...>>( return std::unique_ptr<OdeSolver<NumEquations>>(
new ConcreteOdeSolver<CVodeSolver, NumEquations, FunctionArguments...>( new ConcreteOdeSolver<CVodeSolver, NumEquations>(config));
config));
#else #else
return nullptr; return nullptr;
#endif // CVODE_FOUND #endif // CVODE_FOUND
......
...@@ -10,61 +10,15 @@ ...@@ -10,61 +10,15 @@
#ifndef MATHLIB_ODE_HANDLES_H #ifndef MATHLIB_ODE_HANDLES_H
#define MATHLIB_ODE_HANDLES_H #define MATHLIB_ODE_HANDLES_H
#include "OdeSolverTypes.h"
namespace MathLib namespace MathLib
{ {
namespace detail namespace detail
{ {
/// Function handles for N equations and arbitrary arguments. /// Function handles for N equations.
template <unsigned N, typename... FunctionArguments>
struct Handles;
/// Function handles for N equations and single argument.
template <unsigned N, typename FunctionArgument>
struct Handles<N, FunctionArgument> : public MathLib::FunctionHandles
{
using Function = MathLib::Function<N, FunctionArgument>;
using JacobianFunction = MathLib::JacobianFunction<N, FunctionArgument>;
Handles(Function& f, JacobianFunction& df, FunctionArgument& arg)
: f(f), df(df), _data(arg)
{
}
bool call(const double t, const double* const y,
double* const ydot) override
{
// looks like f and df could be any callable object with suitable
// signature
// consider omission of data pointer and switch to std::function or
// alike
if (f)
return f(t, MappedConstVector<N>{y}, MappedVector<N>{ydot}, _data);
return true;
}
bool callJacobian(const double t, const double* const y, double* const ydot,
double* const jac) override
{
if (df)
return df(t,
MappedConstVector<N>{y},
MappedVector<N>{ydot},
MappedMatrix<N, N>{jac},
_data);
return true;
}
bool hasJacobian() const override { return df != nullptr; }
unsigned getNumEquations() const override { return N; }
private:
Function f = nullptr;
JacobianFunction df = nullptr;
FunctionArgument& _data;
};
/// Function handles for N equations and no arguments.
template <unsigned N> template <unsigned N>
struct Handles<N> : public MathLib::FunctionHandles struct Handles : public MathLib::FunctionHandles
{ {
using Function = MathLib::Function<N>; using Function = MathLib::Function<N>;
using JacobianFunction = MathLib::JacobianFunction<N>; using JacobianFunction = MathLib::JacobianFunction<N>;
...@@ -73,11 +27,8 @@ struct Handles<N> : public MathLib::FunctionHandles ...@@ -73,11 +27,8 @@ struct Handles<N> : public MathLib::FunctionHandles
bool call(const double t, const double* const y, bool call(const double t, const double* const y,
double* const ydot) override double* const ydot) override
{ {
if (f) if (f) return f(t, MappedConstVector<N>{y}, MappedVector<N>{ydot});
{ return false;
return f(t, MappedConstVector<N>{y}, MappedVector<N>{ydot});
}
return true;
} }
bool callJacobian(const double t, const double* const y, double* const ydot, bool callJacobian(const double t, const double* const y, double* const ydot,
...@@ -86,15 +37,15 @@ struct Handles<N> : public MathLib::FunctionHandles ...@@ -86,15 +37,15 @@ struct Handles<N> : public MathLib::FunctionHandles
if (df) if (df)
return df(t, return df(t,
MappedConstVector<N>{y}, MappedConstVector<N>{y},
MappedVector<N>{ydot}, MappedConstVector<N>{ydot},
MappedMatrix<N, N>{jac}); MappedMatrix<N, N>{jac});
return true; return false;
} }
bool hasJacobian() const override { return df != nullptr; } bool hasJacobian() const override { return df != nullptr; }
unsigned getNumEquations() const override { return N; } unsigned getNumEquations() const override { return N; }
Function f = nullptr; Function f;
JacobianFunction df = nullptr; JacobianFunction df;
}; };
} // namespace detail } // namespace detail
......
...@@ -23,16 +23,14 @@ namespace MathLib ...@@ -23,16 +23,14 @@ namespace MathLib
* It provides type-safe and array-bounds checked access to external * It provides type-safe and array-bounds checked access to external
* ODE solver libraries. However, it is agnostic to the specific solver used. * ODE solver libraries. However, it is agnostic to the specific solver used.
*/ */
template <unsigned NumEquations, typename... FunctionArguments> template <unsigned NumEquations>
class OdeSolver class OdeSolver
{ {
public: public:
using Function = MathLib::Function<NumEquations, FunctionArguments...>; using Function = MathLib::Function<NumEquations>;
using JacobianFunction = using JacobianFunction = MathLib::JacobianFunction<NumEquations>;
MathLib::JacobianFunction<NumEquations, FunctionArguments...>;
virtual void setFunction(Function f, JacobianFunction df, virtual void setFunction(Function f, JacobianFunction df) = 0;
FunctionArguments&... args) = 0;
virtual void setTolerance(const std::array<double, NumEquations>& abstol, virtual void setTolerance(const std::array<double, NumEquations>& abstol,
const double reltol) = 0; const double reltol) = 0;
......
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#ifndef MATHLIB_ODE_ODESOLVERTYPES_H #ifndef MATHLIB_ODE_ODESOLVERTYPES_H
#define MATHLIB_ODE_ODESOLVERTYPES_H #define MATHLIB_ODE_ODESOLVERTYPES_H
#include <functional>
#include <Eigen/Core> #include <Eigen/Core>
namespace MathLib namespace MathLib
...@@ -27,18 +28,15 @@ using MappedVector = MappedMatrix<N, 1>; ...@@ -27,18 +28,15 @@ using MappedVector = MappedMatrix<N, 1>;
template <int N> template <int N>
using MappedConstVector = MappedConstMatrix<N, 1>; using MappedConstVector = MappedConstMatrix<N, 1>;
template <unsigned N, typename... FunctionArguments> template <unsigned N>
using Function = bool (*)(const double t, using Function = std::function<bool(
MappedConstVector<N> const y, const double t, MappedConstVector<N> const y, MappedVector<N> ydot)>;
MappedVector<N> ydot,
FunctionArguments&... arg); template <unsigned N>
using JacobianFunction = std::function<bool(const double t,
template <unsigned N, typename... FunctionArguments> MappedConstVector<N> const y,
using JacobianFunction = bool (*)(const double t, MappedConstVector<N> ydot,
MappedConstVector<N> const y, MappedMatrix<N, N> jac)>;
MappedVector<N> ydot,
MappedMatrix<N, N> jac,
FunctionArguments&... arg);
// This is an internal detail // This is an internal detail
class FunctionHandles class FunctionHandles
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment