Viewing a single comment thread. View all comments

MrAcurite t1_j2h2ei1 wrote

You can teach a neural network to solve, say, mazes in a 10x10 grid, but then you'd need to train it again to solve them in a 20x20 grid, and there would be a size at which the same model would simply cease to work. Whereas Dijkstra's, even if it slows down, would never fail to find the exit if the exit exists.

You might be able to train a model to find new strategies in a specific case, analyze it, and then code your understanding of it yourself, kinda like using a Monte Carlo approach to find a numerical answer to a problem before trying an analytic one. But you're not going to be able to pull an algorithm out of the parameters directly.

6

currentscurrents OP t1_j2hdsvv wrote

Someone else posted this example, which is kind of what I was interested in. They trained a neural network to do a toy problem, addition mod 113, and then were able to determine the algorithm it used to compute it.

>The algorithm learned to do modular addition can be fully reverse engineered. The algorithm is roughly:

>Map inputs x,y→ cos(wx),cos(wy),sin(wx),sin(wy) with a Discrete Fourier Transform, for some frequency w.

>Multiply and rearrange to get cos(w(x+y))=cos(wx)cos(wy)−sin(wx)sin(wy) and sin(w(x+y))=cos(wx)sin(wy)+sin(wx)cos(wy)

>By choosing a frequency w=2πnk we get period dividing n, so this is a function of x + y (mod n)

>Map to the output logits z with cos(w(x+y))cos(wz)+sin(w(x+y))sin(wz)=cos(w(x+y−z)) - this has the highest logit at z≡x+y(mod n), so softmax gives the right answer.

>To emphasise, this algorithm was purely learned by gradient descent! I did not predict or understand this algorithm in advance and did nothing to encourage the model to learn this way of doing modular addition. I only discovered it by reverse engineering the weights.

This is a very different way to do modular addition, but it makes sense for the network. Sine/cosine functions represent waves that repeat every frequency, so if you choose the right frequency you can implement the non-differentiable modular addition function just working with differentiable functions.

Extracting this algorithm is useful for generalization; while the original network only worked for mod 113, with the algorithm we can plug in any value for the frequency. Of course this is a toy example and there are much faster ways to do modular addition, but maybe it could work for more complex problems too.

6