Intro

In Part 6, we improved our kernel to slightly outperform the reference kernel on the RTX 3090, but found that on the A100, it only reached 80.3% of the reference. In Part 7, we profiled kernel 7 and identified the key bottlenecks:

  1. Excessive scalar instructions: We're executing roughly 2x the scalar instructions and 6.8-11.3x more dependency-creating instructions compared to the reference kernel. These dependencies block our cp.async, ldmatrix, and mma operations, forcing the high-throughput tensor cores to sit idle.

  2. Register pressure limits block size: Kernel 7 spills 272 bytes of registers per thread at - the optimal block configuration on the A100. These expensive memory accesses negate the benefits of larger blocks, forcing us to use smaller, less efficient configurations.

  3. Throughput imbalance amplifies the problem: The A100 has much higher tensor throughput but lower FP32 throughput than the RTX 3090. This lopsided ratio means our scalar instruction overhead creates a severe bottleneck - the tensor cores can process work much faster than our scalar units can prepare it.

In this part, we'll systematically eliminate these unnecessary instructions through four kernel iterations:

  1. Kernel 8: Implement strided swizzling to eliminate logic and bit-shift instructions while reducing register pressure
  2. Kernel 9: Optimize fragment storage to remove redundant register copies
  3. Kernel 10: Remove CS2R instructions and optimize the initial softmax iteration
  4. Kernel 11: Complete strided swizzling for RF→SMEM transfers

By the end, we'll achieve an 18.6% performance improvement, jumping from 149.71 TFLOPs (kernel 7) to 177.40 TFLOPs (kernel 11). That puts us within striking distance of the reference performance (186.4 TFLOPs) on the A100.

Kernel 8: Reducing Logic and Bit Shift Instructions & Register Usage

When we profiled kernel 7, we noticed excessive logic and bit shift instructions. These shouldn't be here - we're running a FP and tensor-heavy workload. The culprits are LOP3.LUT, IMAD.SHL.U32, and SHF.L.U32, which typically show up in memory address calculations. Let's dig into the assembly code to confirm this hypothesis.

Grouping Instructions by Operation Type

The compiler often uses different instructions for similar purposes. For instance:

  • certain parts of address calculation might use either LEA or IMAD
  • a reduction in IMAD.MOV.U32 instructions might be offset by an increase in MOV instructions

Instead of trying to reduce a specific instruction, group related instructions together and work to eliminate the total sum.

Kernel 7 SASS Examination

Let's examine snippets from kernel 7's SASS code (full code here) to see where these instructions appear.

Notice how these logic and shift instructions appear right before each memory transfer - they're calculating addresses for LDGSTS (cp.async) and LDSM (ldmatrix) operations:

Logic and Bit Shift Instruction Interleaving
	// ...
    SHF.L.U32 R215, R215, 0x1, RZ ;
    LOP3.LUT R178, R221, 0x1800, R8, 0xfe, !PT ;
    LDGSTS.E.BYPASS.LTC128B.128 [R205+0x8000], [R28.64] ;
    LOP3.LUT R30, R9, 0x7, R0, 0x78, !PT ;
    IMAD.SHL.U32 R191, R191, 0x2, RZ ;
    SHF.L.U32 R178, R178, 0x1, RZ ;
    LDGSTS.E.BYPASS.LTC128B.128 [R206+0x8000], [R28.64+0x80] ;
    IMAD.SHL.U32 R125, R30, 0x8, RZ ;
    LDSM.16.M88.4 R24, [R207+0x4000] ;
    LOP3.LUT R192, R125.reuse, R8, RZ, 0xfc, !PT ;
    LDSM.16.M88.4 R68, [R183+0x4000] ;
    LOP3.LUT R181, R125, 0x800, R8, 0xfe, !PT ;
    SHF.L.U32 R192, R192, 0x1, RZ ;
    LDSM.16.M88.4 R28, [R177+0x4000] ;
    LOP3.LUT R176, R125, 0x1000, R8.reuse, 0xfe, !PT ;
    SHF.L.U32 R181, R181, 0x1, RZ ;
    LDSM.16.M88.4 R120, [R192+0x4000] ;
    LOP3.LUT R173, R125, 0x1800, R8, 0xfe, !PT ;
    IMAD.SHL.U32 R176, R176, 0x2, RZ ;
    // ...	

Address Calculation

These instructions are indeed computing memory addresses. For LDSM, the second argument is the SMEM address. Here are a couple examples:

SMEM Address Calculation 1
364:        LOP3.LUT R192, R125.reuse, R8, RZ, 0xfc, !PT ;
367:        SHF.L.U32 R192, R192, 0x1, RZ ;
371:        LDSM.16.M88.4 R120, [R192+0x4000] ;
SMEM Address Calculation 2
369:        LOP3.LUT R176, R125, 0x1000, R8.reuse, 0xfe, !PT ;
373:        IMAD.SHL.U32 R176, R176, 0x2, RZ ;
376:        LDSM.16.M88.4 R156, [R176+0x4000] ;

Source of Instructions

These address calculations are far more complex than they should be. The culprit? Our swizzling implementation from Part 4. While it eliminated bank conflicts, it made our memory access patterns so complex that the compiler can't optimize them efficiently.

Let's test this hypothesis by comparing the instruction counts for swizzled and non-swizzled versions of kernel 7.

Swizzled vs Non-Swizzled

Instruction Count

After profiling the instruction counts, we can confirm our hypothesis. The swizzled version of our kernel executes 12x-42x as many of these instructions as the non-swizzled one.

Register Usage

Beyond instruction overhead, swizzling causes another problem: register pressure. The compiler must allocate registers to store all these computed memory addresses. Let's examine how many registers are needed by comparing the LDGSTS and LDSM operations in swizzled and non-swizzled versions.

LDGSTS (GMEM → SMEM)

Here is an example of a LDGSTS instruction

LDGSTS.E.BYPASS.LTC128B.128 [R19], [R4.64]

  • The first register is the destination address (SMEM, 32-bit)
  • and the second is the source address (GMEM, 64-bit)
    • The .64 modifier in the src register indicates that it's a 64 bit value.
  • The .BYPASS modifier specifies that we bypass the L1 cache.

These snippets contain every LDGSTS instruction from both kernels.

Kernel 7 (No Swizzling) All LDGSTS Instructions
LDGSTS.E.BYPASS.LTC128B.128 [R147],        [R6.64]
...                         [R145+0x80],   [R8.64]
                            [R145+0x400],  [R10.64]
                            [R145+0x480],  [R10.64+0x80]
                            [R145+0x800],  [R12.64]
                            [R145+0x880],  [R12.64+0x80]
                            [R145+0xc00],  [R4.64]
                            [R145+0xc80],  [R4.64+0x80]
                            [R147+0x4000], [R14.64]
                            [R145+0x4080], [R16.64]
                            [R145+0x4400], [R18.64]
                            [R145+0x4480], [R18.64+0x80]
                            [R145+0x4800], [R8.64]
                            [R145+0x4880], [R8.64+0x80]
                            [R145+0x4c00], [R6.64]
                            [R145+0x4c80], [R6.64+0x80]
                            [R147+0x8000], [R40.64]
                            [R145+0x8080], [R44.64]
                            [R145+0x8400], [R58.64]
                            [R145+0x8480], [R58.64+0x80]
                            [R145+0x8800], [R60.64]
                            [R145+0x8880], [R60.64+0x80]
                            [R145+0x8c00], [R56.64]
                            [R145+0x8c80], [R56.64+0x80]
                            [R147+0x4000], [R136.64]
                            [R145+0x4080], [R138.64]
                            [R145+0x4400], [R154.64]
                            [R145+0x4480], [R154.64+0x80]
                            [R145+0x4800], [R148.64]
                            [R145+0x4880], [R148.64+0x80]
                            [R145+0x4c00], [R164.64]
                            [R145+0x4c80], [R164.64+0x80]
Kernel 7 (Swizzling) All LDGSTS Instructions
LDGSTS.E.BYPASS.LTC128B.128 [R39],         [R24.64]
...                         [R41],         [R16.64]
                            [R43],         [R26.64]
                            [R45],         [R26.64+0x80]
                            [R39+0x800],   [R28.64]
                            [R41+0x800],   [R28.64+0x80]
                            [R47],         [R24.64]
                            [R49],         [R24.64+0x80]
                            [R39+0x4000],  [R30.64]
                            [R41+0x4000],  [R32.64]
                            [R43+0x4000],  [R34.64]
                            [R45+0x4000],  [R34.64+0x80]
                            [R39+0x4800],  [R36.64]
                            [R41+0x4800],  [R36.64+0x80]
                            [R47+0x4000],  [R26.64]
                            [R49+0x4000],  [R26.64+0x80]
                            [R194+0x8000], [R24.64]
                            [R198+0x8000], [R26.64]
                            [R200+0x8000], [R30.64]
                            [R202+0x8000], [R30.64+0x80]
                            [R203+0x8000], [R68.64]
                            [R204+0x8000], [R68.64+0x80]
                            [R205+0x8000], [R28.64]
                            [R206+0x8000], [R28.64+0x80]
                            [R194+0x4000], [R144.64]
                            [R198+0x4000], [R158.64]
                            [R200+0x4000], [R146.64]
                            [R202+0x4000], [R146.64+0x80]
                            [R203+0x4000], [R156.64]
                            [R204+0x4000], [R156.64+0x80]
                            [R205+0x4000], [R160.64]
                            [R206+0x4000], [R160.64+0x80]
SMEM addresses

The destination registers (first argument) reveal a stark difference. The non-swizzled kernel uses a small number of base SMEM addresses, each referenced numerous times with different offsets. Since the offsets are equidistant, the compiler embeds them directly in the instructions.

The swizzled version uses far more registers, each referenced only a few times. The graph below quantifies this: the x-axis represents how many times a register is used, while the y-axis shows the number of registers with that usage count. The non-swizzled kernel uses only 2 registers for SMEM addresses while the swizzled version uses 14 - a significant difference!

GMEM addresses

On the other hand, there isn't much of a difference between the GMEM registers for the swizzled and non-swizzled instructions. This is because we don't swizzle GMEM offsets, only SMEM.

LDSM

LDSM shows a similar pattern to LDGSTS, but the difference is even more pronounced. The second argument here is the SMEM address. I'll go into more detail on LDSM in the section for the next kernel.

Kernel 7 (No Swizzling) LDSM Snippet
...
LDSM.16.M88.4 R4,   [R205+0x8000]
LDSM.16.M88.4 R140, [R178]
LDSM.16.M88.4 R156, [R205+0x9000]
LDSM.16.M88.4 R160, [R205+0x8020]
LDSM.16.M88.4 R168, [R178+0x20]
LDSM.16.M88.4 R164, [R178+0x1000]
LDSM.16.M88.4 R172, [R178+0x1020]
LDSM.16.M88.4 R8,   [R205+0x9020]
LDSM.16.M88.4 R156, [R205+0x8040]
LDSM.16.M88.4 R168, [R178+0x40]
LDSM.16.M88.4 R160, [R205+0x9040]
LDSM.16.M88.4 R172, [R178+0x1040]
LDSM.16.M88.4 R8,   [R205+0x8060]
LDSM.16.M88.4 R168, [R178+0x60]
...
Kernel 7 (Swizzling) LDSM Snippet
...
LDSM.16.M88.4 R32,  [R18] ;
LDSM.16.M88.4 R36,  [R38] ;
LDSM.16.M88.4 R40,  [R40] ;
LDSM.16.M88.4 R44,  [R44] ;
LDSM.16.M88.4 R48,  [R48] ;
LDSM.16.M88.4 R52,  [R52] ;
LDSM.16.M88.4 R56,  [R56] ;
LDSM.16.M88.4 R60,  [R60] ;
LDSM.16.M88.4 R24,  [R207+0x4000] ;
LDSM.16.M88.4 R68,  [R183+0x4000] ;
LDSM.16.M88.4 R28,  [R177+0x4000] ;
LDSM.16.M88.4 R120, [R192+0x4000] ;
...

In total for LDSM, the non-swizzled kernel only uses 3 registers - one of those registers is used in 64 different instructions! On the other hand the swizzled kernel uses a total of 40 unique registers.

Summary

Compared to the non-swizzled version of kernel 7, the swizzled version

  • executes far more logic and bit-shift instructions
  • occupies far more registers to store SMEM addresses

Swizzling Patterns

Why do we execute so many instructions and use so many registers for swizzling?

If we examined the instruction dependencies, we'd find that virtually all address offsets are partially recomputed each iteration, even when base addresses are shared. The compiler can't optimize our swizzling computation pattern, so it allocates extra registers to cache intermediate values and inserts instructions to recompute offsets.

The solution: make swizzling more explicit. If we can expose the underlying pattern to the compiler, it should be able to cache offsets instead of recalculating them for every access.

Let's revisit the 4 bank swizzling example from kernel 2, where we read a (4, 4) tile with 4 threads, column by column.

Let's trace the access pattern for each thread:

  • Thread 0: columns 0, 1, 2, 3
  • Thread 1: columns 1, 0, 3, 2
  • Thread 2: columns 2, 3, 0, 1
  • Thread 3: columns 3, 2, 1, 0

See the pattern? Each thread follows a unique but predictable stride through memory. Looking closer:

  • The starting column for each thread equals the thread ID

  • Between the 0th and 2nd iteration & 3rd and 4th iteration,

    • threads 0 and 2 increase their index by 1, while
    • threads 1 and 3 decrease their index by 1
  • Between the 0th and 2nd iteration,

    • threads 0 and 1 increase their index by 2, while
    • threads 2 and 3 decrease their index by 2

This suggests we can encode each thread's access pattern as a stride with a base offset. If we can expose this pattern to the compiler, it should be able to cache these offsets instead of recalculating them for every access.

Let's formalize this pattern:

  • Thread starts at column (base offset)
  • Between iterations, the column changes by powers of 2: ±1, ±2, ±4, etc.
  • The sign depends on which thread we're looking at

This gives us a strided indexing scheme where each thread has its own unique stride pattern and base offset:

ThreadStrideBase Offset
0(2, 1)0
1(2, -1)1
2(-2, 1)2
3(-2, -1)3

We can extend the table to an (8, 8) tile of memory:

ThreadStrideBase Offset
0(4, 2, 1)0
1(4, 2, -1)1
2(4, -2, 1)2
3(4, -2, -1)3
4(-4, 2, 1)4
5(-4, 2, -1)5
6(-4, -2, 1)6
7(-4, -2, -1)7

Here's the key insight: we can calculate any thread's stride pattern using bit operations:

  • The magnitude of the th stride is (counting from right to left)
  • The sign is determined by the th bit of the thread ID: if set, negative; if clear, positive

Let's see how this bit-based stride calculation works in practice:

// Thread 3 has strides (-4, -2, -1) and base offset 3
int thread_id = 3;
int base_offset = 3;  // Starting column for thread 3
 
// Calculate swizzled column for each iteration
for (int iter = 0; iter < 8; iter++) {
    int col = base_offset;
    
    // Apply stride based on iteration bits
    if (iter & 4) col += (thread_id & 4) ? -4 : 4;  // Bit 2: stride ±4
    if (iter & 2) col += (thread_id & 2) ? -2 : 2;  // Bit 1: stride ±2  
    if (iter & 1) col += (thread_id & 1) ? -1 : 1;  // Bit 0: stride ±1
    
    col = col & 7;  // Wrap around within 8 columns
    
    // Thread 3 accesses columns: 3, 2, 1, 0, 7, 6, 5, 4
}

Scaling to Larger Tiles: Swizzle Regions

What happens when we have more than 8 columns or rows? Since we only have 8 banks, the swizzling pattern repeats every 8 rows/columns.

We'll call each 8×128-byte segment a swizzle region. Within each region, threads follow the same access pattern we just defined. When we move to the next region, we simply reuse the same pattern with a different base address.

This pattern is repeated.

Will using this formulation actually reduce the number of swizzling instructions? With our previous subroutine for swizzling, most of the offsets, if not all, of the offsets are recomputed every time they're accessed. However, since the access pattern is the same across swizzle regions, we only have to compute the strides / offsets once and then use those values for all regions. That should reduce the redundant computation.

How Memory Transfers Change with Strided Swizzling

For every direction of memory transfer, we need to calculate

  1. the base per-thread offset and
  2. the swizzle stride

In this kernel, we'll only update the swizzling code for copying between GMEM & SMEM and from SMEM → RF, but not from RF → SMEM. The improvements from these changes would be smaller since it would only be executed once (for copying back). There are higher impact changes to be made first, so we'll revisit this in Kernel 11 Strided Swizzling from RF → SMEM.

Note

Between kernels 7-10, the code changes significantly between iterations. To save people the trouble of reading drastically changing code, from here on I'll only include the changes that reflects the core ideas or possibly omit the code.

Strided Swizzling Between GMEM ↔ SMEM

The Problem: Multiple Offsets Per Thread

In the example above, we were loading 1 column per iteration for 8 iterations - a column stride of 1 with 8 iterations per swizzle region. Remember that happens 4 times per warp, once for each octet of threads.

However, when we load between GMEM and SMEM, each warp loads a tile per iteration, giving us only 2 iterations per swizzle region:

Each iteration within a swizzle region requires a unique offset calculation. With 2 iterations per region, we need to compute 2 offsets instead of 8, but that's still one more than ideal.

The Solution: CTA-Level Cooperation

Can we reduce this to just one offset calculation per thread? Yes! The key is reorganizing how we distribute work across the CTA (Cooperative Thread Array).

Currently, each warp loads contiguous rows independently. Instead, we can have all 4 warps in the CTA cooperate. This changes our stride from per warp to across the entire CTA. Now each iteration spans 2 complete swizzle regions, meaning we only need to compute one initial offset per thread - the compiler can encode everything else directly into the SASS instructions.

This means that for and , warps will need to cooperate CTA-wide to load the blocks. Accordingly, to avoid race conditions, we'll have to change any warp-scoped barriers (__syncwarp()) to CTA-scoped barriers (__syncthreads()).

Implementation Details

Instead of using the lane_id = thread_id % 32 to calculate the per thread memory offset, we'll use thread_id directly for CTA-wide coordination.

load_store.cuh
// Configuration for GMEM to SMEM transfers
// Handles both swizzled and non-swizzled memory layouts
template <typename Swizzle_, typename OpStride_, typename TensorShape_,
          typename SmemStride_>
struct GSMemLdstConfig {
    using Swizzle = Swizzle_;
    using OpStride = OpStride_;
    using TensorShape = TensorShape_;
    using SmemStride = SmemStride_;
 
    using OpIters = TShape<TensorShape::rows() / OpStride::row(),
                           TensorShape::cols() / OpStride::col()>;
 
    static constexpr int thrs_per_row = 8;
 
    // Convert thread ID to row coordinate in CTA grid
    static constexpr int tid_to_thr_row(int tid) { return tid / thrs_per_row; }
 
    // Convert thread ID to column coordinate in CTA grid
    static constexpr int tid_to_thr_col(int tid) {
        return (tid % thrs_per_row) * COLS_PER_FRAGMENT;
    }
 
    // Calculate per-thread offset for GMEM access
    static constexpr int gmem_thr_offset(int tid, RuntimeStride stride) {
        return tid_to_thr_row(tid) * stride.row +
               tid_to_thr_col(tid) * stride.col;
    }
 
    // Calculate per-thread offset for SMEM access (with swizzling)
    static constexpr int smem_thr_offset(int tid) {
        return Swizzle::apply(tid_to_thr_row(tid) * SmemStride::row() +
                              tid_to_thr_col(tid) * SmemStride::col());
    }
};
// ...

In the tensor wrapper class, we'll use the above helpers to set the pointers of the blocks for each thread.

tensor.cuh
// ...
        gmem_ptr = gmem_block_ptr + gsmem_::gmem_thr_offset(tid, gmem_stride);
        smem_gsmem_ptr = _smem_ptr + gsmem_::smem_thr_offset(tid);
// ...

We'll also change the __syncwarp() barriers for and to __syncthreads().

Strided Swizzling from SMEM → RF

The situation is different when we copy between SMEM and RF. We load tiles per iteration for a total of 4 iterations per swizzle region. However, we can't use the same "trick" as before where we have all warps in the CTA cooperate with each other because each warp independently loads their own tiles.

To resolve this, we'll explicitly encode the swizzling stride within a region. Note that instead of a stride of 3 dimensions like , we only need a stride with 2 dimensions because we can ignore the smallest stride element that is covered by the other column of threads.

Code

We'll create an object containing the swizzling strides.

layout.cuh
// ...
struct SwizzleStride {
    int s0; // (always 64, this is for when we cross from one swizzle region to another)
    int s1; // (+- 32)
    int s2; // (+- 16)
 
	// The iteration is of the inner loop.
    constexpr int offset(int iteration) const {
        int i0 = (iteration >> 2) & 1;
        int i1 = (iteration >> 1) & 1;
        int i2 = iteration & 1;
        return i0 * s0 + i1 * s1 + i2 * s2;
    }
};
// ...

This calculates the initial per thread offset for ldmatrix calls.

load_store.cuh
// ...
    static constexpr int lane_to_thr_offset_s2rmem(int lane_id) {
        int thread_row = lane_id % 16;
        int thread_col = (lane_id / 16) * COLS_PER_FRAGMENT;
        return Swizzle::apply(
            thread_row * SmemStride::row() +
            thread_col * SmemStride::col());
    }
// ...
tensor.cuh
        // ...
        smem_s2rmem_ptr =
            smem_srmem_ptr + srmem_::lane_to_thr_offset_s2rmem(lane_id);
        // ...

And the code to calculate the strides. We need to use lane_id = thread_id % 32 here since each warp has to load its own data.

load_store.cuh
// ...
    static constexpr SwizzleStride lane_to_thr_swizzle_stride(int lane_id) {
        if constexpr (std::is_same_v<Swizzle, NoSwizzle>) {
            return SwizzleStride{64, 32, 16};
        } else {
            // Calculate initial offset.
            int base_swizzle_offset = lane_to_thr_offset_s2rmem(lane_id);
            int base_offset_cmp = Swizzle::yy_mask_lowest_bit << 1;
            int s1 = 32 * binary_to_pm1((base_swizzle_offset &
                                         (base_offset_cmp << 1)) == 0);
            int s2 = 16 * binary_to_pm1(
                              (base_swizzle_offset & base_offset_cmp) == 0);
 
			// The 64 stride is for when we cross from one swizzle boundary to
			// the next, for instance when d_head = 128.
            return SwizzleStride{64, s1, s2};
        }
    }
// ...

Kernel 8 Memory Address Register Count

Let's make the same comparison we did earlier for the changes we just made.

For LDGSTS, only a single register is used, which is even better than the previous kernel without swizzling.

For LDSM, we've significantly improved SMEM address register reuse. Slightly less efficient than the non-swizzled version - which is expected given the more complex pattern - but still significant progress. This should result in a substantial reduction in register pressure. For , this resulted in a 90%+ drop in spills! It should be a serious contender now.

Kernel #stack_framespill_storesspill_loads
7272336304
8163216
Reduction94.1%90.5%94.7%

Executed Instructions Count

Before we take a look at the performance, let's see the difference in the number of executed instructions.

We eliminated the vast majority of all logic and shift instructions, leaving the counts comparable to the reference kernel. However, we somewhat increased the number of IMAD.MOV.U32 copy instructions.

Configuration Specification

Note: all the kernels from now on will use async copies + eager KV loads + swizzling + opt_softmax, so we'll strip those settings from the specifications.

(d_head, B_r, B_c, n_warps): {async}+{eager}+{swizzled}+load_{q_fragments}_{k_frags}_{v_frags}_fragments+{buffer}+{opt_softmax}

becomes

(d_head, B_r, B_c, n_warps): load_{q_fragments}_{k_frags}_{v_frags}_fragments+{buffer}

Summary of Optimizations

By encoding swizzling as explicit strides rather than recomputed offsets, we've achieved:

  • Reduced address calculation instructions by 90%+
  • Cut SMEM address registers:
    • for LDSM, from 40 to 5
    • for LDGSTS, from 14 to 1
  • Eliminated 90%+ of register spilling for larger block sizes

Performance

The massive reduction in register pressure finally allows us to use the larger block size - the same configuration the reference kernel uses. This was previously underperforming due to register spilling. As we discussed in Block Size Limitations, this larger block size has far more potential for higher tensor throughput.

KernelTop Performing Config
7(128, 64, 64, 4): load_0_2_2_fragments
8(128, 128, 64, 4): load_2_2_2_fragments

That's a solid 9.38% performance jump, from 149.71 TFLOPs to 163.76 TFLOPs.

Kernel Metrics

By doubling to 128 while keeping 4 warps per CTA, we halve our cp.async calls and cut ldmatrix calls by 37.5%. See Appendix B - Block Size Configuration for further details.

This translates to:

  • SMEM load utilization dropping 31.59%, from 60.49% to 41.38%
  • L2 load utilization dropping 44.90%, from 42.53% to 23.43%

The scalar pipeline utilization drops significantly:

  • ALU: down 17.98%, from 21.20% → 17.39%
  • FMA: slightly up 3.19%, from 29.23% to 30.16%
    • This FMA increase aligns with the rise in executed IMAD.MOV.U32 instructions1
  • This represented an overall decrease by 5.71%, from 50.43% → 47.55%

The total number of cycles went down, so tensor pipeline utilization increased 9.4%, from 60.47% to 66.15%

Kernel 9: Reducing Copy Instructions

Time to tackle those register copy instructions. We're executing over 100x more IMAD.MOV.U32 instructions and 4x as many MOV instructions compared to the reference.

Examining the SASS

Here's a snippet from kernel 8's assembly code. Notice the IMAD.MOV.U32 and MOV instructions interleaved between HMMA and LDSM operations. Some destination registers aren't even used before being overwritten - R201 at line 355, for example. What's going on here?

kernel_8_A100.asm
    ...
    MOV R200, R9 ;
    HMMA.16816.F32 R20, R192, R184, R20 ;
    IMAD.MOV.U32 R201, RZ, RZ, R11 ;
    IMAD.MOV.U32 R184, RZ, RZ, R57 ;
    MOV R185, R59 ;
    HMMA.16816.F32 R68, R192.reuse, R200, R68 ;
    HMMA.16816.F32 R84, R192, R184, R84 ;
    LDSM.16.M88.4 R184, [R232+0x1000] ;
    IMAD.MOV.U32 R200, RZ, RZ, R4 ;
    MOV R201, R6 ;
    HMMA.16816.F32 R196, R192, R200, R196 ;
    IMAD.MOV.U32 R200, RZ, RZ, R5 ;
    IMAD.MOV.U32 R201, RZ, RZ, R7 ;
    HMMA.16816.F32 R192, R192, R200, R88 ;
    ...

LDSM and HMMA

To understand why we have all these extra IMAD.MOV.U32 and MOV instructions, we need to examine how PTX instructions map to SASS on Ampere. The key insight is that ldmatrix and mma have strict register alignment requirements at the SASS level that aren't obvious from the PTX. Let's look at how these instructions translate to LDSM and HMMA respectively.

The ldmatrix PTX instruction has the signature,

ldmatrix.sync.aligned.m8n8.x4.b16 {d0, d1, d2, d3}, [addr];

Here's a concrete example from our kernel:

ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%r213, %r214, %r215, %r216}, [%r217];

This compiles to the following SASS instruction:

LDSM.16.M88.4 R160, [R196]

Notice that PTX specifies 4 registers (one for each fragment), while SASS only specifies 1. What's happening here?

The answer lies in how SASS encodes these special instructions:

  • SASS registers (R[0-255]) are 32-bit registers
  • LDSM and HMMA implicitly operate on groups of contiguous registers starting from the base register specified
  • In our example, R160 actually refers to registers R160, R161, R162, and R163 - the four fragments map to these four consecutive SASS registers
  • Unlike 64-bit registers which use an explicit .64 suffix (like R28.64 for GMEM addresses), there's no special notation for these 4-register groups - the instruction format itself determines how many registers are used

The mma instruction follows a similar pattern, but with one crucial difference - asymmetric operand sizes:

mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32
  {%Rd0, %Rd1, %Rd2, %Rd3}, // D (output accumulator)
  {%Ra0, %Ra1, %Ra2, %Ra3}, // A (left multiplicand)
  {%Rb0, %Rb1},             // B (right multiplicand - only 2 registers!)
  {%Rc0, %Rc1, %Rc2, %Rc3}; // C (input accumulator)

Here's a concrete example from kernel 8:

mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32
  {%f1210, %f1211, %f1212, %f1213},
  {%r646, %r647, %r648, %r649},
  {%r658, %r659},
  {%f1146, %f1147, %f1148, %f1149};

Which compiles to:

HMMA.16816.F32 R64, R8, R12, R64

The SASS operands correspond to D, A, B, C respectively. Notice the critical difference:

  • D, A, and C each use 4 consecutive registers (e.g., R64 means R64-R67)
  • B uses only 2 consecutive registers (e.g., R12 means R12-R13)

Register Alignment Requirements

The SASS encoding imposes strict alignment requirements on the base registers:

  • D, A, and C operands (4 registers each): base register must be a multiple of 4 (like R0, R4, R8, ...)
  • B operand (2 registers): base register must be a multiple of 2 (like R0, R2, R4, ...)

This means we'll never see odd register numbers as arguments to HMMA or LDSM. These alignment constraints become important when we look at how fragments flow through our kernel - let's examine what happens conceptually when we store and manipulate these fragments.

Conceptual Model of Fragment Storage

This changes how we should conceptualize how tiles are stored in RF. Even though in PTX we can specify the individual registers, this is not the case in SASS. We currently store matrices indexed by . To better reflect how tiles are actually stored, we'll retile our RF layout to have 2 layers of strides and shapes: the outer layer will contain fragment tiles, and the inner layer will index into the individual fragments in the tile.

You can essentially think of it as

uint32_t mat[fragment_rows / 2][fragment_cols / 2][2][2];

instead of

uint32_t input[fragment_rows][fragment_cols];

This conceptual model reveals why certain operations require extra instructions - the hardware's view of register organization differs from PTX's.

Where Do The Extra Instructions Come From?

Now that we understand SASS register alignment requirements, we can trace where these extra instructions originate.

When HMMA operands aren't in the expected contiguous format, the compiler must insert copy instructions. A closer examination of the SASS code reveals something interesting: the IMAD.MOV.U32 and MOV instructions only affect the B operand of the HMMA instructions.

Here's one pair

kernel_8_A100.asm
    ...
    MOV R200, R9 ;
    HMMA.16816.F32 R20, R192, R184, R20 ;
    IMAD.MOV.U32 R201, RZ, RZ, R11 ;
    IMAD.MOV.U32 R184, RZ, RZ, R57 ;
    MOV R185, R59 ;
    HMMA.16816.F32 R68, R192.reuse, R200, R68 ;
    HMMA.16816.F32 R84, R192, R184, R84 ;    
    ...

And another

kernel_8_A100.asm
    ...
    MOV R200, R9 ;
    HMMA.16816.F32 R20, R192, R184, R20 ;
    IMAD.MOV.U32 R201, RZ, RZ, R11 ;
    IMAD.MOV.U32 R184, RZ, RZ, R57 ;
    MOV R185, R59 ;
    HMMA.16816.F32 R68, R192.reuse, R200, R68 ;
    HMMA.16816.F32 R84, R192, R184, R84 ;
    ...

Which kernel operation do they correspond to: , , or both? We can use a simple trick to distinguish the two. ldmatrix corresponds to LDSM.16.M88.4 while the transpose version corresponds to LDSM.16.MT88.4. is the former, while is the latter. We can compare the destination of an LDSM instruction and find the position of the source register of the corresponding HMMA instruction(s).

After examining the SASS code, we find that the culprit is , as there are no corresponding copy instructions for LDSM.16.MT88.4 ().

Let's visualize our layout pattern for copying SMEM → RF. Each ldmatrix instruction loads a block corresponding to a tile of fragments:

ldmatrix loads the values within each tile in column major order. This is optimal for because HMMA expects the A operand as a contiguous quartet of registers in exactly this order - which is why we see no copy instructions for A.

, however, presents a problem. The B operand requires 2 registers from the same row, but column major loading places each aligned pair in the same column instead. The solution: load fragments in row major format.

With row major loading, 's fragment storage now aligns correctly:

Why doesn't have the same issue?

  • lives entirely in RF, so no loading is involved
  • is stored in RF in column major order, matching the layout within fragment tiles - the same pattern as

Optimal ldmatrix Fragment Shape

When I introduced ldmatrix in Which Fragments to Load?, I stated that is the optimal fragment shape without going into further detail. Now that we understand register alignment and swizzling constraints, let's examine why the other options don't work.

The three possible fragment shapes are: (16, 16), (32, 8), and (8, 32). Let's examine the trade-offs of each.

(32, 8) Fragment Shape

The (32, 8) shape - a column of four (8, 8) fragments - seems promising since it matches mma's layout for . But it's incompatible for and :

  • requires a tile in column major, where each column have fragments in the same row as each other
  • requires a tile in row major

So if we utilized this layout, all the work we just did would be for naught and we'd have more copies than before!

In addition, in this layout each fragment resides in a separate swizzle region. This would force us to calculate strides for all 8 iterations (covering the entire swizzle region) instead of just 4. The result would be increased register usage and additional instructions for address calculation.

(8, 32) Fragment Shape

The (8, 32) shape represents a row of four (8, 8) fragments. Its key advantage is that all fragments live in the same swizzle region, cutting iterations from 4 (for (16, 16)) down to just 2.

The issue, however, is that it doesn't match the layout for or , which both require tiles in column major. So (8, 32) would result in more copies - similar to (32, 8).

(16, 16) - The Sweet Spot

The (16, 16) shape is fully compatible with the mma instruction layout: the A fragment layout for and the B fragments for and (when transposed).

Beyond compatibility, provides two additional benefits:

  1. Shared address calculations: When both and use this shape, they share swizzle offsets and strides. Most of the common calculations execute once with shared storage.

  2. Single-instruction loads: For A matrix layouts, a single ldmatrix provides all fragments needed for an mma.m16n8k16 instruction. The alternatives - (8, 32) or (32, 8) - require at least two ldmatrix calls before the first mma, reducing our ability to interleave memory transfers with computation.

Solution Summary

The register alignment constraints of SASS HMMA instructions were forcing the compiler to insert copy instructions to reorder fragments. By switching to row major loads for , we ensure fragments arrive in the exact order HMMA expects, eliminating the need for these expensive register shuffles.

SASS Instruction Counts

Build Time Instructions

This table contains the difference in the number of register copy instructions in the kernel SASS assembly.

KernelIMAD.MOV.U32MOV
8187159
93618
Reduction81%89%

What a huge drop! We've eliminated all register copy instructions related to LDSM and HMMA. Note that these are build-time instruction counts - the runtime reduction will vary based on branching, sequence length, and other dynamic factors.

Performance

KernelTop Performing Configuration
7(128, 64, 64, 4): load_0_2_2_fragments
8(128, 128, 64, 4): load_2_2_2_fragments
9(128, 128, 64, 4): load_2_2_2_fragments

Another solid gain: 8.51% improvement from 163.76 to 177.68 TFLOPs. We've nearly closed the gap with the reference kernel.

Profile

Much better! We're still executing more IMAD.MOV.U32 instructions than the reference (FP pipeline), but the dramatic reduction in MOV instructions (INT pipeline) more than compensates.

Summed together, we're only executing 2.5% more copy instructions than the reference.

The scalar pipeline utilization drops significantly:

  • ALU: down 47.89%, from 17.39% to 9.06%
  • FMA: down 26.75%, from 30.16% to 22.09%

This represents an overall drop of 34.48%, from 47.55% → 31.15%.

The total number of cycles went down, so tensor pipeline utilization increased 8.51% from 66.15% to 71.78%.

Kernel 10: Removing CS2R Instructions + Optimizing Initial Softmax Iteration

Time to eliminate those CS2R instructions. These are used to zero-initialize in RF. You might expect the kernel to also use CS2R for , but it doesn't - gets initialized by passing the zero register RZ as the C operand to HMMA. Here's what that looks like:

		HMMA.16816.F32 R80, R16.reuse, R32, RZ ;
		HMMA.16816.F32 R60, R16.reuse, R26, RZ ;
		HMMA.16816.F32 R76, R16.reuse, R50, RZ ;

Kernel 9's assembly doesn't have equivalent code for . The fix is straightforward: explicitly handle the initial loop iteration by refactoring the loop body into a function with a boolean template parameter that distinguishes between the first and subsequent iterations. When true, we pass RZ as the C operand to zero the matrix.

While we're at it, we can also optimize the first softmax iteration. Currently, we initialize the row statistics and to and , then treat the first iteration like any other. But we can do better:

  • Initialize and directly from the first column of
  • Skip the unnecessary scaling of and since they're zero:

The code changes are straightforward, so I won't detail them here - check out the diff if you're curious.

Performance

KernelTop Performing Configuration
7(128, 64, 64, 4): load_0_2_2_fragments
8(128, 128, 64, 4): load_2_2_2_fragments
9(128, 128, 64, 4): load_2_2_2_fragments
10(128, 128, 64, 4): load_2_2_2_fragments

Unexpectedly, we regressed 1.46%, dropping from 177.68 to 175.10 TFLOPs.

This is puzzling given that we eliminated the extra softmax and CS2R instructions - these instruction counts now match the reference kernel.

The FP pipeline utilization dropped from 22.1% to 20.42% after removing the extra softmax instructions. This reduction should scale inversely with sequence length.

Regression Analysis

What went wrong? Comparing stall profiles between kernels 9 and 10 reveals the culprit:

Stall TypeKernel 9Kernel 10Delta
short_scoreboard2.36%7.79%+5.44%
wait36.99%38.48%+1.49%

That's a massive jump in short_scoreboard stalls. Let's examine where these stalls occur. In the snippets below, the last instruction is where the stall happens, while the first instruction is the dependency it's waiting on:

kernel_9_A100.asm
    ...
    LDSM.16.M88.4 R216, [R213+0x2000] ;
    HMMA.16816.F32 R60, R204, R4, R60 ;
    HMMA.16816.F32 R64, R204, R6, R64 ;
    LDSM.16.M88.4 R4, [R213+0x3000] ;
    LDSM.16.M88.4 R204, [R235+-0x40] ;
    HMMA.16816.F32 R92, R228, R224, R92 ;
    HMMA.16816.F32 R96, R228, R226, R96 ;
    HMMA.16816.F32 R68, R228, R220, R68 ;
    HMMA.16816.F32 R72, R228, R222, R72 ;
    HMMA.16816.F32 R76, R228, R216, R76 ;
    ...
Link to full assembly
kernel_10_A100.asm
    ...
    LDSM.16.M88.4 R136, [R242] ;
    LDSM.16.M88.4 R4, [R232+-0x40] ;
    HMMA.16816.F32 R128, R224, R136, R128 ;
    ...
Link to full assembly

The short_scoreboard stalls are happening because HMMA instructions are waiting on data from LDSM operations. The problem repeats throughout the kernel.

Here's what's interesting: Kernel 9 handles this much better. It packs around 8 independent instructions between the dependency and its use, helping mask the LDSM latency. Kernel 10, on the other hand, only has 1-2 instructions in those critical gaps. The compiler scheduled the LDSM → HMMA sequence less optimally, creating more opportunities for short_scoreboard stalls.

Why did the compiler make this choice? It comes down to slightly increased register pressure - more values need to stay alive between blocks that are branched to.

Despite the regression, Kernel 10's changes lay important groundwork. Removing the CS2R instructions and optimizing the initial softmax iteration eliminate unnecessary work. The compiler just needs better instruction scheduling to realize the benefits - we'll recover this performance and more in subsequent kernels.

Kernel 11: Strided Swizzling from RF → SMEM

In this iteration, we implement the final piece of strided swizzling: writing from RF → SMEM. Remember back in Kernel 8 when we said we'd revisit this optimization later? Now's the time. We saved it for last because this transfer only executes once per CTA - right at the very end when writing the final output. Unlike the GMEM ↔ SMEM and SMEM → RF transfers that happen repeatedly throughout execution, the impact here is inherently limited. Still, with all the major optimizations complete, this is worth implementing for completeness.

Why RF → SMEM Swizzling Differs

The RF → SMEM transfer has some unique characteristics compared to the other directions we've optimized:

  • Threads within a warp are arranged in an grid (8 rows, 4 threads per row)
  • Unlike the vectorized loads we saw earlier, individual store operations aren't vectorized. This means we copy:
    • 4B per thread
    • 16B per row (4 threads × 4B each)
    • 128B total per instruction warp-wide
  • Each thread in a row shares the same stride but has a different offset: (tid % 4) * 2

This arrangement lets us share stride calculations across threads in the same row. Here's what the mapping looks like for the second iteration of a copy within a swizzle region:

Performance

KernelTop Performing Configuration
7(128, 64, 64, 4): load_0_2_2_fragments
8(128, 128, 64, 4): load_2_2_2_fragments
9(128, 128, 64, 4): load_2_2_2_fragments
10(128, 128, 64, 4): load_2_2_2_fragments
11(128, 128, 64, 4): load_2_2_2_fragments

We improved 1.31%, from 175.10 to 177.40 TFLOPs - nearly recovering the regression from Kernel 10. The improvement might seem modest, but it's exactly what we'd expect. The RF → SMEM transfer only executes once per CTA at the very end, unlike the GMEM → SMEM and SMEM → RF transfers that happen repeatedly throughout execution.

Summary

Through kernels 8-11, we systematically eliminated unnecessary instructions by:

  1. Kernel 8:
    1. Implementing strided swizzling to virtually eliminate all unnecessary logic/shift instructions
    2. Significantly reduce the register pressure, making the most performant kernel
  2. Kernel 9: Optimizing fragment storage to eliminate redundant register copies
  3. Kernel 10: Removing CS2R instructions and optimizing initial softmax (minor regression due to instruction scheduling)
  4. Kernel 11: Completing strided swizzling implementation for RF→SMEM transfers

The net result: 18.6% performance improvement from Kernel 7 (149.7 TFLOPs) to Kernel 11 (177.4 TFLOPs), bringing us within striking distance of reference performance (186.4 TFLOPs) on the A100.

In Part 9, we'll tackle the final stretch with a different assortment of optimizations. These minor refinements will push our final performance from 95.2% to 99.2% of reference.

Footnotes

  1. IMAD.MOV.U32 are executed on the FMA pipeline as we discussed in Appendix A - Ampere Microarchitecture.