A small illustration on using the armadillo C++ linear algebra
library for solving an ordinary differential equation of the form
\[ X’(t) = F(t,X(t),U(t)).\]
The abstract super class Solver
defines the methods solve (for approximating the solution in
user-defined time-points) and solveint (for interpolating user-defined
input functions on a finer grid). As an illustration a simple
Runge-Kutta solver is derived in the class RK4.
The first step is to define the ODE, here a simple one-dimensional ODE \(X’(t) = \theta\cdot\{U(t)-X(t)\}\) with a single input \(U(t)\):
rowvec dX(const rowvec &input, // time (first element) and additional input variables
const rowvec &x, // state variables
const rowvec &theta) { // parameters
rowvec res = { theta(0)*theta(1)*(input(1)-x(0)) };
return( res );
}
The ODE may then be solved using the following syntax
odesolver::RK4 MyODE(dX);
arma::mat res = MyODE.solve(input, init, theta);
with the step size defined implicitly by input (first column is the time variable
and the following columns the optional different input variables) and
boundary conditions defined by init.
The header file
#ifndef _ODESOLVER_H_
#define _ODESOLVER_H_
#ifndef RARMA
#define MATHLIB_STANDALONE
#include <armadillo>
#include "Rmath.h"
#endif
#if defined(RARMA)
#include <RcppArmadillo.h>
#endif
namespace odesolver {
using odefunc = std::function<arma::mat(arma::mat input, arma::mat x, arma::mat theta)>; // Type definition
/*!
Abstract class for ODE Solver
*/
class Solver {
protected:
odefunc F;
public:
Solver(odefunc F) { this->F = F; }
virtual ~Solver() {}
virtual arma::mat solve(const arma::mat &input, arma::mat init, arma::mat theta) = 0;
arma::mat solveint(const arma::mat &input, arma::mat init, arma::mat theta, double tau=1.0e-1, bool reduce=true);
};
/*!
Clasisc Runge-Kutta solver
*/
class RK4 : public Solver { // Basic 4th order Runge-Kutta solver
public:
using Solver::Solver;
arma::mat solve(const arma::mat &input, arma::mat init, arma::mat theta);
};
} // namespace odesolver
arma::uvec approx(const arma::mat &time, // Sorted time points (Ascending)
const arma::mat &newtime,
unsigned type=0); // (0: nearest, 1: right, 2: left)
arma::mat interpolate(const arma::mat &input, // first column is time
double tau, // Time-step
bool locf=false); // Last-observation-carried forward, otherwise linear interpolation
#endif /* _ODESOLVER_H_ */
Class cpp file
#include <vector>
#include "odesolver.h"
using namespace arma;
namespace odesolver {
arma::uvec approx(const arma::mat &time, // Sorted time points (Ascending)
const arma::mat &newtime,
unsigned type) { // (0: nearest, 1: right, 2: left)
uvec idx(newtime.n_elem);
double vmax = time(time.n_elem-1);
vec::const_iterator it;
double upper=0.0; int pos=0;
for (int i=0; i<newtime.n_elem; i++) {
if (newtime[i]>=vmax) {
pos = time.n_elem-1;
} else {
it = std::lower_bound(time.begin(), time.end(), newtime(i));
upper = *it;
if (it == time.begin()) {
pos = 0;
} else {
pos = int(it-time.begin());
if (type==0 && std::fabs(newtime(i)-time(pos-1)) < std::fabs(newtime(i)-time(pos))) pos -= 1;
}
}
if (type==2 && newtime(i)<upper) pos--;
idx(i) = pos;
}
return(idx);
}
arma::mat interpolate(const arma::mat &input, double tau, bool locf) {
vec time = input.col(0);
unsigned n = time.n_elem;
double t0 = time(0);
double tn = time(n-1);
unsigned N = std::ceil((tn-t0)/tau)+1;
mat input2(N, input.n_cols);
unsigned cur = 0;
input2.row(0) = input.row(0);
double curtime = t0;
rowvec slope(input.n_cols);
if (locf) {
slope.fill(0); slope(0) = 1;
} else {
slope = (input.row(cur+1)-input.row(cur))/(time(cur+1)-time(cur));
}
for (unsigned i=0; i<N-1; i++) {
while (time(cur+1)<curtime) {
cur++;
if (cur==(n-1)) break;
if (!locf)
slope = (input.row(cur+1)-input.row(cur))/(time(cur+1)-time(cur));
}
double delta = curtime-time(cur);
input2.row(i) = input.row(cur) + slope*(curtime-time(cur));
curtime += tau;
}
tau = tn-input2(N-2,0);
input2.row(N-1) = input.row(input.n_rows-1);
return( input2 );
}
arma::mat RK4::solve(const arma::mat &input, arma::mat init, arma::mat theta) {
unsigned n = input.n_rows;
unsigned p = init.n_elem;
mat res(n, p);
rowvec y = init;
res.row(0) = init;
for (unsigned i=0; i<n-1; i++) {
rowvec dinput = input.row(i+1)-input.row(i);
double tau = dinput(0);
rowvec f1 = tau*F(input.row(i), y, theta);
rowvec f2 = tau*F(input.row(i) + dinput/2, y + f1/2, theta);
rowvec f3 = tau*F(input.row(i) + dinput/2, y + f2/2, theta);
rowvec f4 = tau*F(input.row(i) + dinput, y + f3, theta);
y += (f1+2*f2+2*f3+f4)/6;
res.row(i+1) = y;
}
return( res );
}
arma::mat Solver::solveint(const arma::mat &input, arma::mat init, arma::mat theta, double tau, bool reduce) {
mat newinput = interpolate(input, tau, true);
mat value = solve(newinput, init, theta);
if (reduce) {
uvec idx = odesolver::approx(newinput.col(0), input.col(0), 0);
value = value.rows(idx);
}
return( value );
}
} // namespace odesolver