Lotka-Volterra ODE Benchmark

We benchmark ODE solvers on the Lotka-Volterra predator-prey model. This is a non-stiff system of equations defined as:

\[ \begin{align}\begin{aligned}\begin{aligned} \frac{dx}{dt} &= a x - b xy\\\frac{dy}{dt} &= c xy - d y \end{aligned}\end{aligned}\end{align} \]

where \(x\) is the number of prey, \(y\) is the number of predators, and \(a\), \(b\), \(c\), and \(d\) are positive real parameters that describe the interaction between the two species.

We will benchmark the following solvers against different tolerances:

  • Diffsol’s BDF & TSIT45 methods

  • CasADi’s CVODE solver

  • Diffrax’s Tsit5 solvers

  • Julia’s DifferentialEquations.jl FBDF & Tsit5 methods

The following solvers are similar so should be compared against each other:

  • Diffsol BDF, CasADi CVODE & DifferentialEquations.jl FBDF methods.

  • Diffsol’s TSIT45, DifferentialEquations.jl Tsit5 and Diffrax’s Tsit5 methods are all explicit Runge-Kutta methods with identical tableaus.

Benchmark Setup

For each solver, we perform as much initial setup as possible outside of the timing loop to ensure a fair comparison, using a function called setup. The actual benchmark is performed in a function called bench.

The code for the Diffsol solvers is shown below:

def lokta_volterra_ode_str():
    code = """
    a { 2.0 / 3.0 }
    b { 4.0 / 3.0 }
    c { 1.0 }
    d { 1.0 }
    u_i {
        x = 1.0,
        y = 1.0,
    }
    F_i {
        a * x - b * x * y,
        -c * y + d * x * y,
    }
    """
    t_final = 10.0
    return code, t_final
import pydiffsol as ds
import numpy as np
from diffsol_lotka_volterra import lokta_volterra_ode_str
from diffsol_robertson import robertson_ode_str


def setup(ngroups: int, tol: float, method: str, problem: str):
    if ngroups < 20:
        matrix_type = ds.nalgebra_dense
    else:
        matrix_type = ds.faer_sparse
    if method == "bdf":
        method = ds.bdf
    elif method == "esdirk34":
        method = ds.esdirk34
    elif method == "tr_bdf2":
        method = ds.tr_bdf2
    elif method == "tsit5":
        method = ds.tsit45
    else:
        raise ValueError(f"Unknown method: {method}")

    if problem == "robertson_ode":
        code, t_final = robertson_ode_str(ngroups=ngroups)
    elif problem == "lotka_volterra_ode":
        code, t_final = lokta_volterra_ode_str()
    else:
        raise ValueError(f"Unknown problem: {problem}")

    ode = ds.Ode(
        code,
        matrix_type=matrix_type,
        scalar_type=ds.f64,
        ode_solver=method,
    )
    ode.rtol = tol
    ode.atol = tol

    return ode, t_final


def bench(model):
    ode, t_final = model
    params = np.array([])
    ys = ode.solve_dense(params, np.array([t_final])).ys
    return ys[:, -1]

The code for the CasADi solver is shown below:

import casadi
import numpy as np

def setup_lokta_volterra_ode():
    x = casadi.MX.sym("x")
    y = casadi.MX.sym("y")
    a = 2.0 / 3.0
    b = 4.0 / 3.0
    c = 1.0
    d = 1.0

    # Expression for ODE right-hand side
    f0 = a * x - b * x * y
    f1 = -c * y + d * x * y

    ode = {}  # ODE declaration
    ode["x"] = casadi.vertcat(x, y)  # states
    ode["ode"] = casadi.vertcat(f0, f1)  # right-hand side
    x0 = np.ones(2)
    return (ode, 10.0, x0)
import casadi
import numpy as np
from casadi_lotka_volterra import setup_lokta_volterra_ode
from casadi_robertson import setup_robertson_ode

def setup(ngroups: int, tol: float, problem: str):
    if problem == "robertson_ode":
        (ode, t_final, x0) = setup_robertson_ode(ngroups)
    elif problem == "lotka_volterra_ode":
        (ode, t_final, x0) = setup_lokta_volterra_ode()
    else:
        raise ValueError(f"Unknown problem: {problem}")
    F = casadi.integrator(
        "F", "cvodes", ode, 0.0, t_final, {"abstol": tol, "reltol": tol}
    )
    return F, x0


def bench(model) -> np.ndarray:
    F, x0 = model
    return F(x0=x0)["xf"][:, -1]

The code for the Diffrax solver is shown below:

import equinox as eqx  # https://github.com/patrick-kidger/equinox
import jax.numpy as jnp

class LotkaVolterra(eqx.Module):
    ngroups: int

    def __call__(self, t, y, args):
        a = 2.0 / 3.0
        b = 4.0 / 3.0
        c = 1.0
        d = 1.0

        f0 = a * y[0] - b * y[0] * y[1]
        f1 = -c * y[1] + d * y[0] * y[1]
        return jnp.vstack([f0, f1]).flatten()
from functools import partial
import diffrax
import jax
import jax.numpy as jnp
from diffrax_lotka_volterra import LotkaVolterra
from diffrax_robertson import RobertsonOde

# Enable 64-bit precision in JAX, required solving problems
# with tolerances of 1e-8
# (see https://docs.kidger.site/diffrax/examples/stiff_ode/)
jax.config.update("jax_enable_x64", True)

def setup(ngroups: int, tol: float, method: str, problem: str):
    if problem == "robertson_ode":
        t_final = 1e10
        y0 = jnp.concatenate([jnp.ones(ngroups), jnp.zeros(2 * ngroups)])
        problem = RobertsonOde(ngroups=ngroups)
    elif problem == "lotka_volterra_ode":
        y0 = jnp.ones(2)
        t_final = 10.0
        problem = LotkaVolterra(ngroups=1)
    else:
        raise ValueError(f"Unknown problem: {problem}")
    if method == "kvaerno5":
        solver = diffrax.Kvaerno5()
    elif method == "tsit5":
        solver = diffrax.Tsit5()
    else:
        raise ValueError(f"Unknown method: {method}")
    return (problem, tol, t_final, solver, HashableArrayWrapper(y0))


# https://github.com/jax-ml/jax/issues/4572#issuecomment-709809897
def some_hash_function(x):
    return int(jnp.sum(x))


class HashableArrayWrapper:
    def __init__(self, val):
        self.val = val

    def __hash__(self):
        return some_hash_function(self.val)

    def __eq__(self, other):
        return (isinstance(other, HashableArrayWrapper) and jnp.all(jnp.equal(self.val, other.val)))


@partial(jax.jit, static_argnames=["model"])
def bench(model) -> jnp.ndarray:
    (model, tol, t_final, solver, y0) = model
    terms = diffrax.ODETerm(model)
    stepsize_controller = diffrax.PIDController(rtol=tol, atol=tol)

    t0 = 0.0
    t1 = t_final
    dt0 = None
    sol = diffrax.diffeqsolve(
        terms,
        solver,
        t0,
        t1,
        dt0,
        y0.val,
        stepsize_controller=stepsize_controller,
    )
    return sol.ys[-1]

The code for the DifferentialEquations.jl solvers is shown below:

import DifferentialEquations as DE
import ModelingToolkit as MTK


function setup_lotka_volterra_ode(ngroups)
    function lotka_volterra!(du, u, p, t)
        a, b, c, d = p
        du[1] = a*u[1] - b*u[1]*u[2]
        du[2] = -c*u[2] + d*u[1]*u[2]
        nothing
    end
    u0 = [1.0, 1.0]
    p = [2.0 / 3.0, 4.0 / 3.0, 1.0, 1.0]
    tspan = (0.0, 10.0)
    prob = DE.ODEProblem(lotka_volterra!, u0, tspan, p)
    return prob, tspan
end
import DifferentialEquations as DE
import ModelingToolkit as MTK
using OrdinaryDiffEqBDF: FBDF
using OrdinaryDiffEqSDIRK: KenCarp3, TRBDF2
using OrdinaryDiffEqTsit5: Tsit5
include("diffeq_robertson.jl")
include("diffeq_lokta_volterra.jl")


function setup(ngroups, tol, method, problem)
    if problem == "robertson_ode"
        prob, tspan = setup_robertson_ode(ngroups)
    elseif problem == "lotka_volterra_ode"
        prob, tspan = setup_lotka_volterra_ode(ngroups)
    else
        error("Unknown problem: $problem")
    end
    @MTK.mtkcompile sys = MTK.modelingtoolkitize(prob)
    prob = DE.ODEProblem(sys, [], tspan, jac=true, sparse=ngroups >= 20)
    if method == "bdf"
        alg = FBDF()
    elseif method == "kencarp3"
        alg = KenCarp3()
    elseif method == "tr_bdf2"
        alg = TRBDF2()
    elseif method == "tsit5"
        alg = Tsit5()
    else
        error("Unknown method: $method")
    end
    return (prob, alg, tol, tspan)
end

function bench(model)
    (prob, alg, tol, tspan) = model
    sol = DE.solve(prob, alg=alg, reltol = tol, abstol = tol, saveat=tspan[2])
    return sol.u[:, end]
end

Differences between implementations

There are a few key differences between the Diffrax, Casadi, Diffsol and DifferentialEquations.jl implementations that may affect the performance of the solvers. The main differences are:

  • JIT compilation & Function evaluation overhead: The Diffsol implementation uses the DiffSL JIT compiler to compile the model equations to optimised native code. Diffrax uses JAX to JIT compile the model equations and solvers to XLA optimised code. Julia is a JIT compiled language so the DifferentialEquations.jl implementation is also JIT compiled to native code using LLVM. Casadi is the only implementation that does not use JIT compilation, instead it builds a computational graph of operations to evaluate the model equations, which adds additional overhead to each function evaluation.

Results

The benchmarks were run on:
  • A Dell PowerEdge R7525 2U rack server, with dual AMD EPYC 7343 3.2Ghz 16C CPU and 128GB Memory

  • A Macbook M2 Pro (14-inch, 2023) with 16GB Memory, 12 (8 performance and 4 efficiency) cores

The results are shown below:

benchmark_lotka_volterra_ode_rack_server.svg benchmark_lotka_volterra_ode_macbook.svg

The Diffsol implementation of the TSIT45 method performs best across all tolerances and hardware setups. The DifferentialEquations.jl and Diffrax implementations perform similarly, with DifferentialEquations.jl slightly improved, particularly on the rack server.

The BDF/FBDF methods follow a similar trend, with the Diffsol implementation significantly outperforming both the Casadi and DifferentialEquations.jl implementations across all tolerances and hardware setups.