Submitted by SaltyStackSmasher t3_11euzja in MachineLearning
Kaleidophon t1_jah1xwe wrote
Reply to comment by RaeudigerRaffi in [D] backprop through beam sampling ? by SaltyStackSmasher
You can backpropagate through samples of a categorical distribution using Gumbel softmax, and as far as i remember you can apply a reparameterization trick for all distributions of the exponential family.
RaeudigerRaffi t1_jah39t7 wrote
You are right Gumbel Softmax is a possibility with which you can backprop. But given that he is trying to do beam sampling and backprop through it at some point you need to argmax on your gumbel softmax vector in order to actually pick the token (assuming there is no way to work with the vector representations down the line correct me if i am wrong) and then this becomes not differentiable
RaeudigerRaffi t1_jahpbod wrote
To add to this I thought a bit about it and technically in PyTorch, this should be possible to do with some trickery with custom autograd functions. You can probably sample with Gumbel Softmax and return the argmax. In the custom backward you can just skip the argmax part and backprop as if the Gumbel Softmax output has been returned and not the argmax on the Gumbel Softmax.
jamesvoltage t1_jajjsh3 wrote
The nano chat GPT repository extended with Gumbel softmax https://github.com/sanjeevanahilan/nanoChatGPT
Viewing a single comment thread. View all comments