Submitted by SaltyStackSmasher t3_11euzja in MachineLearning

so I was just going through the VAE reparameterization and thought whether it can be extended to beam sampling. is this possible at all ? I think if we can backprop through beam sampling, we can directly optimise for bleu ?

please correct me if I'm wrong. I'm happy to explore a bit as well, I just don't know where to start.

12

Comments

You must log in or register to comment.

cnapun t1_jage50a wrote

I'm not an expert on this topic, but I've discussed it with coworkers. I do believe you should be able to backprop through sampling, mathematically at least. My suspicion is that you'll run into the same problem as you have with RNNs, where backpropping through many steps leads to high variance in gradients. I'd search for some papers that have explored this; I assume they exist.

5

SaltyStackSmasher OP t1_jagisv7 wrote

thanks for the response. my main concern with beam sampling and backprop is the fact that context for the 2nd token will include 1st token. I believe in the RNN case, this wouldn't necessarily matter since only the hidden state is being propagated forward. In transformers, we have to completely redo the forward pass for 2nd token onwards and these subsequent forward passes don't have anything in common, so I'm a bit confused about how the gradients will flow exactly.

please let me know if I wasn't clear in explaining my problem. thanks again for your response :)

2

cnapun t1_jai24sf wrote

What I was trying to say was that doing this sampling approach (in a transformer) seems like it would have similar issues to a RNN, in that your computational graph will be repeated N times, where N is the rollout size. This makes me suspect that you'll get a lot of noise in your gradient estimates if N is large (also iirc Gumbel softmax gradients are biased, which might cause some more issues if chaining them)

1

RaeudigerRaffi t1_jagjc74 wrote

So in general it is not possible to backpropagate directly through any operation that involves sampling from a probability distribution. There are however techniques( policy learning, reparameterization trick) from reinforcement learning that try to circumvent this problem.

4

Kaleidophon t1_jah1xwe wrote

You can backpropagate through samples of a categorical distribution using Gumbel softmax, and as far as i remember you can apply a reparameterization trick for all distributions of the exponential family.

3

RaeudigerRaffi t1_jah39t7 wrote

You are right Gumbel Softmax is a possibility with which you can backprop. But given that he is trying to do beam sampling and backprop through it at some point you need to argmax on your gumbel softmax vector in order to actually pick the token (assuming there is no way to work with the vector representations down the line correct me if i am wrong) and then this becomes not differentiable

3

RaeudigerRaffi t1_jahpbod wrote

To add to this I thought a bit about it and technically in PyTorch, this should be possible to do with some trickery with custom autograd functions. You can probably sample with Gumbel Softmax and return the argmax. In the custom backward you can just skip the argmax part and backprop as if the Gumbel Softmax output has been returned and not the argmax on the Gumbel Softmax.

1

Kaleidophon t1_jah1ke1 wrote

I think what you are looking for is the Gumbel-softmax trick, which is basically differentiable sampling for categorical distributions. But in your case the problem will be that BLEU is not differentiable, and often in MT you find that when you try to directly optimize for some translation quality metric, the actual quality as assessed by human judges decreases.

1

Emergency_Apricot_77 t1_jah9rb7 wrote

Why go with BLEU though ? OP didn't particularly mention optimizing sequence level metrics. Can't we still use cross entropy ? Something as follows:

Sample first token, calculate cross-entropy with first token of gold

Sample second token, calculate cross-entropy with second token of gold

Sample third token, calculate cross-entropy with third token of gold

... and so on ?

​

This way we still have differentiable metric but we have a much better alignment between train and inference scenarios -- as opposed to current teacher forcing training and sampling inference -- which I thought the OP was going for.

1

Kaleidophon t1_jalv0qs wrote

>Why go with BLEU though ? OP didn't particularly mention optimizing sequence level metrics.

From OPs post above:

>is this possible at all ? I think if we can backprop through beam sampling, we can directly optimise for bleu ?

1