Viewing a single comment thread. View all comments

tripple13 t1_jck9593 wrote

Does anyone know why they didn't add the flashattention directly into the MultiheadAttention-modules? Seems to be integrated, awesome!


programmerChilli t1_jcnydmw wrote

I think it is used in Pytorch’s nn.transformerencoder but a lot of people like implementing their own.


mike94025 t1_jcv94un wrote

SDPA is used by F.multi_head_attention_forward (if need_weights=False) which is used by nn.MHA and nn.Transformer* as well as other libraries. (source)

Public service announcement: need_weights defaults to True, and guts performance. (Because allocating and writing the attention weight tensor defeats the memory BW advantages of flash attention.)

Also, if `key_padding_mask is not None` performance will suffer (because this is converted into an attention mask, and only the causal attention mask is suppprted by Flash Attention). Use Nested Tensors for variable sequence length batches.


mike94025 t1_je5ojaw wrote

It is. Follow the call tree into F.multi_head_attention_forward


tripple13 t1_je5seed wrote

Is that right? I some how end up here when trying to assess what the F.multi_head_attention call does in the Class definition.

But I trust you're right, it would only make sense, I just couldn't identify the calls myself.