FlexAttention now supports a FlashAttention-4 (FA4) backend on Hopper and Blackwell GPUs, delivering 1.2×–3.2× speedups over the existing Triton implementation on compute-bound workloads. The integration required extending FA4 with score-modification hooks and block-sparse iteration in both forward and backward passes, while PyTorch's Inductor compiler was updated to auto-generate CuTeDSL code from user-defined score/mask modification functions. On Blackwell GB200, the new backend achieves 1.6–3.2× forward and 1.85–2.3× backward speedups over Triton for standard patterns, and similar gains for FlexAttention-only patterns like ALiBi, document masking, and sliding window. The post details the low-level kernel architecture challenges on Blackwell (warp specialization, TMEM, async pipelines, ping-pong tiling), explains the Inductor→CuTeDSL lowering pipeline, and documents current limitations including block size constraints, dynamic scalar recompilation, and lack of backward support for captured buffers requiring gradients.
Table of contents
Blackwell: bigger tensor cores, bigger problemsFlashAttention-4 as the foundationInductor → CuTeDSL: the glue layerFlexifying FlashAttention-4ResultsCorrectness and benchmark methodologyFuture workThanksFurther reading / linksSort: