This appendix dives into how block size configurations affect instruction patterns and performance in Flash Attention. If you've been wondering why larger blocks tend to be more efficient, this is for you.
We'll examine how different block configurations affect instruction patterns by analyzing kernel 16 (our final optimized kernel). Since kernel 16 is heavily optimized, it has less noise from unoptimized code generation, giving us cleaner insights into the fundamental tradeoffs.
We'll discuss how the following design choices impact instruction mix and resulting kernel metrics:
- Reloading vs persisting
in the RF (register file) - The size of
(query tile size) - The number of query rows per warp (
) - The size of
(key/value tile size)
Note: unless noted otherwise, instruction count formulas are per warp per
/ tile.
Profiling and Benchmarking
All profiles and benchmarks were run on the A100 with
and . For the profiles, each configuration was restricted to an occupancy of 2 CTAs per SM. For the benchmarks, this restriction was lifted.
Here's the general principle: larger block sizes (mma instructions. We're computing the same output regardless of block size - the block size just affects how we organize that computation.
Larger block sizes also decrease the total number of iterations executed. This typically results in fewer branches, barriers, and potentially address calculation instructions. However, I'll forgo these, as their impact is typically smaller, harder to measure, or highly variable.
Cascading Effects From Changing Block Sizes
Changing block sizes triggers cascading effects across multiple dimensions. Occupancy, register pressure, memory access patterns, instruction scheduling, and data dependencies all shift simultaneously, making it challenging to isolate the impact of any single factor. The impacts may also differ across different devices. With that caveat in mind, let's examine how these changes manifest.
The Persist vs. Reload Tradeoff for
One key design choice is whether to persist
Persisting
Reloading ldmatrix instructions but can avoid register spills. The overhead varies with block size - larger
So what's the actual cost of reloading? Let's quantify this overhead. With 4 warps per CTA, the total number of ldmatrix instructions per
The
denominator comes from each instruction loading elements.
Reloading
times more ldmatrix instructions (excluding the initial load when persisting
| 32 | 64 | |
|---|---|---|
| 64 | 1.25x | 1.125x |
| 128 | 1.5x | 1.25x |
These multipliers show that reload overhead is most significant for large
But how does this play out in practice? The impact of this choice becomes clear when we examine SMEM to register bandwidth utilization on the A100:
| Persist | Total ldmatrix | SMEM → RF % Peak 1 | ||
|---|---|---|---|---|
| 64 | 32 | True → False | 2.7e+08 → 3.4e+08 | 60.01 → 69.19 |
| 64 | 64 | True → False | 2.7e+08 → 3.0e+08 | 62.87 → 72.82 |
| 128 | 32 | True → False | 1.3e+08 → 2.0e+08 | 30.50 → 49.89 |
Note: % peak metrics can increase even when doing less total work. Since these percentages are calculated as (work done / cycles elapsed), fewer cycles may result in higher peak percentages.
Reloading ldmatrix instructions. But this isn't necessarily a bad tradeoff. Persisting
The key insight is that reloading ldmatrix instructions for K/V tiles across all CTAs - more than offsetting the reload overhead for
How and Query Rows Per Warp Affect Arithmetic Intensity
and L2 Arithmetic Intensity
n_warps to 4, this also halves the total number of warps).
Why does this matter? Each CTA must copy the entire cp.async instructions for cp.async instructions for
Total cp.async | L2: Load % Peak2 | ||
|---|---|---|---|
| 64 → 128 | 32 | 6.8e+07 → 3.4e+07 | 44.07% → 23.80% |
| 64 → 128 | 64 | 6.8e+07 → 3.4e+07 | 48.81% → 26.58% |
Query Rows per Warp and SMEM Arithmetic Intensity
At the warp level, we see a similar pattern. Each warp must copy the entire mma instructions independently, fewer warps means fewer redundant ldmatrix instructions for
When the number of warps per CTA is fixed to 4:
- With
, each warp handles query rows - With
, each warp handles query rows
To double the number of rows per warp while avoiding overwhelming register spills, we need to reload ldmatrix instructions, but the reduction from increasing query rows per warp more than compensates:
| Persist | Total ldmatrix | SMEM → RF % Peak | |||
|---|---|---|---|---|---|
| 64 → 128 | 16 → 32 | 32 | False | 3.4e+08 → 2.0e+08 | 69.19% → 49.89% |
| 64 → 128 | 16 → 32 | 64 | False | 3.0e+08 → 1.7e+08 | 72.82% → 46.46% |
The instruction count reduction is substantial - nearly halving ldmatrix instructions when doubling the number of rows per warp. This demonstrates why more rows per warp are generally more efficient, provided we have enough SMEM and can manage register pressure.
It's worth stepping back to understand what's happening here:
Arithmetic Intensity
Fewer memory instructions means higher data reuse, which translates to higher arithmetic intensity.
- Increasing
increases L2 (GMEM) arithmetic intensity: Each CTA loads the entire K/V sequence once but processes more query rows, improving the compute-to-GMEM-load ratio. - Increasing query rows per warp increases SMEM arithmetic intensity: Each warp loads K/V tiles from SMEM once but computes more query rows with them, improving the compute-to-SMEM-load ratio.
Increasing Query Rows per Warp by Reloading
When we change the kernel to reload
The relative decrease is
ldmatrix instructions.
Equivalently, the new-to-old multiplier is
| Persist | ldmatrix | Decrease | SMEM → RF % Peak | |||
|---|---|---|---|---|---|---|
| 64 → 128 | 16 → 32 | 32 | True → False | 2.7e+08 → 2.0e+08 | 25% | 60.01% → 49.89% |
| 64 → 128 | 16 → 32 | 64 | True → False | 2.7e+08 → 1.7e+08 | 37.5% | 62.87% → 46.46% |
Warning
Depending on the ratio of
, we can actually increase the number of ldmatrixinstructions.We eliminate
instructions but add , with a net change . This results in fewer instructions when
So when doubling
, we decrease the amount of ldmatrixinstructions when. This is typically the case.
Performance Impact
| Persist | Relative TFLOPs % | ||
|---|---|---|---|
| 64 → 128 | 64 | True → False | 84.1 → 100 |
| 64 → 128 | 32 | False | 83.4 → 89.3 |
How Affects Softmax Overhead and Memory Access
The key/value tile size affects the softmax computation overhead. Recall that Flash Attention uses a blocked softmax algorithm that computes the attention incrementally across tiles. From Final Algorithm (where we derived the detailed instruction counts), for each
instructions
where
A significant portion of these instructions come from the blocked softmax overhead - operations needed to correctly combine results across tiles:
The values in orange are overhead from local softmax. Line (5) has no overhead because the scaling is fused with the first sum.
This is a total of
For
: 52.6% of softmax operations are overhead : 35.7% are overhead : 21.7% are overhead
Let's put these numbers in perspective. Compared to
executes ~65% more softmax instructions and executes ~22% more
Since softmax instructions account for 65-75% of all non-mma instructions, this translates to mma instructions than
We can see this impact in the profiler data:
| Persist | Total Softmax Instructions | Pipe Utilization: FMA | Pipe Utilization: ALU + FMA | ||
|---|---|---|---|---|---|
| 64 | 32 → 64 | True | 1.3e+09 → 9.4e+08 | 27.24% → 16.84% | 39.54% → 26.54% |
| 128 | 32 → 64 | False | 1.3e+09 → 9.4e+08 | 26.58% → 19.50% | 36.93% → 30.27% |
Whether this overhead impacts tensor throughput depends on the device and kernel design. With the right kernel design, this overhead can be made to have minimal to no impact3.
Impact on Reloading
When reloading
In general, the relative decrease is
So doubling
fewer instructions.
Total ldmatrix | Decrease | SMEM → RF % Peak | SMEM: LDS Wavefronts | ||
|---|---|---|---|---|---|
| 64 | 32 → 64 | 3.4e+08 → 3.0e+08 | 10% | 69.19% → 72.82% | 1.3e+09 → 1.2e+09 |
| 128 | 32 → 64 | 2.0e+08 → 1.7e+08 | 16.6% (1/6) | 49.89% → 46.46% | 8.1e+08 → 6.7e+08 |
For
Performance
| Kernel Iteration | Persist | TFLOPs Relative to | ||
|---|---|---|---|---|
| 16 | 128 | 32 → 64 | False | 89.3% → 100% |
| 16 | 64 | 32 → 64 | True → False | 83.4% → 84.1% |
Performance Comparison Across Block Sizes
Here's how the different block size configurations stack up against each other, all relative to the best-performing configuration:
| Kernel Iteration | Persist | TFLOPs Relative to | ||
|---|---|---|---|---|
| 16 | 128 | 64 | False | 100% |
| 16 | 128 | 32 | False | 89.3% |
| 16 | 64 | 64 | True | 84.1% |
| 16 | 64 | 32 | False | 83.4% |
Applicability to Hopper and Blackwell
Ampere is now a couple of generations old, and NVIDIA has made significant architectural improvements in Hopper and Blackwell. Some aspects of the performance impacts we've discussed have been improved upon through new hardware features. Let's examine how these newer architectures change the optimization landscape and which principles still apply.
Hopper and Blackwell Datacenter GPUs
For Hopper and Blackwell DC:
- The discussion on
is still fully relevant - the number of FP instructions remains unchanged between architectures. Flash Attention 4, however, improves upon this by avoiding rescaling softmax in cases that preserve numeric stability for 16-bit input tensors. - For
and the number of query rows per warp, we'll look into how Hopper and Blackwell provide architectural improvements.
and Number of Query Rows per Warp
Query Rows per Warp and SMEM Arithmetic Intensity
New features from Hopper and Blackwell DC have made the discussion on query rows per warp and SMEM arithmetic intensity irrelevant.
- Hopper has warpgroup level
mmas and Blackwell DC has single-thread-initiatedmmaAPIs that directly operate on thematrix in SMEM - This means we don't need to load the
/ sequence from SMEM for every warp - only once per warpgroup/ mma, giving us 4-8x higher arithmetic intensity.
To avoid reading the
- Hopper
mmas can load thematrix from the RF - Blackwell DC has a special buffer that can cache the
(or )4 operand instead of re-reading it
In addition, Blackwell DC expanded mmas to be able to execute on 2 CTAs at once. In this mode, each CTA independently loads its own
and L2 Arithmetic Intensity
Our discussion on
CTA clusters are a new addition to the thread hierarchy, fitted in between grids and CTAs.
- Clusters support memory transfer multicasting between SMs in the same cluster
- This uses inter-SM locality to accelerate broadcast loads from L2 to any number of SMs, likely via a crossbar situated between the SMs and L2
- This datapath was measured to have roughly 32% lower latency than L2, but somewhat lower aggregate throughput across all SMs (Luo et al., 2025)
- This reduces the number of reads from L2/GMEM
Larger Block Sizes
Two resources constrain block sizes: SMEM and RMEM (register memory). Hopper and Blackwell improve on these in different ways to enable larger block sizes.
For both Hopper and Blackwell DC:
- In these matmuls, the
operand doesn’t need to be (nor can it be) loaded into the RF, which reduces register pressure - They both increase the amount of SMEM per SM, from 163KB (Ampere) to 227KB
Blackwell DC goes even further with memory improvements:
Memory Expansion: Blackwell adds 256KB of TMEM (tensor memory) on top of the 227KB of SMEM and 256KB of RMEM, bringing the total to roughly 739KB of on-chip memory per SM.
Key Benefits for Flash Attention:
- TMEM stores the accumulator matrix
during mmaoperations (on Hopper, this must use RMEM), freeing up register file space - The
matrix can optionally be cached in a collector buffer for mmainstructions, avoiding SMEM reloads- On Hopper, the
matrix would need to be stored in the RF, increasing register pressure
- On Hopper, the
- This extra memory enables even larger block sizes without spilling
Access Restrictions: While powerful, TMEM has more restrictive access patterns than traditional memory:
- Warps can only access specific TMEM rows5 based on their warp ID
- Threads are limited to particular rows within TMEM based on their thread ID
- Despite these limitations, this isn't a problem for
mmaworkloads since they naturally follow these access patterns
As usual, larger block sizes can lead to tile or wave quantization - performance penalties when problem dimensions don't evenly divide by block sizes, leaving some SMs underutilized.
RTX Lineup
This entire discussion remains fully relevant for non-datacenter GPUs (i.e., Ada and Blackwell GeForce). These GPUs
- are limited to Ampere style
mmainstructions - so no warpgroupmmas - have no hardware support for CTA clusters - so no multicasting
It's entirely possible that neither of these features will be added to the GeForce line, as they serve as a form of market segmentation.
Summary
Block size configurations create complex tradeoffs that cascade through multiple aspects of kernel performance, making it challenging to isolate their individual impacts. Nevertheless, we can precisely quantify how different configurations affect instruction mix, giving us solid intuition about their effects.
Takeaways:
(query tile size): Scales inversely with GMEM → SMEM instructions. Larger reduces the number of CTAs, which reduces total cp.asyncandldmatrixinstructions for/ . This leads to better data reuse and higher L2 arithmetic intensity. Query rows per warp: Fewer warps means less redundant copying of
/ from SMEM to registers, resulting in higher SMEM arithmetic intensity. Reloading
: Adds redundant SMEM → RF loads but enables larger values without excessive register spills. The net effect is fewer overall loads since larger dramatically reduces / load instructions.
(key/value tile size): Scales inversely with softmax instruction overhead and (when reloading ) the number of redundant SMEM → RF copies. Local softmax introduces overhead compared to vanilla softmax due to the need for block scaling. Larger decreases this overhead.
Footnotes
-
This metric is
l1tex__data_pipe_lsu_wavefronts_mem_shared_op_ld.sum.pct_of_peak_sustained_elapsed. It includes bothldmatrixand the SMEM → RF portion of loadingbefore writing into the GMEM. The latter, however, consists of a negligible fraction of overall loads. ↩ -
This metric is
lts__t_sectors_srcunit_tex_op_read.avg.pct_of_peak_sustained_elapsed. ↩ -
I'll have a post up soon that goes into more detail about this, separate from this series. ↩
-
With the weight stationary API ↩
-
These rows are referred to as lanes in CUDA terminology. ↩