Viewing a single comment thread. View all comments

mike94025 t1_jcx5xvg wrote

Data type?

SDPA currently has 3 kernels implemented by a kernel picker.

  • sdpa_math
  • sdpa_flash
  • sdpa_mem_eff

A kernel picker picks the best given your constraints

  • Math is the trusted kernel from the equation in the paper.
  • Flash only works for FP16 and BF16, and on SM80 (e.g., A100).
  • mem_efficient kernel works on older architecture levels, and supports FP32, but the upside is limited due to lack of compute capacity for FP32. FP16 or BF16 should help. Also, there are requirements on alignment, dropout values etc to qualify for the high-perf SDPA implementations. Dropout required to be 0 @ PT2.0

Also, different kernels parallelize across different dimensions, so B=1 will not work with all of those kernels.

In a nutshell, performance comes at the price of generality, and GPUs are finnecky to get the performance, so our inputs must adhere to those, and parallelization strategies matter for different combinations of dimensions.

5