Dependent_Ad5120
Dependent_Ad5120 t1_jdec7kx wrote
Reply to comment by oathbreakerkeeper in [D] PyTorch 2.0 Native Flash Attention 32k Context Window by super_deap
I don't know. I was using pure fp16, no autocast and it works.
Dependent_Ad5120 t1_jd3m0ce wrote
Reply to comment by oathbreakerkeeper in [D] PyTorch 2.0 Native Flash Attention 32k Context Window by super_deap
try fp16, that doesn't require training=False apparently.
Dependent_Ad5120 t1_jd3knio wrote
Reply to comment by Dependent_Ad5120 in [D] PyTorch 2.0 Native Flash Attention 32k Context Window by super_deap
OK, I found out why. To use flash attention, I had to use fp16. It is a bit faster then using memory_efficient attention in my test.
Dependent_Ad5120 t1_jd1d00j wrote
Reply to comment by mike94025 in [D] PyTorch 2.0 Native Flash Attention 32k Context Window by super_deap
It seems to me that I have to call model.eval() to use the memory_efficient attention. Otherwise, it throws an error of no available kernel.
I tried on both rtx 3090 and A100, in both cases, it seems only have enable_flash=True resulted in the same error of no available kernel, even with model.eval().
So my questions are:
- with model.eval(), does it mean drop_out is not enabled during training?
- Am I doing something wrong for flash attention? How do I actually enable it?
Thanks a lot!
Dependent_Ad5120 t1_je5qfmp wrote
Reply to comment by oathbreakerkeeper in [D] PyTorch 2.0 Native Flash Attention 32k Context Window by super_deap
I don't have a github repo for this, but it is pretty simple:
```
model = nn.Transformer().cuda().half
input = torch.rand(..).cuda().half
with sdp_kernel(...enable only flash attn):
output = model(input)
```
These 4 lines should be enough.