A simple ODE Class
A small illustration on using the
armadillo
C++ linear algebra
library for solving an ordinary differential equation of the form
\[ X’(t) = F(t,X(t),U(t)).\]
The abstract super class Solver
defines the methods solve
(for approximating the solution in
user-defined time-points) and solveint
(for interpolating user-defined
input functions on finer grid). As an illustration a simple
Runge-Kutta solver is derived in the class RK4
.
The first step is to define the ODE, here a simple one-dimensional ODE \(X’(t) = \theta\cdot\{U(t)-X(t)\}\) with a single input \(U(t)\):
rowvec dX(const rowvec &input, // time (first element) and additional input variables
const rowvec &x, // state variables
const rowvec &theta) { // parameters
rowvec res = { theta(0)*theta(1)*(input(1)-x(0)) };
return( res );
}
The ODE may then be solved using the following syntax
odesolver::RK4 MyODE(dX);
arma::mat res = MyODE.solve(input, init, theta);
with the step size defined implicitly by input
(first column is the time variable
and the following columns the optional different input variables) and
boundary conditions defined by init
.
The header file
#ifndef _ODESOLVER_H_
#define _ODESOLVER_H_
#ifndef RARMA
#define MATHLIB_STANDALONE
#include <armadillo>
#include "Rmath.h"
#endif
#if defined(RARMA)
#include <RcppArmadillo.h>
#endif
namespace odesolver {
using odefunc = std::function<arma::mat(arma::mat input, arma::mat x, arma::mat theta)>; // Type definition
/*!
Abstract class for ODE Solver
*/
class Solver {
protected:
odefunc F;
public:
Solver(odefunc F) { this->F = F; }
virtual ~Solver() {}
virtual arma::mat solve(const arma::mat &input, arma::mat init, arma::mat theta) = 0;
arma::mat solveint(const arma::mat &input, arma::mat init, arma::mat theta, double tau=1.0e-1, bool reduce=true);
};
/*!
Clasisc Runge-Kutta solver
*/
class RK4 : public Solver { // Basic 4th order Runge-Kutta solver
public:
using Solver::Solver;
arma::mat solve(const arma::mat &input, arma::mat init, arma::mat theta);
};
} // namespace odesolver
arma::uvec approx(const arma::mat &time, // Sorted time points (Ascending)
const arma::mat &newtime,
unsigned type=0); // (0: nearest, 1: right, 2: left)
arma::mat interpolate(const arma::mat &input, // first column is time
double tau, // Time-step
bool locf=false); // Last-observation-carried forward, otherwise linear interpolation
#endif /* _ODESOLVER_H_ */
Class cpp file
#include <vector>
#include "odesolver.h"
using namespace arma;
namespace odesolver {
arma::uvec approx(const arma::mat &time, // Sorted time points (Ascending)
const arma::mat &newtime,
unsigned type) { // (0: nearest, 1: right, 2: left)
uvec idx(newtime.n_elem);
double vmax = time(time.n_elem-1);
vec::const_iterator it;
double upper=0.0; int pos=0;
for (int i=0; i<newtime.n_elem; i++) {
if (newtime[i]>=vmax) {
pos = time.n_elem-1;
} else {
it = std::lower_bound(time.begin(), time.end(), newtime(i));
upper = *it;
if (it == time.begin()) {
pos = 0;
} else {
pos = int(it-time.begin());
if (type==0 && std::fabs(newtime(i)-time(pos-1)) < std::fabs(newtime(i)-time(pos))) pos -= 1;
}
}
if (type==2 && newtime(i)<upper) pos--;
idx(i) = pos;
}
return(idx);
}
arma::mat interpolate(const arma::mat &input, double tau, bool locf) {
vec time = input.col(0);
unsigned n = time.n_elem;
double t0 = time(0);
double tn = time(n-1);
unsigned N = std::ceil((tn-t0)/tau)+1;
mat input2(N, input.n_cols);
unsigned cur = 0;
input2.row(0) = input.row(0);
double curtime = t0;
rowvec slope(input.n_cols);
if (locf) {
slope.fill(0); slope(0) = 1;
} else {
slope = (input.row(cur+1)-input.row(cur))/(time(cur+1)-time(cur));
}
for (unsigned i=0; i<N-1; i++) {
while (time(cur+1)<curtime) {
cur++;
if (cur==(n-1)) break;
if (!locf)
slope = (input.row(cur+1)-input.row(cur))/(time(cur+1)-time(cur));
}
double delta = curtime-time(cur);
input2.row(i) = input.row(cur) + slope*(curtime-time(cur));
curtime += tau;
}
tau = tn-input2(N-2,0);
input2.row(N-1) = input.row(input.n_rows-1);
return( input2 );
}
arma::mat RK4::solve(const arma::mat &input, arma::mat init, arma::mat theta) {
unsigned n = input.n_rows;
unsigned p = init.n_elem;
mat res(n, p);
rowvec y = init;
res.row(0) = init;
for (unsigned i=0; i<n-1; i++) {
rowvec dinput = input.row(i+1)-input.row(i);
double tau = dinput(0);
rowvec f1 = tau*F(input.row(i), y, theta);
rowvec f2 = tau*F(input.row(i) + dinput/2, y + f1/2, theta);
rowvec f3 = tau*F(input.row(i) + dinput/2, y + f2/2, theta);
rowvec f4 = tau*F(input.row(i) + dinput, y + f3, theta);
y += (f1+2*f2+2*f3+f4)/6;
res.row(i+1) = y;
}
return( res );
}
arma::mat Solver::solveint(const arma::mat &input, arma::mat init, arma::mat theta, double tau, bool reduce) {
mat newinput = interpolate(input, tau, true);
mat value = solve(newinput, init, theta);
if (reduce) {
uvec idx = odesolver::approx(newinput.col(0), input.col(0), 0);
value = value.rows(idx);
}
return( value );
}
} // namespace odesolver