Submitted by martenlienen t3_zfvb8h in MachineLearning


We have developed a new ODE solver suite for PyTorch that eliminates some unintended side-effects that can occur in batched training with adaptive step sizes by tracking a separate solver state for each sample in a batch. Additionally, torchode can speed up your neural ODE or continuous normalizing flow by minimizing the solver overhead through various implementation optimizations in its code such as combined operations (einsum, addcmul), polynomial evaluation via Horner's rule and JIT compilation. See the paper for details.

I am happy to answer questions here on reddit. If you are a NeurIPS (+workshops) attendee, it would be great to see you at my poster at the DLDE workshop on Friday at 05:10 PT / 13:10 UTC or 09:05 PT / 18:05 UTC.



You must log in or register to comment.

MathChief t1_izehk9o wrote

Judging by the Van der Pol benchmark, your methods are faster than other neural ODE solvers, but are still several magnitude slower than traditional solvers such as Runge-Kutta. Why are many people excited about this?


martenlienen OP t1_izejs85 wrote

In these benchmarks we compare the same Runge-Kutta solver (5th order Dormand-Prince) implemented in all of these libraries. None of these libraries actually propose any new stepping methods. The point is to make ODE solvers available in popular deep learning methods to enable deep continuous models such as neural ODEs and continuous normalizing flows. The particular appeal of torchode is its optimized implementation and that it runs multiple independent instances of an ODE solver in parallel when you train on batches, i.e. each instance is solved with its own step size and step accept/reject decisions. This avoids a performance pitfall where the usual batching approach can lead to many unnecessary solver steps in batched training of models with varying stiffness, as we show in the Van der Pol experiment.


MathChief t1_izgdz7x wrote

One more question: for the Van der Pol benchmark

Using the old faithful ode45 in MATLAB (Runge-Kutta) to run test problem you listed in the poster

vdp1 = @(t, y) [y(2); 25*(1-y(1)^2)*y(2)-y(1)];
[t,y] = ode45(vdp1,[0 5],[1.2, 0.1]);

It takes only 0.007329 seconds for marching 389 steps using FP64 on CPU. What is the loop time's unit? What is the bottleneck of the implemented algorithm?


martenlienen OP t1_izi8k3g wrote

First, the difference in steps is probably due to different tolerances in the step size controller.

The loop times is measured in milliseconds. Of course, that is much slower than what you got in matlab. The difference is that we did all benchmarks on GPU, because that is the usual mode for deep learning even though it is certainly inappropriate for the VdP equation if you were interested in it for anything else but benchmarking the inner loop of an ODE solver on a GPU. I think, you can get similar numbers to your matlab code in diffrax with JIT compilation on a CPU. However, you won't get it with torchode because PyTorch's JIT is not as good as JAX's and specifically this line is really slow on CPUs. Nonetheless, after comparing several alternatives we chose this because, as I said, in practice in most of deep learning only GPU performance matters.


Rodot t1_izhk7o5 wrote

Also, people have used neural networks to solve the equation sof radiative transfer and used them in scientific papers


cheecheepong t1_izfuznn wrote

This is cool but couldn't they have picked a better name 😂.

Cannot unsee: "torChode" sounds like some wordplay on the Tor project.


fhchl t1_izi5ecp wrote

Nice work! Though diffrax is mentioned, it would be interesting to see a direct comparison between diffrax and torchode. Can you give some more details in how they differ in features and performance, apart from the library in which they are implemented?


martenlienen OP t1_izi7n9z wrote

Diffrax is an excellent project and a superset of torchode. torchode solves only ODEs, while diffrax combines ODEs, CDEs and SDEs (maybe more?) in the same framework. The reason why we created torchode is that we wanted to bring its structure and flexibility into the PyTorch ecosystem that is still more popular than JAX. In addition, we were also looking to create an optimized implementation in which we succeeded as far as I am concerned. Even though tracking multiple ODE solvers at once is inherently more complex than solving a batch of ODEs jointly, torchode is as fast as or faster than the other PyTorch ODE solvers in our experiments.


fhchl t1_izif55h wrote

Is this feature of torchode of solving multiple ODEs at once over some batch dimension comparable to jax.vmapping over that dimension in diffrax?


martenlienen OP t1_iziq1xp wrote

Yes, it is the same thing. Unfortunately, functorch is not advanced enough yet to just translate diffrax to PyTorch directly. Instead, we had to take care of batching everywhere explicitly to decide how long to loop etc.


fhchl t1_izjb0tf wrote

Aight! Thanks for the nice answers! I wish a good conference :)


dopadelic t1_izh5hi6 wrote

It's been years since I've implemented ODE solvers in my engineering courses, but I recall that ODEs are inherently not parallelizable since the subsequent timesteps require the current time steps to be solved. I've only worked with Euler's and Newton's method though.


martenlienen OP t1_izi71ti wrote

You are correct except for parallel-in-time integration methods that we also mention in the paper. But the "parallel" in the title refers to solving multiple ODEs in parallel independently which is contrary to what is currently done in ML. At the moment, training on a batch of ODEs means that you treat the batch as one large ODE that is solved jointly. torchode solves them independently from each other but still in parallel by tracking a separate current state, step size, etc. for each sample.


fhchl t1_izjd2vy wrote

+1 for the paper on parallel in time! Are there any implementation of those algorithms for torchode or diffrax out there?


martenlienen OP t1_izk7bk3 wrote

I am not aware of any but would be very interested if you find one