Viewing a single comment thread. View all comments

Sad-Comedian-711 t1_jcpvahm wrote

So there is flash attention and then there is block sparse flash attention.

Flash attention by itself only got them to 16k on an A100 for their model, to go further they needed to use windowed attention... You could have already gone to 16k with windowed attention before this paper without much issue.

The special thing about this windowed attention is that it is in blocks that can fit into SRAM. From what I can tell Python's implementation of Flash Attention doesn't look like it supports block sparse flash attention.

https://github.com/pytorch/pytorch/blob/eb32bb2ca6811ea21002699f4be884d3012dc362/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_fprop_kernel_1xN.h

While Triton's looks like it does:https://github.com/openai/triton/blob/c9740f0870f6ae2480acd2a76a5fb4c920bc5ce5/python/triton/ops/flash_attention.py

I think windowing must be done in blocks that align with the SRAM grid so it kinda has to be part of the Flash Attention implementation. You might be able to throw normal Big Bird block sparse attention on top...

You also may be able to call out to triton's implementation:
https://github.com/violethaze74/pumpkin-py/blob/d9250933bec045e6add61b3930ff3dbbe08f6501/aten/src/ATen/native/transformers/attention.cpp#L726

3