Viewing a single comment thread. View all comments

patrickkidger t1_ivb4la8 wrote

This is really nice!
...I might shamelessly steal this idea for my JAX work. :D

6

xl0 OP t1_ivd0rca wrote

Haha, thank you! You are not the first person to mention JAX, so I guess I'll do a JAX version next. :)

I have a rough idea of what it is, and as I understand it, it's more about transforming functions. Do you have ideas about anything JAX-specific that should be included in the ndarray summary?

1

patrickkidger t1_ivg6fsg wrote

There's no JAX-specific information worth including, I don't think. A JAX array basically holds the same information as a PyTorch tensor, i.e. shape/dtype/device/nans/infs/an array of values.

The implementation would need to respect how JAX works, though. JAX works by substituting arrays for duck-typed "tracer" objects, passing them in to your Python function, recording everything that happens to them, and them compiling the resulting computation graph. (=no Python interpeter during execution, and the possibility of op fusion, which often means improved performance. Also it what makes its function transformations possible, like autodiff, autoparallelism, autobatching etc.)

This means that you don't usually have the value of the array when you evaluate your Python function -- just some metadata like its shape and dtype. Instead you'd have to create an op that delays doing the printing until runtime, i.e. not doing it during trace time.

..which sounds like a lot, but is probably very easy. Just wrap jax.debug.print.

1

xl0 OP t1_ivj5slm wrote

I started working on it. Will make sure repr works inside jit and parallel before moving to other things.

https://github.com/xl0/lovely-jax

Please let me know if you have any thoughts, I'm very new to JAX.

1