Intro

In this 10-part series, we're going to implement Flash Attention 2 from scratch on Ampere GPUs. We'll build an initial implementation and optimize it over 16 kernel iterations, all without importing any external libraries. By the final kernel, we'll reach 99.2% the performance of the official implementation on the A100 and 102.9% on the RTX 3090 (at sequence length 4096, ).

You can find the code here.

Prerequisites

What you'll need:

  • Solid CUDA experience (memory hierarchy, occupancy, tiling, bank conflicts, etc.)
  • Familiarity with attention mechanisms and Flash Attention basics

New to Flash Attention? Start with the Flash Attention 2 paper or this ELI5 explanation.

Kernel Specification

To keep this series focused and manageable, we'll narrow our focus to a well-defined slice of features:

  • forward pass only
  • non-causal attention
  • head dimension = 128
  • no dropout or KV caching
  • equal query/key/value sequence lengths
  • sequence lengths divisible by block sizes (typically 64-128 in our implementation, as defined in the paper)
  • 16-bit (bf16/fp16) input and output tensors, softmax calculation in fp32

Full Overview

Over ten blog posts, we'll implement Flash Attention 2 from scratch. We'll build up to a basic kernel that achieves just 49.5% of reference performance and systematically optimize it through 16 iterations.

Here's where it gets interesting: our optimization journey has a twist. Kernels 1-7 will be developed on RTX 3090, reaching 101% of reference performance. But when we run that same kernel on A100, performance drops to ~80%. This hardware-specific behavior will drive our exploration into advanced assembly-level optimizations in kernels 8-16, where we'll close the gap to within 1% of reference performance.

I've currently published three parts of the series, with the rest coming soon. Subscribe to the RSS feed to get notified when they're out.

  1. Intro (this post)
  2. Building Blocks
    • We'll cover the key CUDA operations and primitives we'll use, which include tensor core operations (mma) and memory transfers (cp.async & ldmatrix).
    • These will be the most performant instructions available on the Ampere architecture.
  3. RTX 3090: Kernel 1
    • We'll implement a basic kernel using the instructions we learned about in the previous part and get a working version running on an RTX 3090.
  4. RTX 3090: Bank Conflicts & Swizzling (Kernel 2)
    • We'll profile our kernel and determine that bank conflicts are the largest bottleneck.
    • We'll look at the consequences of bank conflicts and implement swizzling to eliminate them.
  5. RTX 3090: CUTLASS GEMM Optimizations (Kernels 3-5)
    • We'll build on top of the significantly improved kernel 2 from the previous part by implementing techniques utilized in Cutlass. Cutlass is Nvidia's C++ and Python library for building CUDA kernels with heavy GEMM components.
    • These techniques will help us overlap memory transfers with computation to increase the number of in-flight instructions and improve our latency hiding
  6. RTX 3090: FP Instruction Fusion & Auto-Tuning (Kernels 6 & 7)
    • We'll find that we're hitting the tensor throughput rooflines on the RTX 3090, so we'll look elsewhere to improve our performance.
    • The strategies we'll implement are FP instruction fusion and auto-tuning.
    • By the end, we'll surpass reference performance by ~1%. However, when we compare our kernel with the reference on an A100, we'll find significant performance gaps.
  7. A100: Profile Analysis and Block Size Configurations
    • We'll profile kernel 7 on the A100 and find the causes for the significant performance gap between the RTX 3090 and the A100.
    • We'll also look at the impact of different block sizes on the kernel.
  8. A100: Instruction Reduction (Kernels 8 to 11)
    • We'll optimize our kernel for the A100, primarily by reducing instruction count. This will involve a lot of low level assembly analysis, so you'll want to get comfortable reading SASS code.
    • We'll also start taking a closer look at the Ampere microarchitecture.
  9. A100: Final Optimizations (Kernels 12-16)
    • We'll make final optimizations to reach within 1% of reference performance.
  10. A100: Kernel Analysis
    • We'll do a deep dive analysis of kernels 10, 13, and 16, guided by previous research on Ampere microarchitecture.

Kernels

You can find the code for all kernels on GitHub.

Foundation, CUTLASS Optimizations, and FP Fusion (Kernels 1-7)

  1. Base Implementation
  2. Swizzling
  3. Eagerly Loading K & V Blocks
  4. Interleaving On-Chip LD/ST with Computation
  5. Double Buffering Shared Memory to Register File Loads
  6. Improving FP32 Throughput *
  7. Auto-Tuning

A100-Focused Instruction-Level Optimizations (Kernels 8-11)

  1. Reducing Logic and Bit-Shift Instructions
  2. Reducing IMAD.MOV.U32 and MOV instructions
  3. Removing CSRZ Instructions + Optimizing Initial Softmax Iteration
  4. Strided Swizzling from the RF (register file) to SMEM

A100 Final Tuning (Kernels 12-16)

  1. Miscellaneous Code Changes
  2. Iterating Backwards *
  3. Cache Configuration
  4. Tiling along d_head *
  5. Static GMEM Stride

*Optimizations inspired by the official implementation

Performance

For GEMM kernels, the standard reference for comparison is cuBLAS. For Flash Attention, we'll benchmark against the official implementation.

We'll focus on sequence length 4096 for most of our benchmarking. While this is longer than typical for bidirectional transformers, it strikes a good balance: shorter sequences don't fully utilize the GPU and achieve lower TFLOPs, while longer sequences take too long to benchmark iteratively but achieve higher TFLOPs.

You can find more details on benchmarking setup and methodology in the appendix here.

Performance Table

Each percentage represents performance relative to the official Flash Attention implementation. Notice how RTX 3090 reaches competitive performance early (kernel 2), while A100 requires much deeper optimization.

Kernel IterationA100A100RTX 3090RTX 3090
seq_len = 4096harm. meanseq_len = 4096harm. mean
1. Base Implementation15.8%16.6%49.5%49.8%
2. Swizzling72.6%72.4%98.3%98.6%
3. Eagerly Loading K & V Blocks77.6%79.9%99.4%100.0%
4. Interleaving On-Chip LD/ST with Computation77.6%80.0%100.0%100.4%
5. Double Buffering Shared Memory to Register File Loads76.8%79.1%99.7%100.3%
6. Improving FP32 Throughput78.1%80.4%99.9%100.4%
7. Auto-Tuning80.3%82.3%101.5%101.8%
8. Reducing Logic and Bit-Shift Instructions87.8%88.9%101.7%101.2%
9. Reducing Copy Instructions95.3%96.3%97.5%97.4%
10. Removing CSRZ Instructions + Optimizing Initial Softmax Iteration93.9%95.0%102.9%102.3%
11. Strided Swizzling from the RF → SMEM95.2%96.7%102.8%102.3%
12. Miscellaneous Code Changes95.3%97.0%102.8%102.3%
13. Iterating Backwards97.6%98.8%101.5%101.2%
14. Cache Configuration97.7%99.1%101.5%101.2%
15. Tiling along d_head dimension97.9%99.5%101.5%101.3%
16. Static GMEM Stride99.2%100.4%100.9%100.7%
0. Reference (TFLOPs)186.4174.067.2966.2

The harmonic mean is taken over sequence lengths 512, 1024, 2048, 4096, 8192, 16384.

Why Flash Attention?

Flash Attention is worth implementing from scratch for two compelling reasons:

  1. It's one of the most impactful innovations made in ML engineering. Attention scales quadratically in compute and memory with sequence length, and this bottleneck becomes increasingly critical as demand for longer sequences grows. Flash Attention made huge leaps towards this by

    • Virtually solving the memory scaling issue by reducing memory complexity from to
    • Mitigating the compute issue by significantly increasing data reuse in fast, on-chip memory
  2. Flash attention is also a complex algorithm with unique optimization challenges beyond standard GEMM kernels. It combines two back-to-back GEMMs with additional state management and a non-trivial FP workload (softmax). The algorithm complexity makes it an excellent advanced GPU programming exercise

Our Implementation Target: Ampere GPUs

We're targeting Ampere because it's the sweet spot for this exercise: the last generation where consumer and HPC cards share identical CUDA APIs. Newer generations (Hopper vs Ada, Blackwell) have diverged significantly—HPC accelerators get exclusive tensor pipeline features that aren't available on consumer cards.

This makes Ampere an interesting target for GPU optimization because we can write identical code and compare performance across very different hardware configurations on level ground. Ampere's HPC and consumer lineups have dramatically different performance characteristics:

  • A100 (HPC): 312 tensor TFLOPs / 19.5 FP32 TFLOPs peak throughput
  • RTX 3090 (Consumer): 71 tensor TFLOPs / 35.6 FP32 TFLOPs peak throughput

We'll see how this performance disparity plays out as we make progress through the different kernels.

Since we're targeting Ampere, we'll exclude newer Hopper/Blackwell features like:

  • TMA related instructions (cp.async.bulk)
  • async warp group mmas (wgmma) (Hopper)
  • 5th gen tensor operations (Blackwell)

Notation & Terminology

I'll mostly stick to standard CUDA terminology and the notation from the Flash Attention papers. I'll include some of the terms and acronyms I use most frequently here, but you can find a full reference in Glossary.

  • GMEM: global memory
  • SMEM: shared memory
  • LMEM: local memory
  • RF: register file / register memory
  • : Query and output tensors handled by the current CTA
  • : The -th key/value tile
  • : query rows in the block
  • : key and rows in the and blocks

Up Next

In Part 2, we'll dive into the CUDA building blocks that enable high-performance Flash Attention on Ampere architecture: tensor core operations, asynchronous memory transfers, shared memory banking strategies, and warp-level programming patterns. These primitives will form the foundation for all our subsequent kernel optimizations.