Robertson ODE Benchmark

This is a benchmark on a stiff ODE system, using the classic test case, the Robertson (1966) problem, which models the kinetics of an autocatalytic reaction. This test case can be written as an ODE given by the following set of equations:

\[\begin{split}\begin{aligned} \frac{dx}{dt} &= -0.04x + 10^4 y z \\ \frac{dy}{dt} &= 0.04x - 3 \cdot 10^7 y^2 - 10^4 y z \\ \frac{dz}{dt} &= 3 \cdot 10^7 y^2 \\ \end{aligned}\end{split}\]

with initial conditions:

\[\begin{split}\begin{aligned} x(0) &= 1 \\ y(0) &= 0 \\ z(0) &= 0 \\ \end{aligned}\end{split}\]

This problem is known to be stiff due to the presence of widely varying timescales in the solution.

We can extend this problem to a larger system by creating multiple groups of the Robertson equations, where each group is independent of the others. This allows us to benchmark the performance of different ODE solvers as we increase the size of the system.

Solvers

We benchmark the following solvers:

  • Diffsol’s BDF, ESDIRK34 & TR-BDF2 methods

  • CasADi’s CVODE solver

  • Diffrax’s Kvaerno5 & Tsit5 solvers

  • Julia’s DifferentialEquations.jl FBDF, KenCarp3, & TRBDF2 methods

The following solvers are similar so should be compared against each other:

  • Diffsol BDF, CasADi CVODE & DifferentialEquations.jl FBDF methods.

  • Diffsol’s TR-BDF2 and DifferentialEquations.jl TRBDF2.

  • Diffsol’s ESDIRK34, DifferentialEquations.jl KenCarp3 and Diffrax’s Tsit5 methods are different methods, but are all SDIRK implicit Runge-Kutta methods of similar order.

Benchmark Setup

For each solver, we perform as much initial setup as possible outside of the timing loop to ensure a fair comparison, using a function called setup. The actual benchmark is performed in a function called bench.

The code for the Diffsol solvers is shown below:

def robertson_ode_str(ngroups: int):
    u_i = (
        f"(0:{ngroups}): x = 1,\n"
        f"({ngroups}:{2 * ngroups}): y = 0,\n"
        f"({2 * ngroups}:{3 * ngroups}): z = 0,\n"
    )
    code = (
        """
        k1 { 0.04 }
        k2 { 30000000 }
        k3 { 10000 }
        u_i {
        """
        + u_i
        + """
        }
        F_i {
            -k1 * x_i + k3 * y_i * z_i,
            k1 * x_i - k2 * y_i * y_i - k3 * y_i * z_i,
            k2 * y_i * y_i,
        }
        """
    )
    t_final = 1e10
    return code, t_final
import pydiffsol as ds
import numpy as np
from diffsol_lotka_volterra import lokta_volterra_ode_str
from diffsol_robertson import robertson_ode_str


def setup(ngroups: int, tol: float, method: str, problem: str):
    if ngroups < 20:
        matrix_type = ds.nalgebra_dense
    else:
        matrix_type = ds.faer_sparse
    if method == "bdf":
        method = ds.bdf
    elif method == "esdirk34":
        method = ds.esdirk34
    elif method == "tr_bdf2":
        method = ds.tr_bdf2
    elif method == "tsit5":
        method = ds.tsit45
    else:
        raise ValueError(f"Unknown method: {method}")

    if problem == "robertson_ode":
        code, t_final = robertson_ode_str(ngroups=ngroups)
    elif problem == "lotka_volterra_ode":
        code, t_final = lokta_volterra_ode_str()
    else:
        raise ValueError(f"Unknown problem: {problem}")

    ode = ds.Ode(
        code,
        matrix_type=matrix_type,
        scalar_type=ds.f64,
        ode_solver=method,
    )
    ode.rtol = tol
    ode.atol = tol

    return ode, t_final


def bench(model):
    ode, t_final = model
    params = np.array([])
    ys = ode.solve_dense(params, np.array([t_final])).ys
    return ys[:, -1]

Note that for ngroup < 20` it uses the nalgebra dense matrix and LU solver, and for ngroups >= 20 the faer sparse matrix and LU solver are used.

The code for the CasADi solver is shown below:

import casadi
import numpy as np


def setup_robertson_ode(ngroups: int):
    x = casadi.MX.sym("x", ngroups)
    y = casadi.MX.sym("y", ngroups)
    z = casadi.MX.sym("z", ngroups)
    k1 = 0.04
    k2 = 30000000
    k3 = 10000

    # Expression for ODE right-hand side
    f0 = -k1 * x + k3 * y * z
    f1 = k1 * x - k2 * y**2 - k3 * y * z
    f2 = k2 * y**2

    ode = {}  # ODE declaration
    ode["x"] = casadi.vertcat(x, y, z)  # states
    ode["ode"] = casadi.vertcat(f0, f1, f2)  # right-hand side
    x0 = np.zeros(3 * ngroups)
    x0[:ngroups] = 1.0
    return (ode, 1e10, x0)
import casadi
import numpy as np
from casadi_lotka_volterra import setup_lokta_volterra_ode
from casadi_robertson import setup_robertson_ode

def setup(ngroups: int, tol: float, problem: str):
    if problem == "robertson_ode":
        (ode, t_final, x0) = setup_robertson_ode(ngroups)
    elif problem == "lotka_volterra_ode":
        (ode, t_final, x0) = setup_lokta_volterra_ode()
    else:
        raise ValueError(f"Unknown problem: {problem}")
    F = casadi.integrator(
        "F", "cvodes", ode, 0.0, t_final, {"abstol": tol, "reltol": tol}
    )
    return F, x0


def bench(model) -> np.ndarray:
    F, x0 = model
    return F(x0=x0)["xf"][:, -1]

The code for the Diffrax solver is shown below:

import equinox as eqx  # https://github.com/patrick-kidger/equinox
import jax.numpy as jnp

class RobertsonOde(eqx.Module):
    ngroups: int

    def __call__(self, t, y, args):
        k1 = 0.04
        k2 = 30000000.0
        k3 = 10000.0

        xs = slice(0, self.ngroups)
        ys = slice(self.ngroups, 2 * self.ngroups)
        zs = slice(2 * self.ngroups, 3 * self.ngroups)
        f0 = -k1 * y[xs] + k3 * y[ys] * y[zs]
        f1 = k1 * y[xs] - k2 * y[ys] ** 2 - k3 * y[ys] * y[zs]
        f2 = k2 * y[ys] ** 2
        return jnp.vstack([f0, f1, f2]).flatten()
from functools import partial
import diffrax
import jax
import jax.numpy as jnp
from diffrax_lotka_volterra import LotkaVolterra
from diffrax_robertson import RobertsonOde

# Enable 64-bit precision in JAX, required solving problems
# with tolerances of 1e-8
# (see https://docs.kidger.site/diffrax/examples/stiff_ode/)
jax.config.update("jax_enable_x64", True)

def setup(ngroups: int, tol: float, method: str, problem: str):
    if problem == "robertson_ode":
        t_final = 1e10
        y0 = jnp.concatenate([jnp.ones(ngroups), jnp.zeros(2 * ngroups)])
        problem = RobertsonOde(ngroups=ngroups)
    elif problem == "lotka_volterra_ode":
        y0 = jnp.ones(2)
        t_final = 10.0
        problem = LotkaVolterra(ngroups=1)
    else:
        raise ValueError(f"Unknown problem: {problem}")
    if method == "kvaerno5":
        solver = diffrax.Kvaerno5()
    elif method == "tsit5":
        solver = diffrax.Tsit5()
    else:
        raise ValueError(f"Unknown method: {method}")
    return (problem, tol, t_final, solver, HashableArrayWrapper(y0))


# https://github.com/jax-ml/jax/issues/4572#issuecomment-709809897
def some_hash_function(x):
    return int(jnp.sum(x))


class HashableArrayWrapper:
    def __init__(self, val):
        self.val = val

    def __hash__(self):
        return some_hash_function(self.val)

    def __eq__(self, other):
        return (isinstance(other, HashableArrayWrapper) and jnp.all(jnp.equal(self.val, other.val)))


@partial(jax.jit, static_argnames=["model"])
def bench(model) -> jnp.ndarray:
    (model, tol, t_final, solver, y0) = model
    terms = diffrax.ODETerm(model)
    stepsize_controller = diffrax.PIDController(rtol=tol, atol=tol)

    t0 = 0.0
    t1 = t_final
    dt0 = None
    sol = diffrax.diffeqsolve(
        terms,
        solver,
        t0,
        t1,
        dt0,
        y0.val,
        stepsize_controller=stepsize_controller,
    )
    return sol.ys[-1]

The code for the DifferentialEquations.jl solvers is shown below:

import DifferentialEquations as DE
import ModelingToolkit as MTK


function setup_robertson_ode(ngroups)
    function rober!(du, u, p, t)
        k₁, k₂, k₃ = p
        y₁ = @view u[1:ngroups]
        y₂ = @view u[ngroups+1:2*ngroups]
        y₃ = @view u[2*ngroups+1:3*ngroups]
        dy₁ = @view du[1:ngroups]
        dy₂ = @view du[ngroups+1:2*ngroups]
        dy₃ = @view du[2*ngroups+1:3*ngroups]
        dy₁ .= -k₁ .* y₁ .+ k₃ .* y₂ .* y₃
        dy₂ .= k₁ .* y₁ .- k₂ .* y₂ .^2 .- k₃ .* y₂ .* y₃
        dy₃ .= k₂ .* y₂ .^2
        nothing
    end
    u0 = vcat(ones(ngroups), zeros(2*ngroups))
    p = [0.04, 3e7, 1e4]
    tspan = (0.0, 1e10)
    prob = DE.ODEProblem(rober!, u0, tspan, p)
    return prob, tspan
end
import DifferentialEquations as DE
import ModelingToolkit as MTK
using OrdinaryDiffEqBDF: FBDF
using OrdinaryDiffEqSDIRK: KenCarp3, TRBDF2
using OrdinaryDiffEqTsit5: Tsit5
include("diffeq_robertson.jl")
include("diffeq_lokta_volterra.jl")


function setup(ngroups, tol, method, problem)
    if problem == "robertson_ode"
        prob, tspan = setup_robertson_ode(ngroups)
    elseif problem == "lotka_volterra_ode"
        prob, tspan = setup_lotka_volterra_ode(ngroups)
    else
        error("Unknown problem: $problem")
    end
    @MTK.mtkcompile sys = MTK.modelingtoolkitize(prob)
    prob = DE.ODEProblem(sys, [], tspan, jac=true, sparse=ngroups >= 20)
    if method == "bdf"
        alg = FBDF()
    elseif method == "kencarp3"
        alg = KenCarp3()
    elseif method == "tr_bdf2"
        alg = TRBDF2()
    elseif method == "tsit5"
        alg = Tsit5()
    else
        error("Unknown method: $method")
    end
    return (prob, alg, tol, tspan)
end

function bench(model)
    (prob, alg, tol, tspan) = model
    sol = DE.solve(prob, alg=alg, reltol = tol, abstol = tol, saveat=tspan[2])
    return sol.u[:, end]
end

Differences between implementations

There are a few key differences between the Diffrax, Casadi, Diffsol and DifferentialEquations.jl implementations that may affect the performance of the solvers. The main differences are:

  • Sparse vs Dense matrices: The Casadi implementation uses sparse matrices, whereas the Diffsol and DifferentialEquations.jl implementations use dense matrices for ngroups < 20, and sparse matrices for ngroups >= 20. This will provide an advantage for Diffsol for smaller problems. The Diffrax implementation uses dense matrices. Treating the Jacobian as dense will be a disadvantage for Diffrax for larger problems as the Jacobian is very sparse for larger problem sizes.

  • Multithreading: For the Macbook M2 Pro run, each library was free to use multiple threads according to their default settings. For the rack server, each library was limited to using 20 threads (using RAYON_NUM_THREADS=20 OMP_NUM_THREADS=20 JULIA_NUM_THREADS=20). The only part of the Diffsol implementation that takes advantage of multiple threads is the faer sparse LU solver and matrix. Both the nalgebra LU solver, matrix, and the DiffSL generated code are all single-threaded. Diffrax uses JAX, which takes advantage of multiple threads (CPU only, no GPUs were used in these benchmarks). Casadi uses multithreading via OpenMP and the Sundials solver. It is unclear if DifferentialEquations.jl uses multithreading for single ODE runs, although it supports use multiple threads for ensemble runs.

Results

The benchmarks were run on:
  • A Dell PowerEdge R7525 2U rack server, with dual AMD EPYC 7343 3.2Ghz 16C CPU and 128GB Memory

  • A Macbook M2 Pro (14-inch, 2023) with 16GB Memory, 12 (8 performance and 4 efficiency) cores

The results are shown below:

benchmark_robertson_ode_rack_server.svg benchmark_robertson_ode_macbook.svg

The Diffsol implementation outperforms the other implementations significantly for smaller problem sizes, especially when using the rack server. At these small problem sizes, the dense matrix and solver used by Diffsol provide an advantage over the sparse solver used by Casadi. Casadi also has additional overhead to evaluate each function evaluation, as it needs to traverse a graph of operations to calculate each rhs or jacobian evaluation, whereas the DiffSL JIT compiler will compile to native code using the LLVM backend, along with low-level optimisations that are not available to Casadi. Diffrax is also significantly slower than Diffsol for smaller problems, this might be due to (a) Diffrax being a ML library and not optimised for solving stiff ODEs, or (b) double precision is used, which again is not a common use case for ML libraries.

As the problem sizes get larger the dense solver used by Diffrax becomes less efficient and the runtime starts to diverge from the other methods The performance of Casadi improve rapidly relative to Diffsol as the problem size increases, and for n > 256 it becomes faster than the Diffsol BDF method for the rack server. For the macbook the Casadi solver never becomes faster than the Diffsol BDF method, instead the two methods converge in performance. This is likely due to the better multi-threading performance of the CVODE solver used by Casadi on the rack server, which has more CPU cores available.

The DifferentialEquations.jl implementations are slower than the Diffsol impolementation across all problem sizes, and slower than Casadi at larger problem sizes. Anthough the DifferentialEquations.jl FBDF method is faster than Casadi for smaller problems.