Submitted by super_deap t3_11tmpc5 in MachineLearning
mike94025 t1_jcv94un wrote
Reply to comment by programmerChilli in [D] PyTorch 2.0 Native Flash Attention 32k Context Window by super_deap
SDPA is used by F.multi_head_attention_forward (if need_weights=False) which is used by nn.MHA and nn.Transformer* as well as other libraries. (source)
Public service announcement: need_weights defaults to True, and guts performance. (Because allocating and writing the attention weight tensor defeats the memory BW advantages of flash attention.)
Also, if `key_padding_mask is not None` performance will suffer (because this is converted into an attention mask, and only the causal attention mask is suppprted by Flash Attention). Use Nested Tensors for variable sequence length batches.
Viewing a single comment thread. View all comments