patrickkidger

patrickkidger t1_j8fdtx2 wrote

Heads-up that my newer jaxtyping project now exists.

Despite the name is supports both PyTorch or JAX; it is also substantially less hackish than TorchTyping! As such I recommend jaxtyping instead of TorchTyping regardless of your framework.

(jaxtyping is now widely used internally.)

3

patrickkidger t1_j8fde35 wrote

On static shape checking: have a look at jaxtyping, which offers compile-time shape checks for JAX/PyTorch/etc.

(Why "JAX"typing? Because it originally only supported JAX. But it now supports other frameworks too! In particular I now recommend jaxtyping over my older "TorchTyping" project, which is pretty undesirably hacky.)

In terms of fitting this kind of stuff into a proper language: that'd be lovely. I completely agree that the extent to which we have retrofitted Python is pretty crazy!

1

patrickkidger t1_ivg6fsg wrote

Reply to comment by xl0 in [P] Lovely Tensors library by xl0

There's no JAX-specific information worth including, I don't think. A JAX array basically holds the same information as a PyTorch tensor, i.e. shape/dtype/device/nans/infs/an array of values.

The implementation would need to respect how JAX works, though. JAX works by substituting arrays for duck-typed "tracer" objects, passing them in to your Python function, recording everything that happens to them, and them compiling the resulting computation graph. (=no Python interpeter during execution, and the possibility of op fusion, which often means improved performance. Also it what makes its function transformations possible, like autodiff, autoparallelism, autobatching etc.)

This means that you don't usually have the value of the array when you evaluate your Python function -- just some metadata like its shape and dtype. Instead you'd have to create an op that delays doing the printing until runtime, i.e. not doing it during trace time.

..which sounds like a lot, but is probably very easy. Just wrap jax.debug.print.

1

patrickkidger t1_ivb3t7e wrote

See the conclusion of my thesis (linked above ;) )

TL;DR: everything neural PDEs, stable training of neural SDEs, applications of neural ODEs to ~all of science~, adaptive/implicit/rough numerical SDEs (although that one's very specialised), there's current work connecting NDEs with state space models (S4D, MEGA, etc.), ... etc. etc!

1

patrickkidger t1_iv5vb05 wrote

Neural differential equations! The continuous-time limit of a lot of deep learning models can be thought of as a differential equation with a neural network as its vector field.

A survey is On Neural Differential Equations.

Also +1 for /u/betelgeuse3e08's recommendations, which are primarily neural ODEs encoding particular kinds of physical structure; c.f. Section 2.2.2 of the above.

You can find a lot of code examples of neural ODEs/SDEs/etc. in JAX in the Diffrax documentation.

This topic is kind of my thing :) DM me if you end up going down this route, I can try to point you at the open problems.

14