Viewing a single comment thread. View all comments

currentscurrents OP t1_j2hdsvv wrote

Someone else posted this example, which is kind of what I was interested in. They trained a neural network to do a toy problem, addition mod 113, and then were able to determine the algorithm it used to compute it.

>The algorithm learned to do modular addition can be fully reverse engineered. The algorithm is roughly:

>Map inputs x,y→ cos(wx),cos(wy),sin(wx),sin(wy) with a Discrete Fourier Transform, for some frequency w.

>Multiply and rearrange to get cos(w(x+y))=cos(wx)cos(wy)−sin(wx)sin(wy) and sin(w(x+y))=cos(wx)sin(wy)+sin(wx)cos(wy)

>By choosing a frequency w=2πnk we get period dividing n, so this is a function of x + y (mod n)

>Map to the output logits z with cos(w(x+y))cos(wz)+sin(w(x+y))sin(wz)=cos(w(x+y−z)) - this has the highest logit at z≡x+y(mod n), so softmax gives the right answer.

>To emphasise, this algorithm was purely learned by gradient descent! I did not predict or understand this algorithm in advance and did nothing to encourage the model to learn this way of doing modular addition. I only discovered it by reverse engineering the weights.

This is a very different way to do modular addition, but it makes sense for the network. Sine/cosine functions represent waves that repeat every frequency, so if you choose the right frequency you can implement the non-differentiable modular addition function just working with differentiable functions.

Extracting this algorithm is useful for generalization; while the original network only worked for mod 113, with the algorithm we can plug in any value for the frequency. Of course this is a toy example and there are much faster ways to do modular addition, but maybe it could work for more complex problems too.

6