Submitted by ChrisRackauckas t3_y74w8j in MachineLearning

Sharing our paper accepted at NeurIPS: Automatic Differentiation of Programs with Discrete Randomness

A summary is given in a Twitter thread. It goes into a high-level of how the method works. The core idea behind the paper is the following: if we treat a program as a random variable X(p), can we come up with a similar number definition such that we get two random variables, (X(p),Y(p)), such that E[Y(p)] = dE[X(p)]/dp? We give a detailed derivation of how to do this properly.

This method gives an unbiased, low variance, and fully automatic method for automatic differentiation of such programs.

A fairly optimized implementation of the method is available as an open-source package:

https://github.com/gaurav-arya/StochasticAD.jl

There are still many things to do in this area. For example, from this it should be easy to train neural networks to generate differential equations of mean behavior over time which match the statistics of an agent-based model, but we can't say we've tried all of the applications. Also, there's a lot more to do in terms of compiler optimizations for this new AD system.

But, even if you don't use it, the idea is fun and cool and you should check it out!

81

Comments

You must log in or register to comment.

EmmyNoetherRing t1_isswkjx wrote

This seems really neat! Just to make sure I understand correctly, do you have a link to an example of what you’re thinking about when you say “agent based model”? I tend to associate that with classical AI where there’s often a discrete and generally finite set of states and transitions, and the agent model is a policy on which transitions to use in which states.

4

ChrisRackauckas OP t1_issyai3 wrote

We mean the standard "agent based model" https://www.pnas.org/doi/10.1073/pnas.082080899, https://en.wikipedia.org/wiki/Agent-based_model . The kind of thing you'd use Agents.jl for. For example, look at agent-based infection models. In these kind of models you create many individuals (agents) with rules. Each agent moves around, but if one is standing near an agent that is infected, there's a probability of infecting the nearby agent. What is the average percentage of infected people at time t?

2

PolygonAndPixel2 t1_ist25ao wrote

That sounds interesting. I didn’t get to read it completely yet but I have a couple of questions:

- You say "a mathematical program". What do you mean by that or rather what is a program that is not mathematical? Any computer program is just a concatenation of basic functions which can be derived. Throw in the chain rule and you can use AD for any program where the gradients are exact for the execution path.

- If I understand that right then p is a random variable (or rather the probability of a random variable to take a given value) that changes the outcome in a discontinuous way. Is it correct to say that the execution path of the program changes with different outcomes for p, i.e., if(random_event(p)) {return 1;} else {return 0;}? Or is this a different problem?

- I didn’t take a look into your code (and Julia isn’t my first language) but can you estimate how much work it is to incorporate this kind of AD in existing AD tools like CoDiPack?

1

EmmyNoetherRing t1_ist2jei wrote

Thanks! Sounds very related, but from a different angle than the one I've worked with before (which is closer to multi-agent systems, by the wikipedia article's nomenclature). Still discrete underneath though, if I'm reading it correctly? Or I guess looking at real valued parameters for the agents programs/states/locations?

1

ChrisRackauckas OP t1_ist3bu8 wrote

> What do you mean by that or rather what is a program that is not mathematical?

If it outputs strings or code it may not work with this method. It should output numbers in a way that has a well-defined (differentiable) expectation.

> Is it correct to say that the execution path of the program changes with different outcomes for p, i.e., if(random_event(p)) {return 1;} else {return 0;}? Or is this a different problem?

It can. One of the examples is differentiation of an inhomogeneous random walk, which is a stress test of doing this kind of branch handling.

> but can you estimate how much work it is to incorporate this kind of AD in existing AD tools like CoDiPack?

That's hard to say. I would say it wouldn't be "too hard", though it may be hard to add this without overhead for "normal" cases? It would make the code more complicated but the standard cases are just a special case here, so it should be fine.

1

schwagggg t1_iswqh92 wrote

cool stuff!

2 things:

  1. i am still trying to wrap my head around how to do this stuff: say we have an 2 layer NN with Bernoulli neurons, how do you take derivative wrt to the first layer’s weight in this case?

  2. seems to me that this approach needs many function evaluations, does it scare well wrt # stochastic variables? if i use it for a VAE with expensive decoder and say 1024 stochastic latents, would it be bad?

1

ChrisRackauckas OP t1_iswr0wc wrote

(1) while running your primal program, you run another problem that is propagating infinitesimal probabilities of certain pieces changing, and then it chooses the flips according to the right proportion (as derived in the paper) to give two correlated but different runs to difference for Y(p). But this Y(p) is defined to have the property that E[Y(p)]=dE[X(p)]/dp with a low variance, so you do this a few times and that thing is your gradient estimate. (2) unlike previous other algorithms with known exponential cost scaling (for example, see https://openreview.net/forum?id=KAFyFabsK88 for a deep discussion on previous work's performance), this scales linearly. 1024 should be fine. Note that this is related to forward mode AD so "really big" needs more work, but that size is fine.

2

rikkajounin t1_isxk23o wrote

Very interesting work!

What about doing the same with reverse mode AD? Are there some issues in this case?

1

ChrisRackauckas OP t1_isxmv6s wrote

We mentioned at the end of the paper that reverse mode requires smoothing in a way that works but indices a bias (except in some cases like the particle filter). This is something we will be looking deeper into.

2

schwagggg t1_isxz4cw wrote

then this sounds like measure valued derivative a bit? you perturb then calculate derivative. then wouldn’t this be at least O(D) expensive for one layer, and O(LD) for L layers of D dim rvs?

1