Lotka-Volterra ODE Benchmark
We benchmark ODE solvers on the Lotka-Volterra predator-prey model. This is a non-stiff system of equations defined as:
where \(x\) is the number of prey, \(y\) is the number of predators, and \(a\), \(b\), \(c\), and \(d\) are positive real parameters that describe the interaction between the two species.
We will benchmark the following solvers against different tolerances:
Diffsol’s BDF & TSIT45 methods
CasADi’s CVODE solver
Diffrax’s Tsit5 solvers
Julia’s DifferentialEquations.jl FBDF & Tsit5 methods
The following solvers are similar so should be compared against each other:
Diffsol BDF, CasADi CVODE & DifferentialEquations.jl FBDF methods.
Diffsol’s TSIT45, DifferentialEquations.jl Tsit5 and Diffrax’s Tsit5 methods are all explicit Runge-Kutta methods with identical tableaus.
Benchmark Setup
For each solver, we perform as much initial setup as possible outside of the timing loop to ensure a fair comparison, using a function called setup. The actual benchmark is performed in a function called bench.
The code for the Diffsol solvers is shown below:
def lokta_volterra_ode_str():
code = """
a { 2.0 / 3.0 }
b { 4.0 / 3.0 }
c { 1.0 }
d { 1.0 }
u_i {
x = 1.0,
y = 1.0,
}
F_i {
a * x - b * x * y,
-c * y + d * x * y,
}
"""
t_final = 10.0
return code, t_final
import pydiffsol as ds
import numpy as np
from diffsol_lotka_volterra import lokta_volterra_ode_str
from diffsol_robertson import robertson_ode_str
def setup(ngroups: int, tol: float, method: str, problem: str):
if ngroups < 20:
matrix_type = ds.nalgebra_dense
else:
matrix_type = ds.faer_sparse
if method == "bdf":
method = ds.bdf
elif method == "esdirk34":
method = ds.esdirk34
elif method == "tr_bdf2":
method = ds.tr_bdf2
elif method == "tsit5":
method = ds.tsit45
else:
raise ValueError(f"Unknown method: {method}")
if problem == "robertson_ode":
code, t_final = robertson_ode_str(ngroups=ngroups)
elif problem == "lotka_volterra_ode":
code, t_final = lokta_volterra_ode_str()
else:
raise ValueError(f"Unknown problem: {problem}")
ode = ds.Ode(
code,
matrix_type=matrix_type,
scalar_type=ds.f64,
ode_solver=method,
)
ode.rtol = tol
ode.atol = tol
return ode, t_final
def bench(model):
ode, t_final = model
params = np.array([])
ys = ode.solve_dense(params, np.array([t_final])).ys
return ys[:, -1]
The code for the CasADi solver is shown below:
import casadi
import numpy as np
def setup_lokta_volterra_ode():
x = casadi.MX.sym("x")
y = casadi.MX.sym("y")
a = 2.0 / 3.0
b = 4.0 / 3.0
c = 1.0
d = 1.0
# Expression for ODE right-hand side
f0 = a * x - b * x * y
f1 = -c * y + d * x * y
ode = {} # ODE declaration
ode["x"] = casadi.vertcat(x, y) # states
ode["ode"] = casadi.vertcat(f0, f1) # right-hand side
x0 = np.ones(2)
return (ode, 10.0, x0)
import casadi
import numpy as np
from casadi_lotka_volterra import setup_lokta_volterra_ode
from casadi_robertson import setup_robertson_ode
def setup(ngroups: int, tol: float, problem: str):
if problem == "robertson_ode":
(ode, t_final, x0) = setup_robertson_ode(ngroups)
elif problem == "lotka_volterra_ode":
(ode, t_final, x0) = setup_lokta_volterra_ode()
else:
raise ValueError(f"Unknown problem: {problem}")
F = casadi.integrator(
"F", "cvodes", ode, 0.0, t_final, {"abstol": tol, "reltol": tol}
)
return F, x0
def bench(model) -> np.ndarray:
F, x0 = model
return F(x0=x0)["xf"][:, -1]
The code for the Diffrax solver is shown below:
import equinox as eqx # https://github.com/patrick-kidger/equinox
import jax.numpy as jnp
class LotkaVolterra(eqx.Module):
ngroups: int
def __call__(self, t, y, args):
a = 2.0 / 3.0
b = 4.0 / 3.0
c = 1.0
d = 1.0
f0 = a * y[0] - b * y[0] * y[1]
f1 = -c * y[1] + d * y[0] * y[1]
return jnp.vstack([f0, f1]).flatten()
from functools import partial
import diffrax
import jax
import jax.numpy as jnp
from diffrax_lotka_volterra import LotkaVolterra
from diffrax_robertson import RobertsonOde
# Enable 64-bit precision in JAX, required solving problems
# with tolerances of 1e-8
# (see https://docs.kidger.site/diffrax/examples/stiff_ode/)
jax.config.update("jax_enable_x64", True)
def setup(ngroups: int, tol: float, method: str, problem: str):
if problem == "robertson_ode":
t_final = 1e10
y0 = jnp.concatenate([jnp.ones(ngroups), jnp.zeros(2 * ngroups)])
problem = RobertsonOde(ngroups=ngroups)
elif problem == "lotka_volterra_ode":
y0 = jnp.ones(2)
t_final = 10.0
problem = LotkaVolterra(ngroups=1)
else:
raise ValueError(f"Unknown problem: {problem}")
if method == "kvaerno5":
solver = diffrax.Kvaerno5()
elif method == "tsit5":
solver = diffrax.Tsit5()
else:
raise ValueError(f"Unknown method: {method}")
return (problem, tol, t_final, solver, HashableArrayWrapper(y0))
# https://github.com/jax-ml/jax/issues/4572#issuecomment-709809897
def some_hash_function(x):
return int(jnp.sum(x))
class HashableArrayWrapper:
def __init__(self, val):
self.val = val
def __hash__(self):
return some_hash_function(self.val)
def __eq__(self, other):
return (isinstance(other, HashableArrayWrapper) and jnp.all(jnp.equal(self.val, other.val)))
@partial(jax.jit, static_argnames=["model"])
def bench(model) -> jnp.ndarray:
(model, tol, t_final, solver, y0) = model
terms = diffrax.ODETerm(model)
stepsize_controller = diffrax.PIDController(rtol=tol, atol=tol)
t0 = 0.0
t1 = t_final
dt0 = None
sol = diffrax.diffeqsolve(
terms,
solver,
t0,
t1,
dt0,
y0.val,
stepsize_controller=stepsize_controller,
)
return sol.ys[-1]
The code for the DifferentialEquations.jl solvers is shown below:
import DifferentialEquations as DE
import ModelingToolkit as MTK
function setup_lotka_volterra_ode(ngroups)
function lotka_volterra!(du, u, p, t)
a, b, c, d = p
du[1] = a*u[1] - b*u[1]*u[2]
du[2] = -c*u[2] + d*u[1]*u[2]
nothing
end
u0 = [1.0, 1.0]
p = [2.0 / 3.0, 4.0 / 3.0, 1.0, 1.0]
tspan = (0.0, 10.0)
prob = DE.ODEProblem(lotka_volterra!, u0, tspan, p)
return prob, tspan
end
import DifferentialEquations as DE
import ModelingToolkit as MTK
using OrdinaryDiffEqBDF: FBDF
using OrdinaryDiffEqSDIRK: KenCarp3, TRBDF2
using OrdinaryDiffEqTsit5: Tsit5
include("diffeq_robertson.jl")
include("diffeq_lokta_volterra.jl")
function setup(ngroups, tol, method, problem)
if problem == "robertson_ode"
prob, tspan = setup_robertson_ode(ngroups)
elseif problem == "lotka_volterra_ode"
prob, tspan = setup_lotka_volterra_ode(ngroups)
else
error("Unknown problem: $problem")
end
@MTK.mtkcompile sys = MTK.modelingtoolkitize(prob)
prob = DE.ODEProblem(sys, [], tspan, jac=true, sparse=ngroups >= 20)
if method == "bdf"
alg = FBDF()
elseif method == "kencarp3"
alg = KenCarp3()
elseif method == "tr_bdf2"
alg = TRBDF2()
elseif method == "tsit5"
alg = Tsit5()
else
error("Unknown method: $method")
end
return (prob, alg, tol, tspan)
end
function bench(model)
(prob, alg, tol, tspan) = model
sol = DE.solve(prob, alg=alg, reltol = tol, abstol = tol, saveat=tspan[2])
return sol.u[:, end]
end
Differences between implementations
There are a few key differences between the Diffrax, Casadi, Diffsol and DifferentialEquations.jl implementations that may affect the performance of the solvers. The main differences are:
JIT compilation & Function evaluation overhead: The Diffsol implementation uses the DiffSL JIT compiler to compile the model equations to optimised native code. Diffrax uses JAX to JIT compile the model equations and solvers to XLA optimised code. Julia is a JIT compiled language so the DifferentialEquations.jl implementation is also JIT compiled to native code using LLVM. Casadi is the only implementation that does not use JIT compilation, instead it builds a computational graph of operations to evaluate the model equations, which adds additional overhead to each function evaluation.
Results
- The benchmarks were run on:
A Dell PowerEdge R7525 2U rack server, with dual AMD EPYC 7343 3.2Ghz 16C CPU and 128GB Memory
A Macbook M2 Pro (14-inch, 2023) with 16GB Memory, 12 (8 performance and 4 efficiency) cores
The results are shown below:
The Diffsol implementation of the TSIT45 method performs best across all tolerances and hardware setups. The DifferentialEquations.jl and Diffrax implementations perform similarly, with DifferentialEquations.jl slightly improved, particularly on the rack server.
The BDF/FBDF methods follow a similar trend, with the Diffsol implementation significantly outperforming both the Casadi and DifferentialEquations.jl implementations across all tolerances and hardware setups.