diff --git a/MathLib/ODE/OdeSolver.h b/MathLib/ODE/OdeSolver.h index 227d32d14f0b0c9eece1366599bb8d461c009a71..6425eea1befcb1e58cfa1db1de4d96508cec2f34 100644 --- a/MathLib/ODE/OdeSolver.h +++ b/MathLib/ODE/OdeSolver.h @@ -40,7 +40,7 @@ public: virtual void setTolerance(const double abstol, const double reltol) = 0; virtual void setFunction(Function f, JacobianFunction df, - FunctionArguments*... args) = 0; + FunctionArguments&... args) = 0; virtual void setIC(const double t0, const Arr& y0) = 0; diff --git a/MathLib/ODE/OdeSolverFactory.h b/MathLib/ODE/OdeSolverFactory.h index 4330db613c75a061012b4c1b6421bba672f8632c..b4fc23b4a91413c204e5df61daeb303616cd2939 100644 --- a/MathLib/ODE/OdeSolverFactory.h +++ b/MathLib/ODE/OdeSolverFactory.h @@ -31,6 +31,11 @@ struct Handles<N, FunctionArgument> : public MathLib::FunctionHandles using Function = MathLib::Function<N, FunctionArgument>; using JacobianFunction = MathLib::JacobianFunction<N, FunctionArgument>; + Handles(Function& f, JacobianFunction& df, FunctionArgument& arg) + : f(f), df(df), _data(arg) + { + } + bool call(const double t, const double* const y, double* const ydot) override { @@ -42,7 +47,7 @@ struct Handles<N, FunctionArgument> : public MathLib::FunctionHandles return f(t, Eigen::Map<const Eigen::Matrix<double, N, 1>>{y}, Eigen::Map<Eigen::Matrix<double, N, 1>>{ydot}, - *_data); + _data); return true; } @@ -54,24 +59,16 @@ struct Handles<N, FunctionArgument> : public MathLib::FunctionHandles Eigen::Map<const Eigen::Matrix<double, N, 1>>{y}, Eigen::Map<Eigen::Matrix<double, N, 1>>{ydot}, Eigen::Map<Eigen::Matrix<double, N, N>>{jac /*, order*/}, - *_data); + _data); return true; } bool hasJacobian() const override { return df != nullptr; } unsigned getNumEquations() const override { return N; } - void setArguments(FunctionArgument* arg) - { - assert(arg != nullptr); - _data = arg; - } - - // TODO: make private +private: Function f = nullptr; JacobianFunction df = nullptr; - -private: - FunctionArgument* _data = nullptr; + FunctionArgument& _data; }; template <unsigned N> @@ -80,6 +77,7 @@ struct Handles<N> : public MathLib::FunctionHandles using Function = MathLib::Function<N>; using JacobianFunction = MathLib::JacobianFunction<N>; + Handles(Function& f, JacobianFunction& df) : f(f), df(df) {} bool call(const double t, const double* const y, double* const ydot) override { @@ -107,7 +105,6 @@ struct Handles<N> : public MathLib::FunctionHandles bool hasJacobian() const override { return df != nullptr; } unsigned getNumEquations() const override { return N; } - void setArguments() const {} Function f = nullptr; JacobianFunction df = nullptr; }; @@ -148,7 +145,7 @@ public: void init() override { Implementation::init(NumEquations); - Implementation::setFunction(&_handles); + Implementation::setFunction(_handles.get()); } void setTolerance(const Arr& abstol, const double reltol) override @@ -162,11 +159,10 @@ public: } void setFunction(Function f, JacobianFunction df, - FunctionArguments*... args) override + FunctionArguments&... args) override { - _handles.f = f; - _handles.df = df; - _handles.setArguments(args...); + _handles.reset(new detail::Handles<NumEquations, FunctionArguments...>{ + f, df, args...}); } void setIC(const double t0, const Arr& y0) override @@ -197,7 +193,8 @@ private: { } - detail::Handles<NumEquations, FunctionArguments...> _handles; + std::unique_ptr<detail::Handles<NumEquations, FunctionArguments...>> + _handles; friend std::unique_ptr<OdeSolver<NumEquations, FunctionArguments...>> createOdeSolver<NumEquations, FunctionArguments...>( diff --git a/Tests/MathLib/TestCVode.cpp b/Tests/MathLib/TestCVode.cpp index 71e60951f6cf0878fa4f70145667964ac8709342..17a9beb7b830569e147b62ca83712e1e892f73ae 100644 --- a/Tests/MathLib/TestCVode.cpp +++ b/Tests/MathLib/TestCVode.cpp @@ -63,11 +63,10 @@ TEST(MathLibCVodeTest, Exponential) BaseLib::ConfigTree::onwarning); auto ode_solver = MathLib::createOdeSolver<1>(config); + ode_solver->setFunction(f, nullptr); ode_solver->init(); ode_solver->setTolerance(1e-8, 1e-6); - ode_solver->setFunction(f, nullptr); - ode_solver->setIC(t0, {y0}); ode_solver->preSolve(); @@ -102,12 +101,12 @@ TEST(MathLibCVodeTest, ExponentialExtraData) BaseLib::ConfigTree::onwarning); auto ode_solver = MathLib::createOdeSolver<1, ExtraData>(config); + ExtraData data; + ode_solver->setFunction(f_extra, nullptr, data); + ode_solver->init(); ode_solver->setTolerance(1e-8, 1e-6); - ExtraData data; - ode_solver->setFunction(f_extra, nullptr, &data); - ode_solver->setIC(t0, {y0}); ode_solver->preSolve(); @@ -159,11 +158,10 @@ TEST(MathLibCVodeTest, ExponentialWithJacobian) BaseLib::ConfigTree::onwarning); auto ode_solver = MathLib::createOdeSolver<1>(config); + ode_solver->setFunction(f, df); ode_solver->init(); ode_solver->setTolerance(1e-10, 1e-8); - ode_solver->setFunction(f, df); - ode_solver->setIC(t0, {y0}); ode_solver->preSolve(); @@ -201,11 +199,10 @@ TEST(MathLibCVodeTest, ExponentialWithJacobianNewton) auto ode_solver = MathLib::createOdeSolver<1>(config); + ode_solver->setFunction(f, df); ode_solver->init(); ode_solver->setTolerance(1e-6, 1e-6); - ode_solver->setFunction(f, df); - ode_solver->setIC(t0, {y0}); ode_solver->preSolve();