This appendix dives into how block size configurations affect instruction patterns and performance in Flash Attention. If you've been wondering why larger blocks tend to be more efficient, this is for you.

We'll examine how different block configurations affect instruction patterns by analyzing kernel 16 (our final optimized kernel). Since kernel 16 is heavily optimized, it has less noise from unoptimized code generation, giving us cleaner insights into the fundamental tradeoffs.

We'll discuss how the following design choices impact instruction mix and resulting kernel metrics:

  • Reloading vs persisting in the RF (register file)
  • The size of (query tile size)
  • The number of query rows per warp ()
  • The size of (key/value tile size)

Note: unless noted otherwise, instruction count formulas are per warp per / tile.

Profiling and Benchmarking

All profiles and benchmarks were run on the A100 with and .

For the profiles, each configuration was restricted to an occupancy of 2 CTAs per SM. For the benchmarks, this restriction was lifted.

Here's the general principle: larger block sizes (, more query rows per warp, or ) reduce total redundant operations through better data reuse, but require more SMEM and registers. What stays constant, though, is the total number of mma instructions. We're computing the same output regardless of block size - the block size just affects how we organize that computation.

Larger block sizes also decrease the total number of iterations executed. This typically results in fewer branches, barriers, and potentially address calculation instructions. However, I'll forgo these, as their impact is typically smaller, harder to measure, or highly variable.

Cascading Effects From Changing Block Sizes

Changing block sizes triggers cascading effects across multiple dimensions. Occupancy, register pressure, memory access patterns, instruction scheduling, and data dependencies all shift simultaneously, making it challenging to isolate the impact of any single factor. The impacts may also differ across different devices. With that caveat in mind, let's examine how these changes manifest.

The Persist vs. Reload Tradeoff for

One key design choice is whether to persist in registers throughout the main loop or reload it for each K/V tile. This decision involves a fundamental tradeoff:

Persisting : We keep query tiles in registers during the multiplication. This creates higher register pressure and can lead to costly register spills, especially with larger values.

Reloading : We free registers by reloading query tiles from SMEM when needed. This adds ldmatrix instructions but can avoid register spills. The overhead varies with block size - larger means more data to reload per CTA.

So what's the actual cost of reloading? Let's quantify this overhead. With 4 warps per CTA, the total number of ldmatrix instructions per tile per warp is:

The denominator comes from each instruction loading elements.

Reloading compared to persisting it results in

times more ldmatrix instructions (excluding the initial load when persisting ). The impact scales proportionally with but inversely with .

\ 3264
641.25x1.125x
1281.5x1.25x

These multipliers show that reload overhead is most significant for large and small configurations - exactly where we'd expect the most register pressure from persisting .

But how does this play out in practice? The impact of this choice becomes clear when we examine SMEM to register bandwidth utilization on the A100:

Persist Total ldmatrixSMEM → RF % Peak 1
6432True → False2.7e+08 → 3.4e+0860.01 → 69.19
6464True → False2.7e+08 → 3.0e+0862.87 → 72.82
12832True → False1.3e+08 → 2.0e+0830.50 → 49.89

Note: % peak metrics can increase even when doing less total work. Since these percentages are calculated as (work done / cycles elapsed), fewer cycles may result in higher peak percentages.

Reloading clearly increases ldmatrix instructions. But this isn't necessarily a bad tradeoff. Persisting creates high register pressure, limiting our options: we can either use smaller (which avoids reload overhead) or larger (which causes register spills if we persist ).

The key insight is that reloading enables larger values with manageable register pressure. For instance, with , persisting causes 336 register spills, while reloading reduces this to just 56. More importantly, as we'll see next, the larger value significantly reduces the total number of ldmatrix instructions for K/V tiles across all CTAs - more than offsetting the reload overhead for .

How and Query Rows Per Warp Affect Arithmetic Intensity

and L2 Arithmetic Intensity

scales inversely with the number of CTAs. Doubling halves the total number of CTAs (since we fix n_warps to 4, this also halves the total number of warps).

Why does this matter? Each CTA must copy the entire / sequences to SMEM regardless of how many query rows it processes. The cp.async instructions for remain unchanged - we still load the same total amount of query data. Therefore, doubling halves the number of CTAs, which halves the total cp.async instructions for /. We can observe this reduction in the L2 load bandwidth utilization:

Total cp.asyncL2: Load % Peak2
64 → 128326.8e+07 → 3.4e+0744.07% → 23.80%
64 → 128646.8e+07 → 3.4e+0748.81% → 26.58%

Query Rows per Warp and SMEM Arithmetic Intensity

At the warp level, we see a similar pattern. Each warp must copy the entire / sequences from SMEM to registers, regardless of how many query rows it processes. The more rows a warp handles, the fewer total warps we need. Since warps execute mma instructions independently, fewer warps means fewer redundant ldmatrix instructions for / tiles.

When the number of warps per CTA is fixed to 4:

  • With , each warp handles query rows
  • With , each warp handles query rows

To double the number of rows per warp while avoiding overwhelming register spills, we need to reload each iteration. This adds some ldmatrix instructions, but the reduction from increasing query rows per warp more than compensates:

Persist Total ldmatrixSMEM → RF % Peak
64 → 12816 → 3232False3.4e+08 → 2.0e+0869.19% → 49.89%
64 → 12816 → 3264False3.0e+08 → 1.7e+0872.82% → 46.46%

The instruction count reduction is substantial - nearly halving ldmatrix instructions when doubling the number of rows per warp. This demonstrates why more rows per warp are generally more efficient, provided we have enough SMEM and can manage register pressure.

It's worth stepping back to understand what's happening here:

Arithmetic Intensity

Fewer memory instructions means higher data reuse, which translates to higher arithmetic intensity.

  • Increasing increases L2 (GMEM) arithmetic intensity: Each CTA loads the entire K/V sequence once but processes more query rows, improving the compute-to-GMEM-load ratio.
  • Increasing query rows per warp increases SMEM arithmetic intensity: Each warp loads K/V tiles from SMEM once but computes more query rows with them, improving the compute-to-SMEM-load ratio.

Increasing Query Rows per Warp by Reloading

When we change the kernel to reload to increase by a factor of , the instruction count per tile changes as follows:

The relative decrease is

ldmatrix instructions.

Equivalently, the new-to-old multiplier is

Persist ldmatrixDecreaseSMEM → RF % Peak
64 → 12816 → 3232True → False2.7e+08 → 2.0e+0825%60.01% → 49.89%
64 → 12816 → 3264True → False2.7e+08 → 1.7e+0837.5%62.87% → 46.46%

Warning

Depending on the ratio of , we can actually increase the number of ldmatrix instructions.

We eliminate instructions but add , with a net change .

This results in fewer instructions when

So when doubling , we decrease the amount of ldmatrix instructions when . This is typically the case.

Performance Impact

Persist Relative TFLOPs %
64 → 12864True → False84.1 → 100
64 → 12832False83.4 → 89.3

How Affects Softmax Overhead and Memory Access

The key/value tile size affects the softmax computation overhead. Recall that Flash Attention uses a blocked softmax algorithm that computes the attention incrementally across tiles. From Final Algorithm (where we derived the detailed instruction counts), for each / tile in the softmax computation, we execute a total of:

  • instructions

where

A significant portion of these instructions come from the blocked softmax overhead - operations needed to correctly combine results across tiles:

The values in orange are overhead from local softmax. Line (5) has no overhead because the scaling is fused with the first sum.

This is a total of overhead instructions per tile. The fraction of softmax instructions spent on this overhead is:

For , this is , so the overhead scales inversely with :

  • : 52.6% of softmax operations are overhead
  • : 35.7% are overhead
  • : 21.7% are overhead

Let's put these numbers in perspective. Compared to ,

  • executes ~65% more softmax instructions and
  • executes ~22% more

Since softmax instructions account for 65-75% of all non-mma instructions, this translates to executing 30% more non-mma instructions than .

We can see this impact in the profiler data:

Persist Total Softmax
Instructions
Pipe Utilization: FMAPipe Utilization:
ALU + FMA
6432 → 64True1.3e+09 → 9.4e+0827.24% → 16.84%39.54% → 26.54%
12832 → 64False1.3e+09 → 9.4e+0826.58% → 19.50%36.93% → 30.27%

Whether this overhead impacts tensor throughput depends on the device and kernel design. With the right kernel design, this overhead can be made to have minimal to no impact3.

Impact on Reloading

When reloading , also scales inversely with the number of redundant copies from SMEM to registers. If we scale by , the instruction count changes as follows:

In general, the relative decrease is

So doubling () results in

fewer instructions.

Total ldmatrixDecreaseSMEM → RF % PeakSMEM: LDS Wavefronts
6432 → 643.4e+08 → 3.0e+0810%69.19% → 72.82%1.3e+09 → 1.2e+09
12832 → 642.0e+08 → 1.7e+0816.6% (1/6)49.89% → 46.46%8.1e+08 → 6.7e+08

For , increasing from 32 → 64 increases the SMEM → RF % peak, but represents an overall decrease in work done due to running for fewer cycles.

Performance

Kernel
Iteration
Persist TFLOPs Relative to
1612832 → 64False89.3% → 100%
166432 → 64True → False83.4% → 84.1%

Performance Comparison Across Block Sizes

Here's how the different block size configurations stack up against each other, all relative to the best-performing configuration:

Kernel
Iteration
Persist TFLOPs Relative to
1612864False100%
1612832False89.3%
166464True84.1%
166432False83.4%

Applicability to Hopper and Blackwell

Ampere is now a couple of generations old, and NVIDIA has made significant architectural improvements in Hopper and Blackwell. Some aspects of the performance impacts we've discussed have been improved upon through new hardware features. Let's examine how these newer architectures change the optimization landscape and which principles still apply.

Hopper and Blackwell Datacenter GPUs

For Hopper and Blackwell DC:

  • The discussion on is still fully relevant - the number of FP instructions remains unchanged between architectures. Flash Attention 4, however, improves upon this by avoiding rescaling softmax in cases that preserve numeric stability for 16-bit input tensors.
  • For and the number of query rows per warp, we'll look into how Hopper and Blackwell provide architectural improvements.

and Number of Query Rows per Warp

Query Rows per Warp and SMEM Arithmetic Intensity

New features from Hopper and Blackwell DC have made the discussion on query rows per warp and SMEM arithmetic intensity irrelevant.

  • Hopper has warpgroup level mmas and Blackwell DC has single-thread-initiated mma APIs that directly operate on the matrix in SMEM
  • This means we don't need to load the / sequence from SMEM for every warp - only once per warpgroup/mma, giving us 4-8x higher arithmetic intensity.

To avoid reading the matrix multiple times:

  • Hopper mmas can load the matrix from the RF
  • Blackwell DC has a special buffer that can cache the (or )4 operand instead of re-reading it

In addition, Blackwell DC expanded mmas to be able to execute on 2 CTAs at once. In this mode, each CTA independently loads its own tile, but only has to transfer half of the or tile. This increases both SMEM and L2 arithmetic intensity.

and L2 Arithmetic Intensity

Our discussion on and L2 arithmetic intensity is still relevant on these newer architectures. No matter what, we still need to read the data from L2/GMEM and the larger is, the fewer CTAs we have, the fewer GMEM copies we make, and the higher the L2 intensity. Hopper and Blackwell DC, however, improve on this with CTA clusters and multicasting.

CTA clusters are a new addition to the thread hierarchy, fitted in between grids and CTAs.

  • Clusters support memory transfer multicasting between SMs in the same cluster
  • This uses inter-SM locality to accelerate broadcast loads from L2 to any number of SMs, likely via a crossbar situated between the SMs and L2
  • This datapath was measured to have roughly 32% lower latency than L2, but somewhat lower aggregate throughput across all SMs (Luo et al., 2025)
  • This reduces the number of reads from L2/GMEM

Larger Block Sizes

Two resources constrain block sizes: SMEM and RMEM (register memory). Hopper and Blackwell improve on these in different ways to enable larger block sizes.

For both Hopper and Blackwell DC:

  • In these matmuls, the operand doesn’t need to be (nor can it be) loaded into the RF, which reduces register pressure
  • They both increase the amount of SMEM per SM, from 163KB (Ampere) to 227KB

Blackwell DC goes even further with memory improvements:

Memory Expansion: Blackwell adds 256KB of TMEM (tensor memory) on top of the 227KB of SMEM and 256KB of RMEM, bringing the total to roughly 739KB of on-chip memory per SM.

Key Benefits for Flash Attention:

  • TMEM stores the accumulator matrix during mma operations (on Hopper, this must use RMEM), freeing up register file space
  • The matrix can optionally be cached in a collector buffer for mma instructions, avoiding SMEM reloads
    • On Hopper, the matrix would need to be stored in the RF, increasing register pressure
  • This extra memory enables even larger block sizes without spilling

Access Restrictions: While powerful, TMEM has more restrictive access patterns than traditional memory:

  • Warps can only access specific TMEM rows5 based on their warp ID
  • Threads are limited to particular rows within TMEM based on their thread ID
  • Despite these limitations, this isn't a problem for mma workloads since they naturally follow these access patterns

As usual, larger block sizes can lead to tile or wave quantization - performance penalties when problem dimensions don't evenly divide by block sizes, leaving some SMs underutilized.

RTX Lineup

This entire discussion remains fully relevant for non-datacenter GPUs (i.e., Ada and Blackwell GeForce). These GPUs

  • are limited to Ampere style mma instructions - so no warpgroup mmas
  • have no hardware support for CTA clusters - so no multicasting

It's entirely possible that neither of these features will be added to the GeForce line, as they serve as a form of market segmentation.

Summary

Block size configurations create complex tradeoffs that cascade through multiple aspects of kernel performance, making it challenging to isolate their individual impacts. Nevertheless, we can precisely quantify how different configurations affect instruction mix, giving us solid intuition about their effects.

Takeaways:

(query tile size): Scales inversely with GMEM → SMEM instructions. Larger reduces the number of CTAs, which reduces total cp.async and ldmatrix instructions for /. This leads to better data reuse and higher L2 arithmetic intensity.

Query rows per warp: Fewer warps means less redundant copying of / from SMEM to registers, resulting in higher SMEM arithmetic intensity.

Reloading : Adds redundant SMEM → RF loads but enables larger values without excessive register spills. The net effect is fewer overall loads since larger dramatically reduces / load instructions.

(key/value tile size): Scales inversely with softmax instruction overhead and (when reloading ) the number of redundant SMEM → RF copies. Local softmax introduces overhead compared to vanilla softmax due to the need for block scaling. Larger decreases this overhead.

Footnotes

  1. This metric is l1tex__data_pipe_lsu_wavefronts_mem_shared_op_ld.sum.pct_of_peak_sustained_elapsed. It includes both ldmatrix and the SMEM → RF portion of loading before writing into the GMEM. The latter, however, consists of a negligible fraction of overall loads.

  2. This metric is lts__t_sectors_srcunit_tex_op_read.avg.pct_of_peak_sustained_elapsed.

  3. I'll have a post up soon that goes into more detail about this, separate from this series.

  4. With the weight stationary API

  5. These rows are referred to as lanes in CUDA terminology.