FlexAttention now supports a FlashAttention-4 (FA4) backend on Hopper and Blackwell GPUs, delivering 1.2×–3.2× speedups over the existing Triton implementation on compute-bound workloads. The integration required extending FA4 with score-modification hooks and block-sparse iteration in both forward and backward passes, while PyTorch's Inductor compiler was updated to auto-generate CuTeDSL code from user-defined score/mask modification functions. On Blackwell GB200, the new backend achieves 1.6–3.2× forward and 1.85–2.3× backward speedups over Triton for standard patterns, and similar gains for FlexAttention-only patterns like ALiBi, document masking, and sliding window. The post details the low-level kernel architecture challenges on Blackwell (warp specialization, TMEM, async pipelines, ping-pong tiling), explains the Inductor→CuTeDSL lowering pipeline, and documents current limitations including block size constraints, dynamic scalar recompilation, and lack of backward support for captured buffers requiring gradients.

18m read timeFrom pytorch.org
Post cover image
Table of contents
Blackwell: bigger tensor cores, bigger problemsFlashAttention-4 as the foundationInductor → CuTeDSL: the glue layerFlexifying FlashAttention-4ResultsCorrectness and benchmark methodologyFuture workThanksFurther reading / links

Sort: