Submitted by wangyi_fudan t3_y2w87i in MachineLearning

The proof is simple:

attention=softmax(QKt)V

=softmax(XWq (XWk)t)XWv

=softmax(XWqWktXt)XWv

let Wk'=WkWq'

attention=softmax(X(XWk')t)XWv

=softmax(XK')V

now we see that Q=XWq is replaced by X, reduced 1/4 paramters in attention module.

I did real experiment and found that with 3/4 parameters of original attention, the difference of loss is 0.01 during the training process and does not increase. Though Wq is not necessary, but with 1/4 more parameters it seems just slightly better.

But in multihead attention, Wq is necessary. However, research has shown that stacking many small single heads attention modules to form a very deep model is better than wider multi-head attention (single head is enough).

17

Comments

You must log in or register to comment.

StellaAthena t1_is7iss2 wrote

The proof is even more simple: (xW_q)(xW_k)^T = x(W_qW_k^T )x^T = xWx

The problem is that W_q and W_k are not square matrices. They are d_model by d_head, and so their product is d_model x d_model. In practice d_model >> d_head (e.g., they’re 4096 and 256 respectively in GPT-J). Doing it your way uses a lot more memory and compute

22

maizeq t1_is57j81 wrote

Transformers aren’t my field of expertise so I don’t know if this has been done before but hah, neat derivation!

Though I would expect their to be no difference in loss in that case. Was the difference positive or negative? And do you think the difference can be chalked up to numerical precision errors that accumulate due to the double vs single matrix multiplication? An easy test of this would be to compare K’ and Wq (XWk)t and see how close they are throughout training for a particular sample.

5

mrfox321 t1_is7pudf wrote

Sure, but using W_q allows for low-rank representations of

W := W_k @ W_q^T

5

UltimateGPower t1_is6adxb wrote

Why is it necessary for multiple heads? What this proof shows is that it is enough to transform either the keys or queries.

3

Co0k1eGal3xy t1_is6xpmg wrote

How does the loss of converged models compare?

Removing parameters is similar to decreasing the learning rate as far as I remember, so you can't compare them during early training stages.

1

Reasonable_Boss2750 t1_is97cn3 wrote

Possible reason why the author uses attention with Wq and Wk is to fuse information in both encoder and decoder. In that case the formula is (XenWq)(XdeWk)t

1