Intro

In the previous part, we implemented swizzling and achieved a dramatic 2x performance improvement by eliminating bank conflicts. Our kernel now reaches 98% of the reference implementation's performance on the RTX 3090.

In this part, we'll implement standard GEMM optimization techniques used in Cutlass, NVIDIA's C++ template library for high-performance linear algebra kernels. Since we've already captured most of the available performance gains, the improvements may seem more modest. However, these techniques are crucial for:

  • Memory latency hiding: Overlapping data transfers with computation to reduce idle time
  • Register pressure reduction: Enabling larger block sizes and more kernel configurations

Some optimizations may even cause slight performance regressions with our current configuration, but they enable better performance with other block sizes and provide auto-tuning variables in later parts.

We'll implement three progressive optimizations:

Kernel 3: Eager K/V Loading - We'll implement double buffering at the GMEM → SMEM level by loading the next and blocks while computing on the current ones. This classic optimization technique overlaps GMEM transfers with computation, reducing GMEM stalls by 93%.

Kernel 4: Fragment Interleaving - Instead of loading entire tiles into RF before computation, we'll divide them into sub-tiles and interleave loading with matrix operations. This reduces the time between the first SMEM → RF transfer and the first GEMM operation while dramatically reducing register pressure across all block configurations.

Kernel 5: Double Buffering SMEM → RF - We'll extend double buffering to fragment loads, allocating extra register space to hide ldmatrix latency. While this shows a slight regression for our current configuration, it enables significant performance gains for other block sizes crucial for auto-tuning.

Kernel 3: Eagerly Loading & Blocks (Double Buffering GMEM → SMEM)

Our current kernel loads and blocks from GMEM just before copying them to the RF, causing warps to idle during transfers.

This inefficiency shows up clearly in our profiling: 15.15% of warp stalls come from waiting on GMEM transfers (long_scoreboard), compared to ~0.43% on the reference kernel. This indicates we have substantial room for improvement through better memory scheduling.

The solution: Start loading blocks much earlier while computing on previous data, so transfers complete before they're needed.

Determining Safe Load Points

To implement this optimization safely, we must respect synchronization constraints:

Memory dependencies:

  • and : Require __syncthreads() between GMEM→SMEM and SMEM→RF (all warps cooperate)
  • Between iterations: Need __syncthreads() to ensure all warps finish reading before overwriting SMEM

Question: When can we start loading and from GMEM → SMEM without causing race conditions?

Here's a simplified version of our kernel in Python pseudocode with the barriers and cp.async.wait() removed.

# Prologue
cp.async_and_commit(Q_SM, Q_GM[offset])
 
for blk in range(T_c):
	cp.async_and_commit(K_SM, K_GM[blk])
	load_SMEM2RF(K_RF, K_SM)
	gemm(S_RF, Q_RF, K_RF^T)
	
	online_softmax(S_RF, m_RF, l_RF, O_RF)
	convert_to_b16(P_RF, S_RF)	
	cp.async_and_commit(V_SM, V_GM[blk])
	load_SMEM2RF(V_RF, V_SM)
	gemm(O_RF, P_RF, V_RF)
 
# Epilogue
# ...

The earliest we can load either block is after we don't need to access the SMEM for the previous block, which occurs right after load_SMEM2RF. Adding it after the gemms, however, will give the compiler more flexibility for optimization.

  • For , that's after (line 7).
  • For , that's after (line 13).

Let's "push" each load upwards to these respective points. For , we'll need to add a special case for an initial load.

# Prologue
cp.async_and_commit(Q_SM, Q_GM[offset])
cp.async_and_commit(K_SM, K_GM[0])
 
for blk in range(T_c):
	cp.async_and_commit(V_SM, V_GM[blk])
	load_SMEM2RF(K_RF, K_SM)
	gemm(S_RF, Q_RF, K_RF^T)
	
	online_softmax(S_RF, m_RF, l_RF, O_RF)
 
	if blk < T_c - 1:
		cp.async_and_commit(K_SM, K_GM[blk+1])
	convert_to_b16(P_RF, S_RF)
	load_SMEM2RF(V_RF, V_SM)
	gemm(O_RF, P_RF, V_RF)
 
# Epilogue
# ...

Now we need to add synchronization barriers. The key insight is that we need two barrier points to prevent race conditions:

Barrier LocationPurpose
Loop startEnsure transfer completes before use
Prevent overwriting while still being read
After softmaxEnsure transfer completes before use
Ensure all warps finish reading before overwriting
# Prologue
cp.async_and_commit(Q_SM, Q_GM[offset])
cp.async_and_commit(K_SM, K_GM[0])
 
for blk in range(T_c):
	cp.wait()
	__syncthreads()
	cp.async_and_commit(V_SM, V_GM[blk])
	load_SMEM2RF(K_RF, K_SM)
	gemm(S_RF, Q_RF, K_RF^T)
	
	online_softmax(S_RF, m_RF, l_RF, O_RF)
 
	cp.wait()
	__syncthreads()
	if blk < T_c - 1:
		cp.async_and_commit(K_SM, K_GM[blk+1])
	convert_to_b16(P_RF, S_RF)
	load_SMEM2RF(V_RF, V_SM)
	gemm(O_RF, P_RF, V_RF)
 
# Epilogue
# ...

This arrangement maximizes overlap between memory transfers and computation while ensuring correct synchronization.

Here's what the execution flow looks like now:

Double Buffering

What we've just implemented is a classic optimization technique called double buffering. The idea is simple: instead of waiting around doing nothing while data loads, we keep two buffers and alternate between them. While we're computing on one buffer, we're simultaneously loading data into the other.

In our case, we allocate extra SMEM space for both and blocks. While we're busy computing on the current blocks, we're already loading the next ones in the background. This way, warps don't have to sit idle waiting for memory transfers.

The nice thing is we already allocated the extra SMEM slice for back in Occupancy - we just weren't using it effectively until now.

Here's how double buffering transforms this pattern in the general case:

Performance

How did this optimization improve performance? After profiling, we see a significant drop in warp stalls due to long_scoreboard, from 15.15% to 0.95%. This led to a decent improvement in performance from 98.3% → 99.4% of the reference implementation.

Kernel 4: Interleaving On Chip LD/ST with Computation

With Kernel 3 successfully reducing GMEM stalls by 93%, we've addressed the GMEM → SMEM transfer bottleneck. However, our kernel still loads entire and tiles into RF before performing any computation. Kernel 4 tackles this inefficiency by interleaving memory transfers with computation, which will result in reduced register pressure.

Currently, our kernel loads entire and tiles into RF before performing any mma operations.

This has two issues

  1. We should be able to start on some matmuls without waiting for the entire tile to be copied
  2. Although the kernel configuration we picked isn't large enough to spill registers, some larger block dimensions are. If we want different block configurations to be feasible, as the one we chose ultimately might not be the most optimal, then we're going to need to reduce the register pressure.

Loading Fragments

Instead of loading all of the fragments at once, we'll divide them into subsets we'll call sub-tiles. Our mma loop will then first load one sub-tile of fragments, perform the mmas on those fragments, load the next sub-tile, mma on those fragments, and so on.

To maximize performance, we want to load fragments in a way that maximizes data reuse. There are two main strategies:

Strategy 1: Load A row-by-row, B column-by-column (inefficient)

  • Each column of gets loaded multiple times (once per row of )
  • Fragment intensity: 0.891 computations per memory transfer
  • Creates redundant loading that stresses the memory pipeline

1st iteration of loading A row-wise and B col-wise

2nd iteration of loading A row-wise and B col-wise

Strategy 2: Load A and B along dimension (optimal)

  • Load "slices" of both and along the inner dimension
  • Each fragment used exactly once
  • Fragment intensity: 1.62 computations per memory transfer - nearly double Strategy 1
  • This is the approach used by Cutlass

1st iteration of loading A & B along the inner k dimension

2nd iteration of loading A & B along the inner k dimension

We'll tile tensors according to the mma instruction we're using, m16n8k16, so we'll split along the k dimension with slices of width 16 elements or 2 fragments (which is depicted by the diagrams above). There are some tiles that are already completely in the RF that we don't need to load from SMEM, like - we'll leave these be.

Tile Dimension Table

TensorFormatFull Shape
(Fragments)
Sub-tile Shape
(Fragments)
# Tiles
(Fragments)
mma matrix
variable
Row major
Row major
Column major
Sub-tile RF shapes, tiled across k dimension

Register Pressure

This will also help us with significantly reducing register pressure, which we discussed earlier. Here are the build time spills for different block sizes.

spilledstack_frame
(bytes)
spill_stores
(bytes)
spill_loads
(bytes)
registers
(128, 64, 32, 4)False000212
(128, 64, 64, 4)False000242
(128, 128, 32, 4)True304356324255
(128, 128, 64, 4)True728960836255
(128, 128, 128, 4)True145621481808255

If we want larger block sizes to be feasible for auto-tuning (kernel 7), we must reduce register pressure significantly.

Some configurations spill nearly 1.5 kilobytes on the stack! The (128, 128, 64, 4) configuration alone generates 2148 bytes of spill stores and 1808 bytes of spill loads. This massive overhead clogs the memory pipeline and makes these configurations impractical.

What causes register spilling? When the register file isn't large enough, the compiler spills register values to LMEM (L1 → L2 → DRAM). This creates overhead that pollutes caches, reduces bandwidth, and delays memory transfers.

Performance impact varies: Spills in the prologue/epilogue have minimal impact, while spills in tight loops significantly slow down useful work. With high arithmetic intensity, small amounts of spilling can sometimes be hidden through optimal instruction scheduling.

You can configure nvcc to output these details on a per kernel basis. See nvcc Flags for Register Spilling for more details.

Code

We'll change our array sizes to only hold the numbers of sub-tiles that we'll store in RF. This only affects , , and .

uint32_t input[rows/8][cols/8];
// becomes
uint32_t input[rows/8][2];

To accommodate this, we make a few changes to warp_fragment_mma_f32_accum() and add a wrapper matmul() that calls warp_fragment_mma_f32_accum() over the sub-tiles.

gemm.cuh
// It's possible for K_fragments_A != K_fragments_B because either tensor can be buffered over sub-tiles.
template <typename value_t, const int M_fragments, const int N_fragments,
          const int K_fragments_A, const int K_fragments_B,
          typename accum_t = float>
__forceinline__ __device__ constexpr void warp_fragment_mma_f32_accum(
    uint32_t (&regs_A)[M_fragments][K_fragments_A],
    uint32_t (&regs_B)[N_fragments][K_fragments_B],
    accum_t (&regs_C)[M_fragments][N_fragments * N_REGS_PER_F32_ACCUM_FRAGMENT],
    int A_col_fragment_offset = 0, int B_col_fragment_offset = 0) {
    constexpr int K_iters = constexpr_min(K_fragments_A, K_fragments_B);
    #pragma unroll
    for (int k = 0; k < K_iters; k += MMA_K_FRAGMENTS_PER_ITER) {
        #pragma unroll
        for (int m = 0; m < M_fragments; m += MMA_M_FRAGMENTS_PER_ITER) {
            #pragma unroll
            for (int n = 0; n < N_fragments; n += MMA_N_FRAGMENTS_PER_ITER) {
                mma_m16n8k16_f32_accum<value_t>(
                    regs_C[m][n * 2],
                    regs_C[m][n * 2 + 1],
                    regs_C[m + 1][n * 2],
                    regs_C[m + 1][n * 2 + 1],
                    
                    regs_A[m][k + A_col_fragment_offset],
                    regs_A[m + 1][k + A_col_fragment_offset],
                    regs_A[m][k + 1 + A_col_fragment_offset],
                    regs_A[m + 1][k + 1 + A_col_fragment_offset],
                    
                    regs_B[n][k + B_col_fragment_offset],
                    regs_B[n][k + 1 + B_col_fragment_offset],
                    
                    regs_C[m][n * 2],
                    regs_C[m][n * 2 + 1],
                    regs_C[m + 1][n * 2],
                    regs_C[m + 1][n * 2 + 1]);
            }
        }
    }
}

Here's the wrapper. I wouldn't worry too much about the GEMM struct. It just contains the configurations for each matrix and the overall GEMM.

gemm.cuh
template <typename GEMM>
__device__ constexpr void matmul(typename GEMM::A_t &A, typename GEMM::B_t &B,
                                 typename GEMM::C_t &C) {
    // If ::load_entire_block_into_rf is set for either A_t or B_t, then
    // we assume the block has already been loaded.
    using A_t = typename GEMM::A_t; // Q or P
    using B_t = typename GEMM::B_t; // K or V
 
    constexpr int fragments_per_iter = 2;
 
    // GEMM::TotalKFragments is
    // - d_head / 8 for QK^T
    // - B_c / 8    for PV
    #pragma unroll
    for (int k = 0; k < GEMM::TotalKFragments; k += fragments_per_iter) {
		// Load fragments along K dimension (2 at a time)
		// Q is pre-loaded, P is computed in RF - only load if needed
		if constexpr (!A_t::load_entire_block_into_rf) {
			A.copy_SM2RF(k);  // Load Q fragments from SMEM
		}
		// Always load K/V fragments from SMEM (2 fragments per iteration)
		B.copy_SM2RF(k);
 
		// Calculate column offsets for accessing the right fragment data
        int A_col_offset = A_t::load_entire_block_into_rf ? k : 0;
        int B_col_offset = B_t::load_entire_block_into_rf ? k : 0;
        
        // Perform outer product: each A row × each B column
        // This gives optimal fragment reuse compared to row-by-row approach
        warp_fragment_mma_f32_accum(A.data(), B.data(), C.data(),
                                    A_col_offset, B_col_offset);
    }
}

Register Usage Comparison

The register pressure reduction is dramatic across all configurations:

Key improvements:

  • Current config (64,64): 242 → 207 registers
  • Reference config (128,32): 304 → 0 registers, eliminating all spills and making this configuration viable
  • Largest spiller (128,64): 2208 → 1312 bytes spill stores (-40.6% reduction)

The most significant win is making the (128, 32, 4) configuration feasible, which is what the reference kernel uses on the RTX 3090. However, some large configurations like (128, 64, 4) still need more work to become usable.

spilledstack_framespill_storesspill_loadsused_registers
(128, 64, 32, 4)False000212 → 168
(128, 64, 64, 4)False000242 → 207
(128, 128, 32, 4)True → False304 → 0356 → 0324 → 0255
(128, 128, 64, 4)True728 → 272964 → 336840 → 304255
(128, 128, 128, 4)True1360 → 8402208 → 13121868 → 1120255

Performance

Fragment interleaving achieves a significant milestone: 100.0% of reference performance. Interleaving SMEM → RF loads with computation and optimal fragment reuse (1.6x vs 0.89x intensity) results in crossing the reference performance threshold for the first time in our optimization journey.

Kernel 5: Double Buffering SMEM → RF Loads

Just as we double buffered GMEM → SMEM transfers in Kernel 3, we can also double buffer fragment loads from SMEM → RF. This helps hide ldmatrix instruction latency at the cost of higher register pressure.

To do this, we'll allocate registers for an extra set of fragments along the k dimension ( for , for ). Then, before our matmul loop, we'll load the first set of fragments into the first set of registers. Inside the loop, we'll load the next set before operating on the current set so that by the time the tensor cores are finished with the current set of fragments, the next set should already be loaded. In between iterations, we'll swap the set of registers we load into.

The conceptualization is the same as in the previous kernel, except that, since the scope is warp wide instead of CTA wide and each operation is warp synchronous, we don't need to have explicit barriers.

Code changes

Storage

We add another dimension to all the RF matrices for the # buffers.

For , , and

uint32_t input[rows/8][2];
// becomes
uint32_t input[2][rows/8][2];

For

uint32_t input[rows/8][cols/8];
// becomes
uint32_t input[1][rows/8][cols/8];
 

For and

float accum[n/8][m/4];
// becomes
float accum[1][n/8][m/4];

matmul

gemm.cuh
template <typename GEMM>
__forceinline__ __device__ constexpr void matmul(
	typename GEMM::A_t &A,
	typename GEMM::B_t &B,
    typename GEMM::C_t &C) {
    // If tensor_t::load_entire_block_into_rf is set for either A_t or B_t, then
    // we assume the block has already been loaded.
    using A_t = typename GEMM::A_t; // Q or P
    using B_t = typename GEMM::B_t; // K or V
 
    constexpr int fragments_per_iter = 2;
 
	// This is 1 for Q and 0 for P. 
    constexpr int A_stage_toggle = A_t::mma_load_stages - 1; 
    constexpr int B_stage_toggle = B_t::mma_load_stages - 1; // 1 for K & V
 
    int A_stage = 0;
    int B_stage = 0;
 
    // True for Q, False for P (calculated & stored in RF)
    if constexpr (!A_t::load_entire_block_into_rf) {
        A.copy_SM2RF(A_stage, 0); // copy first Q sub-tile
    }
    B.copy_SM2RF(B_stage, 0); // copy first K or V sub-tile
 
    // GEMM::TotalKFragments is
    // - d_head / 8 for QK^T
    // - B_c / 8    for PV
    #pragma unroll
    for (int k = 0; k < GEMM::TotalKFragments; k += fragments_per_iter) {
        int k_load_fragment = k + fragments_per_iter;
        if (k_load_fragment < GEMM::TotalKFragments) {
            // True for Q, false for P.
            if constexpr (!A_t::load_entire_block_into_rf) {
                A.copy_SM2RF(A_stage_toggle ^ A_stage, k_load_fragment);
            }
            B.copy_SM2RF(B_stage_toggle ^ B_stage, k_load_fragment);
        }
 
        int A_col_offset = A_t::load_entire_block_into_rf ? k : 0;
        int B_col_offset = B_t::load_entire_block_into_rf ? k : 0;
        // Perform sub-tile-wise outer products.
        warp_fragment_mma_f32_accum(A.data(A_stage), B.data(B_stage), C.data(0),
                                    A_col_offset, B_col_offset);
 
        A_stage ^= A_stage_toggle;
        B_stage ^= B_stage_toggle;
    }
}
 

Performance

Double buffering SMEM → RF shows a slight performance regression with our current configuration (), from 100% → 99.6% of reference. However, this regression is configuration-specific - other block configurations in the final kernel show up to 4% improvement, making this optimization essential for comprehensive auto-tuning.

Optional: Profile Analysis

The warp stall breakdown reveals some regressions:

StallKernel 4Kernel 5Delta
barrier3.66%4.37%+0.71%
mio_throttle2.18%2.40%+0.22%

mio_throttle Stalls (+0.22%)

Double buffering creates denser memory instruction patterns that saturate the memory instruction queue (mio) faster.

Let's look at the start of the start of the ldmatrix blocks for and in SASS.

Kernel 4 intersperses 6 ldmatrix instructions with 19 arithmetic operations, providing natural spacing that prevents queue saturation:

kernel_4.asm
        // ...
        LDSM.16.M88.4 R88, [R177+0x4000] ; // 6 ldmatrix instructions
        LDSM.16.M88.4 R148, [R157] ;
        LDSM.16.M88.4 R112, [R172+0x4000] ;
        LDSM.16.M88.4 R36, [R174+0x4000] ;
        LDSM.16.M88.4 R92, [R183+0x4000] ;
        LDSM.16.M88.4 R144, [R155] ;
        LDGDEPBAR ;
        MOV R128, R88 ;        // 19 instructions between
        IMAD.MOV.U32 R129, RZ, RZ, R90 ;
        LOP3.LUT R88, R156, 0x7, R5, 0x78, !PT ;
        MOV R90, R89 ;
        IMAD.SHL.U32 R178, R88, 0x8, RZ ;
        MOV R96, R112 ;
        IMAD.MOV.U32 R97, RZ, RZ, R114 ;
        MOV R141, R115 ;
        HMMA.16816.F32 R104, R148.reuse, R90, RZ ;
        LOP3.LUT R176, R178.reuse, 0x800, R159, 0xfe, !PT ;
        IMAD.MOV.U32 R140, RZ, RZ, R113 ;
        LOP3.LUT R175, R178, 0x1000, R159, 0xfe, !PT ;
        IMAD.MOV.U32 R133, RZ, RZ, R38 ;
        SHF.L.U32 R176, R176, 0x1, RZ ;
        IMAD.MOV.U32 R137, RZ, RZ, R94 ;
        SHF.L.U32 R175, R175, 0x1, RZ ;
        IMAD.MOV.U32 R38, RZ, RZ, R37 ;
        LOP3.LUT R188, R178, R159, RZ, 0xfc, !PT ;
        LDSM.16.M88.4 R88, [R176+0x4000] ; // next ldmatrix

On the other hand, kernel 5 front-loads an additional ldmatrix on top of the 6 from before and executes only 5 instructions before the next memory operation, creating higher queue pressure:

kernel_5.asm
        // ...
        LDSM.16.M88.4 R92, [R192+0x4000] ; // 7 ldmatrix instructions
        LDSM.16.M88.4 R144, [R159] ;
        LDSM.16.M88.4 R128, [R179+0x4000] ;
        LDSM.16.M88.4 R88, [R183+0x4000] ;
        LDSM.16.M88.4 R36, [R181+0x4000] ;
        LDSM.16.M88.4 R148, [R156] ;
        LDSM.16.M88.4 R132, [R186+0x4000] ;
        LDGDEPBAR ;
        IMAD.MOV.U32 R137, RZ, RZ, R94 ; // only 5 instructions
        MOV R94, R93 ;
        MOV R136, R92 ;
        HMMA.16816.F32 R104, R144.reuse, R94, RZ ;
        LDSM.16.M88.4 R92, [R195+0x4000] ; // next ldmatrix

This increases the total portion of mio_throttle stalls due to and loads from 47.6% → 50.2%.

stall_barrier Increase (+0.71%)

The increase in barrier stalls is more complex and not directly tied to the mio_throttle bottleneck, although the memory instruction queue pressure likely contributes partially. Rather, it more generally reflects broader compiler scheduling changes and complex warp scheduling interactions in the Ampere microarchitecture that occur when instruction patterns change.

Summary

We've successfully implemented core Cutlass GEMM optimizations and achieved 100% of reference performance on the RTX 3090. Our kernel now features:

  • Double Buffering at the GMEM → SMEM level for reduced GMEM load stalls
  • Fragment interleaving to increase overlap between memory transfers and computation and reduced register pressure
  • Double Buffering from SMEM → RF

Up Next

In Part 6, we'll push beyond 100% to surpass reference performance by ~1.5% on the RTX 3090. We'll tackle two final optimizations:

  • Fusing Floating-point instructions to reduce FP instruction count
  • Auto-tuning to systematically search our configuration space for the top performing configurations.

Footnotes