Submitted by **martenlienen** t3_zfvb8h
in **MachineLearning**

#
**martenlienen**
OP
t1_izejs85 wrote

Reply to comment by **MathChief** in **[R] torchode: A Parallel ODE Solver for PyTorch** by **martenlienen**

In these benchmarks we compare the same Runge-Kutta solver (5th order Dormand-Prince) implemented in all of these libraries. None of these libraries actually propose any new stepping methods. The point is to make ODE solvers available in popular deep learning methods to enable deep continuous models such as neural ODEs and continuous normalizing flows. The particular appeal of torchode is its optimized implementation and that it runs multiple independent instances of an ODE solver in parallel when you train on batches, i.e. each instance is solved with its own step size and step accept/reject decisions. This avoids a performance pitfall where the usual batching approach can lead to many unnecessary solver steps in batched training of models with varying stiffness, as we show in the Van der Pol experiment.

#
**MathChief**
t1_izgdz7x wrote

One more question: for the Van der Pol benchmark

Using the old faithful `ode45`

in MATLAB (Runge-Kutta) to run test problem you listed in the poster

```
vdp1 = @(t, y) [y(2); 25*(1-y(1)^2)*y(2)-y(1)];
tic;
[t,y] = ode45(vdp1,[0 5],[1.2, 0.1]);
toc;
```

It takes only 0.007329 seconds for marching 389 steps using FP64 on CPU. What is the loop time's unit? What is the bottleneck of the implemented algorithm?

#
**martenlienen**
OP
t1_izi8k3g wrote

First, the difference in steps is probably due to different tolerances in the step size controller.

The loop times is measured in milliseconds. Of course, that is much slower than what you got in matlab. The difference is that we did all benchmarks on GPU, because that is the usual mode for deep learning even though it is certainly inappropriate for the VdP equation if you were interested in it for anything else but benchmarking the inner loop of an ODE solver on a GPU. I think, you can get similar numbers to your matlab code in diffrax with JIT compilation on a CPU. However, you won't get it with torchode because PyTorch's JIT is not as good as JAX's and specifically this line is really slow on CPUs. Nonetheless, after comparing several alternatives we chose this because, as I said, in practice in most of deep learning only GPU performance matters.

#
**MathChief**
t1_izjayhi wrote

Cool. Thanks for the explanation.

#
**MathChief**
t1_izeuyop wrote

Sounds nice. Thanks for the explanation.

#
**Rodot**
t1_izhk7o5 wrote

Also, people have used neural networks to solve the equation sof radiative transfer and used them in scientific papers

Viewing a single comment thread. View all comments