Viewing a single comment thread. View all comments

SaltyStackSmasher OP t1_jagisv7 wrote

thanks for the response. my main concern with beam sampling and backprop is the fact that context for the 2nd token will include 1st token. I believe in the RNN case, this wouldn't necessarily matter since only the hidden state is being propagated forward. In transformers, we have to completely redo the forward pass for 2nd token onwards and these subsequent forward passes don't have anything in common, so I'm a bit confused about how the gradients will flow exactly.

please let me know if I wasn't clear in explaining my problem. thanks again for your response :)

2

cnapun t1_jai24sf wrote

What I was trying to say was that doing this sampling approach (in a transformer) seems like it would have similar issues to a RNN, in that your computational graph will be repeated N times, where N is the rollout size. This makes me suspect that you'll get a lot of noise in your gradient estimates if N is large (also iirc Gumbel softmax gradients are biased, which might cause some more issues if chaining them)

1