Intro
In Part 6, we improved our kernel to slightly outperform the reference kernel on the RTX 3090, but found that on the A100, it only reached 80.3% of the reference. In Part 7, we profiled kernel 7 and identified the key bottlenecks:
-
Excessive scalar instructions: We're executing roughly 2x the scalar instructions and 6.8-11.3x more dependency-creating instructions compared to the reference kernel. These dependencies block our
cp.async,ldmatrix, andmmaoperations, forcing the high-throughput tensor cores to sit idle. -
Register pressure limits block size: Kernel 7 spills 272 bytes of registers per thread at
- the optimal block configuration on the A100. These expensive memory accesses negate the benefits of larger blocks, forcing us to use smaller, less efficient configurations. -
Throughput imbalance amplifies the problem: The A100 has much higher tensor throughput but lower FP32 throughput than the RTX 3090. This lopsided ratio means our scalar instruction overhead creates a severe bottleneck - the tensor cores can process work much faster than our scalar units can prepare it.
In this part, we'll systematically eliminate these unnecessary instructions through four kernel iterations:
- Kernel 8: Implement strided swizzling to eliminate logic and bit-shift instructions while reducing register pressure
- Kernel 9: Optimize fragment storage to remove redundant register copies
- Kernel 10: Remove CS2R instructions and optimize the initial softmax iteration
- Kernel 11: Complete strided swizzling for RF→SMEM transfers
By the end, we'll achieve an 18.6% performance improvement, jumping from 149.71 TFLOPs (kernel 7) to 177.40 TFLOPs (kernel 11). That puts us within striking distance of the reference performance (186.4 TFLOPs) on the A100.
Kernel 8: Reducing Logic and Bit Shift Instructions & Register Usage
When we profiled kernel 7, we noticed excessive logic and bit shift instructions. These shouldn't be here - we're running a FP and tensor-heavy workload. The culprits are LOP3.LUT, IMAD.SHL.U32, and SHF.L.U32, which typically show up in memory address calculations. Let's dig into the assembly code to confirm this hypothesis.
Grouping Instructions by Operation Type
The compiler often uses different instructions for similar purposes. For instance:
- certain parts of address calculation might use either
LEAorIMAD- a reduction in
IMAD.MOV.U32instructions might be offset by an increase inMOVinstructionsInstead of trying to reduce a specific instruction, group related instructions together and work to eliminate the total sum.
Kernel 7 SASS Examination
Let's examine snippets from kernel 7's SASS code (full code here) to see where these instructions appear.
Notice how these logic and shift instructions appear right before each memory transfer - they're calculating addresses for LDGSTS (cp.async) and LDSM (ldmatrix) operations:
// ...
SHF.L.U32 R215, R215, 0x1, RZ ;
LOP3.LUT R178, R221, 0x1800, R8, 0xfe, !PT ;
LDGSTS.E.BYPASS.LTC128B.128 [R205+0x8000], [R28.64] ;
LOP3.LUT R30, R9, 0x7, R0, 0x78, !PT ;
IMAD.SHL.U32 R191, R191, 0x2, RZ ;
SHF.L.U32 R178, R178, 0x1, RZ ;
LDGSTS.E.BYPASS.LTC128B.128 [R206+0x8000], [R28.64+0x80] ;
IMAD.SHL.U32 R125, R30, 0x8, RZ ;
LDSM.16.M88.4 R24, [R207+0x4000] ;
LOP3.LUT R192, R125.reuse, R8, RZ, 0xfc, !PT ;
LDSM.16.M88.4 R68, [R183+0x4000] ;
LOP3.LUT R181, R125, 0x800, R8, 0xfe, !PT ;
SHF.L.U32 R192, R192, 0x1, RZ ;
LDSM.16.M88.4 R28, [R177+0x4000] ;
LOP3.LUT R176, R125, 0x1000, R8.reuse, 0xfe, !PT ;
SHF.L.U32 R181, R181, 0x1, RZ ;
LDSM.16.M88.4 R120, [R192+0x4000] ;
LOP3.LUT R173, R125, 0x1800, R8, 0xfe, !PT ;
IMAD.SHL.U32 R176, R176, 0x2, RZ ;
// ... Address Calculation
These instructions are indeed computing memory addresses. For LDSM, the second argument is the SMEM address. Here are a couple examples:
364: LOP3.LUT R192, R125.reuse, R8, RZ, 0xfc, !PT ;
367: SHF.L.U32 R192, R192, 0x1, RZ ;
371: LDSM.16.M88.4 R120, [R192+0x4000] ;369: LOP3.LUT R176, R125, 0x1000, R8.reuse, 0xfe, !PT ;
373: IMAD.SHL.U32 R176, R176, 0x2, RZ ;
376: LDSM.16.M88.4 R156, [R176+0x4000] ;Source of Instructions
These address calculations are far more complex than they should be. The culprit? Our swizzling implementation from Part 4. While it eliminated bank conflicts, it made our memory access patterns so complex that the compiler can't optimize them efficiently.
Let's test this hypothesis by comparing the instruction counts for swizzled and non-swizzled versions of kernel 7.
Swizzled vs Non-Swizzled
Instruction Count
After profiling the instruction counts, we can confirm our hypothesis. The swizzled version of our kernel executes 12x-42x as many of these instructions as the non-swizzled one.
Register Usage
Beyond instruction overhead, swizzling causes another problem: register pressure. The compiler must allocate registers to store all these computed memory addresses. Let's examine how many registers are needed by comparing the LDGSTS and LDSM operations in swizzled and non-swizzled versions.
LDGSTS (GMEM → SMEM)
Here is an example of a LDGSTS instruction
LDGSTS.E.BYPASS.LTC128B.128 [R19], [R4.64]
- The first register is the destination address (SMEM, 32-bit)
- and the second is the source address (GMEM, 64-bit)
- The
.64modifier in the src register indicates that it's a 64 bit value.
- The
- The
.BYPASSmodifier specifies that we bypass the L1 cache.
These snippets contain every LDGSTS instruction from both kernels.
LDGSTS.E.BYPASS.LTC128B.128 [R147], [R6.64]
... [R145+0x80], [R8.64]
[R145+0x400], [R10.64]
[R145+0x480], [R10.64+0x80]
[R145+0x800], [R12.64]
[R145+0x880], [R12.64+0x80]
[R145+0xc00], [R4.64]
[R145+0xc80], [R4.64+0x80]
[R147+0x4000], [R14.64]
[R145+0x4080], [R16.64]
[R145+0x4400], [R18.64]
[R145+0x4480], [R18.64+0x80]
[R145+0x4800], [R8.64]
[R145+0x4880], [R8.64+0x80]
[R145+0x4c00], [R6.64]
[R145+0x4c80], [R6.64+0x80]
[R147+0x8000], [R40.64]
[R145+0x8080], [R44.64]
[R145+0x8400], [R58.64]
[R145+0x8480], [R58.64+0x80]
[R145+0x8800], [R60.64]
[R145+0x8880], [R60.64+0x80]
[R145+0x8c00], [R56.64]
[R145+0x8c80], [R56.64+0x80]
[R147+0x4000], [R136.64]
[R145+0x4080], [R138.64]
[R145+0x4400], [R154.64]
[R145+0x4480], [R154.64+0x80]
[R145+0x4800], [R148.64]
[R145+0x4880], [R148.64+0x80]
[R145+0x4c00], [R164.64]
[R145+0x4c80], [R164.64+0x80]LDGSTS.E.BYPASS.LTC128B.128 [R39], [R24.64]
... [R41], [R16.64]
[R43], [R26.64]
[R45], [R26.64+0x80]
[R39+0x800], [R28.64]
[R41+0x800], [R28.64+0x80]
[R47], [R24.64]
[R49], [R24.64+0x80]
[R39+0x4000], [R30.64]
[R41+0x4000], [R32.64]
[R43+0x4000], [R34.64]
[R45+0x4000], [R34.64+0x80]
[R39+0x4800], [R36.64]
[R41+0x4800], [R36.64+0x80]
[R47+0x4000], [R26.64]
[R49+0x4000], [R26.64+0x80]
[R194+0x8000], [R24.64]
[R198+0x8000], [R26.64]
[R200+0x8000], [R30.64]
[R202+0x8000], [R30.64+0x80]
[R203+0x8000], [R68.64]
[R204+0x8000], [R68.64+0x80]
[R205+0x8000], [R28.64]
[R206+0x8000], [R28.64+0x80]
[R194+0x4000], [R144.64]
[R198+0x4000], [R158.64]
[R200+0x4000], [R146.64]
[R202+0x4000], [R146.64+0x80]
[R203+0x4000], [R156.64]
[R204+0x4000], [R156.64+0x80]
[R205+0x4000], [R160.64]
[R206+0x4000], [R160.64+0x80]SMEM addresses
The destination registers (first argument) reveal a stark difference. The non-swizzled kernel uses a small number of base SMEM addresses, each referenced numerous times with different offsets. Since the offsets are equidistant, the compiler embeds them directly in the instructions.
The swizzled version uses far more registers, each referenced only a few times. The graph below quantifies this: the x-axis represents how many times a register is used, while the y-axis shows the number of registers with that usage count. The non-swizzled kernel uses only 2 registers for SMEM addresses while the swizzled version uses 14 - a significant difference!
GMEM addresses
On the other hand, there isn't much of a difference between the GMEM registers for the swizzled and non-swizzled instructions. This is because we don't swizzle GMEM offsets, only SMEM.
LDSM
LDSM shows a similar pattern to LDGSTS, but the difference is even more pronounced. The second argument here is the SMEM address. I'll go into more detail on LDSM in the section for the next kernel.
...
LDSM.16.M88.4 R4, [R205+0x8000]
LDSM.16.M88.4 R140, [R178]
LDSM.16.M88.4 R156, [R205+0x9000]
LDSM.16.M88.4 R160, [R205+0x8020]
LDSM.16.M88.4 R168, [R178+0x20]
LDSM.16.M88.4 R164, [R178+0x1000]
LDSM.16.M88.4 R172, [R178+0x1020]
LDSM.16.M88.4 R8, [R205+0x9020]
LDSM.16.M88.4 R156, [R205+0x8040]
LDSM.16.M88.4 R168, [R178+0x40]
LDSM.16.M88.4 R160, [R205+0x9040]
LDSM.16.M88.4 R172, [R178+0x1040]
LDSM.16.M88.4 R8, [R205+0x8060]
LDSM.16.M88.4 R168, [R178+0x60]
......
LDSM.16.M88.4 R32, [R18] ;
LDSM.16.M88.4 R36, [R38] ;
LDSM.16.M88.4 R40, [R40] ;
LDSM.16.M88.4 R44, [R44] ;
LDSM.16.M88.4 R48, [R48] ;
LDSM.16.M88.4 R52, [R52] ;
LDSM.16.M88.4 R56, [R56] ;
LDSM.16.M88.4 R60, [R60] ;
LDSM.16.M88.4 R24, [R207+0x4000] ;
LDSM.16.M88.4 R68, [R183+0x4000] ;
LDSM.16.M88.4 R28, [R177+0x4000] ;
LDSM.16.M88.4 R120, [R192+0x4000] ;
...In total for LDSM, the non-swizzled kernel only uses 3 registers - one of those registers is used in 64 different instructions! On the other hand the swizzled kernel uses a total of 40 unique registers.
Summary
Compared to the non-swizzled version of kernel 7, the swizzled version
- executes far more logic and bit-shift instructions
- occupies far more registers to store SMEM addresses
Swizzling Patterns
Why do we execute so many instructions and use so many registers for swizzling?
If we examined the instruction dependencies, we'd find that virtually all address offsets are partially recomputed each iteration, even when base addresses are shared. The compiler can't optimize our swizzling computation pattern, so it allocates extra registers to cache intermediate values and inserts instructions to recompute offsets.
The solution: make swizzling more explicit. If we can expose the underlying pattern to the compiler, it should be able to cache offsets instead of recalculating them for every access.
Let's revisit the 4 bank swizzling example from kernel 2, where we read a (4, 4) tile with 4 threads, column by column.
Let's trace the access pattern for each thread:
- Thread 0: columns 0, 1, 2, 3
- Thread 1: columns 1, 0, 3, 2
- Thread 2: columns 2, 3, 0, 1
- Thread 3: columns 3, 2, 1, 0
See the pattern? Each thread follows a unique but predictable stride through memory. Looking closer:
-
The starting column for each thread equals the thread ID
-
Between the 0th and 2nd iteration & 3rd and 4th iteration,
- threads 0 and 2 increase their index by 1, while
- threads 1 and 3 decrease their index by 1
-
Between the 0th and 2nd iteration,
- threads 0 and 1 increase their index by 2, while
- threads 2 and 3 decrease their index by 2
This suggests we can encode each thread's access pattern as a stride with a base offset. If we can expose this pattern to the compiler, it should be able to cache these offsets instead of recalculating them for every access.
Let's formalize this pattern:
- Thread
starts at column (base offset) - Between iterations, the column changes by powers of 2: ±1, ±2, ±4, etc.
- The sign depends on which thread we're looking at
This gives us a strided indexing scheme where each thread has its own unique stride pattern and base offset:
| Thread | Stride | Base Offset |
|---|---|---|
| 0 | (2, 1) | 0 |
| 1 | (2, -1) | 1 |
| 2 | (-2, 1) | 2 |
| 3 | (-2, -1) | 3 |
We can extend the table to an (8, 8) tile of memory:
| Thread | Stride | Base Offset |
|---|---|---|
| 0 | (4, 2, 1) | 0 |
| 1 | (4, 2, -1) | 1 |
| 2 | (4, -2, 1) | 2 |
| 3 | (4, -2, -1) | 3 |
| 4 | (-4, 2, 1) | 4 |
| 5 | (-4, 2, -1) | 5 |
| 6 | (-4, -2, 1) | 6 |
| 7 | (-4, -2, -1) | 7 |
Here's the key insight: we can calculate any thread's stride pattern using bit operations:
- The magnitude of the
th stride is (counting from right to left) - The sign is determined by the
th bit of the thread ID: if set, negative; if clear, positive
Let's see how this bit-based stride calculation works in practice:
// Thread 3 has strides (-4, -2, -1) and base offset 3
int thread_id = 3;
int base_offset = 3; // Starting column for thread 3
// Calculate swizzled column for each iteration
for (int iter = 0; iter < 8; iter++) {
int col = base_offset;
// Apply stride based on iteration bits
if (iter & 4) col += (thread_id & 4) ? -4 : 4; // Bit 2: stride ±4
if (iter & 2) col += (thread_id & 2) ? -2 : 2; // Bit 1: stride ±2
if (iter & 1) col += (thread_id & 1) ? -1 : 1; // Bit 0: stride ±1
col = col & 7; // Wrap around within 8 columns
// Thread 3 accesses columns: 3, 2, 1, 0, 7, 6, 5, 4
}Scaling to Larger Tiles: Swizzle Regions
What happens when we have more than 8 columns or rows? Since we only have 8 banks, the swizzling pattern repeats every 8 rows/columns.
We'll call each 8×128-byte segment a swizzle region. Within each region, threads follow the same access pattern we just defined. When we move to the next region, we simply reuse the same pattern with a different base address.
This pattern is repeated.
Will using this formulation actually reduce the number of swizzling instructions? With our previous subroutine for swizzling, most of the offsets, if not all, of the offsets are recomputed every time they're accessed. However, since the access pattern is the same across swizzle regions, we only have to compute the strides / offsets once and then use those values for all regions. That should reduce the redundant computation.
How Memory Transfers Change with Strided Swizzling
For every direction of memory transfer, we need to calculate
- the base per-thread offset and
- the swizzle stride
In this kernel, we'll only update the swizzling code for copying between GMEM & SMEM and from SMEM → RF, but not from RF → SMEM. The improvements from these changes would be smaller since it would only be executed once (for copying
Note
Between kernels 7-10, the code changes significantly between iterations. To save people the trouble of reading drastically changing code, from here on I'll only include the changes that reflects the core ideas or possibly omit the code.
Strided Swizzling Between GMEM ↔ SMEM
The Problem: Multiple Offsets Per Thread
In the example above, we were loading 1 column per iteration for 8 iterations - a column stride of 1 with 8 iterations per swizzle region. Remember that happens 4 times per warp, once for each octet of threads.
However, when we load between GMEM and SMEM, each warp loads a
Each iteration within a swizzle region requires a unique offset calculation. With 2 iterations per region, we need to compute 2 offsets instead of 8, but that's still one more than ideal.
The Solution: CTA-Level Cooperation
Can we reduce this to just one offset calculation per thread? Yes! The key is reorganizing how we distribute work across the CTA (Cooperative Thread Array).
Currently, each warp loads contiguous rows independently. Instead, we can have all 4 warps in the CTA cooperate. This changes our stride from
This means that for __syncwarp()) to CTA-scoped barriers (__syncthreads()).
Implementation Details
Instead of using the lane_id = thread_id % 32 to calculate the per thread memory offset, we'll use thread_id directly for CTA-wide coordination.
// Configuration for GMEM to SMEM transfers
// Handles both swizzled and non-swizzled memory layouts
template <typename Swizzle_, typename OpStride_, typename TensorShape_,
typename SmemStride_>
struct GSMemLdstConfig {
using Swizzle = Swizzle_;
using OpStride = OpStride_;
using TensorShape = TensorShape_;
using SmemStride = SmemStride_;
using OpIters = TShape<TensorShape::rows() / OpStride::row(),
TensorShape::cols() / OpStride::col()>;
static constexpr int thrs_per_row = 8;
// Convert thread ID to row coordinate in CTA grid
static constexpr int tid_to_thr_row(int tid) { return tid / thrs_per_row; }
// Convert thread ID to column coordinate in CTA grid
static constexpr int tid_to_thr_col(int tid) {
return (tid % thrs_per_row) * COLS_PER_FRAGMENT;
}
// Calculate per-thread offset for GMEM access
static constexpr int gmem_thr_offset(int tid, RuntimeStride stride) {
return tid_to_thr_row(tid) * stride.row +
tid_to_thr_col(tid) * stride.col;
}
// Calculate per-thread offset for SMEM access (with swizzling)
static constexpr int smem_thr_offset(int tid) {
return Swizzle::apply(tid_to_thr_row(tid) * SmemStride::row() +
tid_to_thr_col(tid) * SmemStride::col());
}
};
// ...In the tensor wrapper class, we'll use the above helpers to set the pointers of the blocks for each thread.
// ...
gmem_ptr = gmem_block_ptr + gsmem_::gmem_thr_offset(tid, gmem_stride);
smem_gsmem_ptr = _smem_ptr + gsmem_::smem_thr_offset(tid);
// ...We'll also change the __syncwarp() barriers for __syncthreads().
Strided Swizzling from SMEM → RF
The situation is different when we copy between SMEM and RF. We load
To resolve this, we'll explicitly encode the swizzling stride within a region. Note that instead of a stride of 3 dimensions like
Code
We'll create an object containing the swizzling strides.
// ...
struct SwizzleStride {
int s0; // (always 64, this is for when we cross from one swizzle region to another)
int s1; // (+- 32)
int s2; // (+- 16)
// The iteration is of the inner loop.
constexpr int offset(int iteration) const {
int i0 = (iteration >> 2) & 1;
int i1 = (iteration >> 1) & 1;
int i2 = iteration & 1;
return i0 * s0 + i1 * s1 + i2 * s2;
}
};
// ...This calculates the initial per thread offset for ldmatrix calls.
// ...
static constexpr int lane_to_thr_offset_s2rmem(int lane_id) {
int thread_row = lane_id % 16;
int thread_col = (lane_id / 16) * COLS_PER_FRAGMENT;
return Swizzle::apply(
thread_row * SmemStride::row() +
thread_col * SmemStride::col());
}
// ... // ...
smem_s2rmem_ptr =
smem_srmem_ptr + srmem_::lane_to_thr_offset_s2rmem(lane_id);
// ...And the code to calculate the strides. We need to use lane_id = thread_id % 32 here since each warp has to load its own data.
// ...
static constexpr SwizzleStride lane_to_thr_swizzle_stride(int lane_id) {
if constexpr (std::is_same_v<Swizzle, NoSwizzle>) {
return SwizzleStride{64, 32, 16};
} else {
// Calculate initial offset.
int base_swizzle_offset = lane_to_thr_offset_s2rmem(lane_id);
int base_offset_cmp = Swizzle::yy_mask_lowest_bit << 1;
int s1 = 32 * binary_to_pm1((base_swizzle_offset &
(base_offset_cmp << 1)) == 0);
int s2 = 16 * binary_to_pm1(
(base_swizzle_offset & base_offset_cmp) == 0);
// The 64 stride is for when we cross from one swizzle boundary to
// the next, for instance when d_head = 128.
return SwizzleStride{64, s1, s2};
}
}
// ...Kernel 8 Memory Address Register Count
Let's make the same comparison we did earlier for the changes we just made.
For LDGSTS, only a single register is used, which is even better than the previous kernel without swizzling.
For LDSM, we've significantly improved SMEM address register reuse. Slightly less efficient than the non-swizzled version - which is expected given the more complex pattern - but still significant progress. This should result in a substantial reduction in register pressure. For
| Kernel # | stack_frame | spill_stores | spill_loads |
|---|---|---|---|
| 7 | 272 | 336 | 304 |
| 8 | 16 | 32 | 16 |
| Reduction | 94.1% | 90.5% | 94.7% |
Executed Instructions Count
Before we take a look at the performance, let's see the difference in the number of executed instructions.
We eliminated the vast majority of all logic and shift instructions, leaving the counts comparable to the reference kernel. However, we somewhat increased the number of IMAD.MOV.U32 copy instructions.
Configuration Specification
Note: all the kernels from now on will use async copies + eager KV loads + swizzling + opt_softmax, so we'll strip those settings from the specifications.
(d_head, B_r, B_c, n_warps): {async}+{eager}+{swizzled}+load_{q_fragments}_{k_frags}_{v_frags}_fragments+{buffer}+{opt_softmax}
becomes
(d_head, B_r, B_c, n_warps): load_{q_fragments}_{k_frags}_{v_frags}_fragments+{buffer}
Summary of Optimizations
By encoding swizzling as explicit strides rather than recomputed offsets, we've achieved:
- Reduced address calculation instructions by 90%+
- Cut SMEM address registers:
- for
LDSM, from 40 to 5 - for
LDGSTS, from 14 to 1
- for
- Eliminated 90%+ of register spilling for larger block sizes
Performance
The massive reduction in register pressure finally allows us to use the larger
| Kernel | Top Performing Config |
|---|---|
| 7 | (128, 64, 64, 4): load_0_2_2_fragments |
| 8 | (128, 128, 64, 4): load_2_2_2_fragments |
That's a solid 9.38% performance jump, from 149.71 TFLOPs to 163.76 TFLOPs.
Kernel Metrics
By doubling cp.async calls and cut ldmatrix calls by 37.5%. See Appendix B - Block Size Configuration for further details.
This translates to:
- SMEM load utilization dropping 31.59%, from 60.49% to 41.38%
- L2 load utilization dropping 44.90%, from 42.53% to 23.43%
The scalar pipeline utilization drops significantly:
- ALU: down 17.98%, from 21.20% → 17.39%
- FMA: slightly up 3.19%, from 29.23% to 30.16%
- This FMA increase aligns with the rise in executed
IMAD.MOV.U32instructions1
- This FMA increase aligns with the rise in executed
- This represented an overall decrease by 5.71%, from 50.43% → 47.55%
The total number of cycles went down, so tensor pipeline utilization increased 9.4%, from 60.47% to 66.15%
Kernel 9: Reducing Copy Instructions
Time to tackle those register copy instructions. We're executing over 100x more IMAD.MOV.U32 instructions and 4x as many MOV instructions compared to the reference.
Examining the SASS
Here's a snippet from kernel 8's assembly code. Notice the IMAD.MOV.U32 and MOV instructions interleaved between HMMA and LDSM operations. Some destination registers aren't even used before being overwritten - R201 at line 355, for example. What's going on here?
...
MOV R200, R9 ;
HMMA.16816.F32 R20, R192, R184, R20 ;
IMAD.MOV.U32 R201, RZ, RZ, R11 ;
IMAD.MOV.U32 R184, RZ, RZ, R57 ;
MOV R185, R59 ;
HMMA.16816.F32 R68, R192.reuse, R200, R68 ;
HMMA.16816.F32 R84, R192, R184, R84 ;
LDSM.16.M88.4 R184, [R232+0x1000] ;
IMAD.MOV.U32 R200, RZ, RZ, R4 ;
MOV R201, R6 ;
HMMA.16816.F32 R196, R192, R200, R196 ;
IMAD.MOV.U32 R200, RZ, RZ, R5 ;
IMAD.MOV.U32 R201, RZ, RZ, R7 ;
HMMA.16816.F32 R192, R192, R200, R88 ;
...LDSM and HMMA
To understand why we have all these extra IMAD.MOV.U32 and MOV instructions, we need to examine how PTX instructions map to SASS on Ampere. The key insight is that ldmatrix and mma have strict register alignment requirements at the SASS level that aren't obvious from the PTX. Let's look at how these instructions translate to LDSM and HMMA respectively.
The ldmatrix PTX instruction has the signature,
ldmatrix.sync.aligned.m8n8.x4.b16 {d0, d1, d2, d3}, [addr];
Here's a concrete example from our kernel:
ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%r213, %r214, %r215, %r216}, [%r217];
This compiles to the following SASS instruction:
LDSM.16.M88.4 R160, [R196]
Notice that PTX specifies 4 registers (one for each fragment), while SASS only specifies 1. What's happening here?
The answer lies in how SASS encodes these special instructions:
- SASS registers (
R[0-255]) are 32-bit registers LDSMandHMMAimplicitly operate on groups of contiguous registers starting from the base register specified- In our example,
R160actually refers to registersR160,R161,R162, andR163- the four fragments map to these four consecutive SASS registers - Unlike 64-bit registers which use an explicit
.64suffix (likeR28.64for GMEM addresses), there's no special notation for these 4-register groups - the instruction format itself determines how many registers are used
The mma instruction follows a similar pattern, but with one crucial difference - asymmetric operand sizes:
mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32
{%Rd0, %Rd1, %Rd2, %Rd3}, // D (output accumulator)
{%Ra0, %Ra1, %Ra2, %Ra3}, // A (left multiplicand)
{%Rb0, %Rb1}, // B (right multiplicand - only 2 registers!)
{%Rc0, %Rc1, %Rc2, %Rc3}; // C (input accumulator)
Here's a concrete example from kernel 8:
mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32
{%f1210, %f1211, %f1212, %f1213},
{%r646, %r647, %r648, %r649},
{%r658, %r659},
{%f1146, %f1147, %f1148, %f1149};
Which compiles to:
HMMA.16816.F32 R64, R8, R12, R64
The SASS operands correspond to D, A, B, C respectively. Notice the critical difference:
- D, A, and C each use 4 consecutive registers (e.g.,
R64meansR64-R67) - B uses only 2 consecutive registers (e.g.,
R12meansR12-R13)
Register Alignment Requirements
The SASS encoding imposes strict alignment requirements on the base registers:
- D, A, and C operands (4 registers each): base register must be a multiple of 4 (like
R0,R4,R8, ...) - B operand (2 registers): base register must be a multiple of 2 (like
R0,R2,R4, ...)
This means we'll never see odd register numbers as arguments to HMMA or LDSM. These alignment constraints become important when we look at how fragments flow through our kernel - let's examine what happens conceptually when we store and manipulate these fragments.
Conceptual Model of Fragment Storage
This changes how we should conceptualize how tiles are stored in RF. Even though in PTX we can specify the individual registers, this is not the case in SASS. We currently store matrices indexed by
You can essentially think of it as
uint32_t mat[fragment_rows / 2][fragment_cols / 2][2][2];instead of
uint32_t input[fragment_rows][fragment_cols];This conceptual model reveals why certain operations require extra instructions - the hardware's view of register organization differs from PTX's.
Where Do The Extra Instructions Come From?
Now that we understand SASS register alignment requirements, we can trace where these extra instructions originate.
When HMMA operands aren't in the expected contiguous format, the compiler must insert copy instructions. A closer examination of the SASS code reveals something interesting: the IMAD.MOV.U32 and MOV instructions only affect the B operand of the HMMA instructions.
Here's one pair
...
MOV R200, R9 ;
HMMA.16816.F32 R20, R192, R184, R20 ;
IMAD.MOV.U32 R201, RZ, RZ, R11 ;
IMAD.MOV.U32 R184, RZ, RZ, R57 ;
MOV R185, R59 ;
HMMA.16816.F32 R68, R192.reuse, R200, R68 ;
HMMA.16816.F32 R84, R192, R184, R84 ;
...And another
...
MOV R200, R9 ;
HMMA.16816.F32 R20, R192, R184, R20 ;
IMAD.MOV.U32 R201, RZ, RZ, R11 ;
IMAD.MOV.U32 R184, RZ, RZ, R57 ;
MOV R185, R59 ;
HMMA.16816.F32 R68, R192.reuse, R200, R68 ;
HMMA.16816.F32 R84, R192, R184, R84 ;
...Which kernel operation do they correspond to: ldmatrix corresponds to LDSM.16.M88.4 while the transpose version corresponds to LDSM.16.MT88.4. LDSM instruction and find the position of the source register of the corresponding HMMA instruction(s).
After examining the SASS code, we find that the culprit is LDSM.16.MT88.4 (
Let's visualize our layout pattern for copying SMEM → RF. Each ldmatrix instruction loads a
ldmatrix loads the values within each HMMA expects the A operand as a contiguous quartet of registers in exactly this order - which is why we see no copy instructions for A.
B operand requires 2 registers from the same row, but column major loading places each aligned pair in the same column instead. The solution: load
With row major loading,
Why doesn't
lives entirely in RF, so no loading is involved is stored in RF in column major order, matching the layout within fragment tiles - the same pattern as
Optimal ldmatrix Fragment Shape
When I introduced ldmatrix in Which Fragments to Load?, I stated that
The three possible fragment shapes are: (16, 16), (32, 8), and (8, 32). Let's examine the trade-offs of each.
(32, 8) Fragment Shape
The (32, 8) shape - a column of four (8, 8) fragments - seems promising since it matches mma's layout for
requires a tile in column major, where each column have fragments in the same row as each other requires a tile in row major
So if we utilized this layout, all the work we just did would be for naught and we'd have more copies than before!
In addition, in this layout each fragment resides in a separate swizzle region. This would force us to calculate strides for all 8 iterations (covering the entire swizzle region) instead of just 4. The result would be increased register usage and additional instructions for address calculation.
(8, 32) Fragment Shape
The (8, 32) shape represents a row of four (8, 8) fragments. Its key advantage is that all fragments live in the same swizzle region, cutting iterations from 4 (for (16, 16)) down to just 2.
The issue, however, is that it doesn't match the layout for (8, 32) would result in more copies - similar to (32, 8).
(16, 16) - The Sweet Spot
The (16, 16) shape is fully compatible with the mma instruction layout: the A fragment layout for B fragments for
Beyond compatibility,
-
Shared address calculations: When both
and use this shape, they share swizzle offsets and strides. Most of the common calculations execute once with shared storage. -
Single-instruction loads: For
Amatrix layouts, a singleldmatrixprovides all fragments needed for anmma.m16n8k16instruction. The alternatives -(8, 32)or(32, 8)- require at least twoldmatrixcalls before the firstmma, reducing our ability to interleave memory transfers with computation.
Solution Summary
The register alignment constraints of SASS HMMA instructions were forcing the compiler to insert copy instructions to reorder HMMA expects, eliminating the need for these expensive register shuffles.
SASS Instruction Counts
Build Time Instructions
This table contains the difference in the number of register copy instructions in the kernel SASS assembly.
| Kernel | IMAD.MOV.U32 | MOV |
|---|---|---|
| 8 | 187 | 159 |
| 9 | 36 | 18 |
| Reduction | 81% | 89% |
What a huge drop! We've eliminated all register copy instructions related to LDSM and HMMA. Note that these are build-time instruction counts - the runtime reduction will vary based on branching, sequence length, and other dynamic factors.
Performance
| Kernel | Top Performing Configuration |
|---|---|
| 7 | (128, 64, 64, 4): load_0_2_2_fragments |
| 8 | (128, 128, 64, 4): load_2_2_2_fragments |
| 9 | (128, 128, 64, 4): load_2_2_2_fragments |
Another solid gain: 8.51% improvement from 163.76 to 177.68 TFLOPs. We've nearly closed the gap with the reference kernel.
Profile
Much better! We're still executing more IMAD.MOV.U32 instructions than the reference (FP pipeline), but the dramatic reduction in MOV instructions (INT pipeline) more than compensates.
Summed together, we're only executing 2.5% more copy instructions than the reference.
The scalar pipeline utilization drops significantly:
- ALU: down 47.89%, from 17.39% to 9.06%
- FMA: down 26.75%, from 30.16% to 22.09%
This represents an overall drop of 34.48%, from 47.55% → 31.15%.
The total number of cycles went down, so tensor pipeline utilization increased 8.51% from 66.15% to 71.78%.
Kernel 10: Removing CS2R Instructions + Optimizing Initial Softmax Iteration
Time to eliminate those CS2R instructions. These are used to zero-initialize CS2R for RZ as the C operand to HMMA. Here's what that looks like:
HMMA.16816.F32 R80, R16.reuse, R32, RZ ;
HMMA.16816.F32 R60, R16.reuse, R26, RZ ;
HMMA.16816.F32 R76, R16.reuse, R50, RZ ;
Kernel 9's assembly doesn't have equivalent code for RZ as the C operand to zero the matrix.
While we're at it, we can also optimize the first softmax iteration. Currently, we initialize the row statistics
- Initialize
and directly from the first column of - Skip the unnecessary scaling of
and since they're zero:
The code changes are straightforward, so I won't detail them here - check out the diff if you're curious.
Performance
| Kernel | Top Performing Configuration |
|---|---|
| 7 | (128, 64, 64, 4): load_0_2_2_fragments |
| 8 | (128, 128, 64, 4): load_2_2_2_fragments |
| 9 | (128, 128, 64, 4): load_2_2_2_fragments |
| 10 | (128, 128, 64, 4): load_2_2_2_fragments |
Unexpectedly, we regressed 1.46%, dropping from 177.68 to 175.10 TFLOPs.
This is puzzling given that we eliminated the extra softmax and CS2R instructions - these instruction counts now match the reference kernel.
The FP pipeline utilization dropped from 22.1% to 20.42% after removing the extra softmax instructions. This reduction should scale inversely with sequence length.
Regression Analysis
What went wrong? Comparing stall profiles between kernels 9 and 10 reveals the culprit:
| Stall Type | Kernel 9 | Kernel 10 | Delta |
|---|---|---|---|
short_scoreboard | 2.36% | 7.79% | +5.44% |
wait | 36.99% | 38.48% | +1.49% |
That's a massive jump in short_scoreboard stalls. Let's examine where these stalls occur. In the snippets below, the last instruction is where the stall happens, while the first instruction is the dependency it's waiting on:
...
LDSM.16.M88.4 R216, [R213+0x2000] ;
HMMA.16816.F32 R60, R204, R4, R60 ;
HMMA.16816.F32 R64, R204, R6, R64 ;
LDSM.16.M88.4 R4, [R213+0x3000] ;
LDSM.16.M88.4 R204, [R235+-0x40] ;
HMMA.16816.F32 R92, R228, R224, R92 ;
HMMA.16816.F32 R96, R228, R226, R96 ;
HMMA.16816.F32 R68, R228, R220, R68 ;
HMMA.16816.F32 R72, R228, R222, R72 ;
HMMA.16816.F32 R76, R228, R216, R76 ;
... ...
LDSM.16.M88.4 R136, [R242] ;
LDSM.16.M88.4 R4, [R232+-0x40] ;
HMMA.16816.F32 R128, R224, R136, R128 ;
...The short_scoreboard stalls are happening because HMMA instructions are waiting on data from LDSM operations. The problem repeats throughout the kernel.
Here's what's interesting: Kernel 9 handles this much better. It packs around 8 independent instructions between the dependency and its use, helping mask the LDSM latency. Kernel 10, on the other hand, only has 1-2 instructions in those critical gaps. The compiler scheduled the LDSM → HMMA sequence less optimally, creating more opportunities for short_scoreboard stalls.
Why did the compiler make this choice? It comes down to slightly increased register pressure - more values need to stay alive between blocks that are branched to.
Despite the regression, Kernel 10's changes lay important groundwork. Removing the CS2R instructions and optimizing the initial softmax iteration eliminate unnecessary work. The compiler just needs better instruction scheduling to realize the benefits - we'll recover this performance and more in subsequent kernels.
Kernel 11: Strided Swizzling from RF → SMEM
In this iteration, we implement the final piece of strided swizzling: writing
Why RF → SMEM Swizzling Differs
The RF → SMEM transfer has some unique characteristics compared to the other directions we've optimized:
- Threads within a warp are arranged in an
grid (8 rows, 4 threads per row) - Unlike the vectorized loads we saw earlier, individual store operations aren't vectorized. This means we copy:
- 4B per thread
- 16B per row (4 threads × 4B each)
- 128B total per instruction warp-wide
- Each thread in a row shares the same stride but has a different offset:
(tid % 4) * 2
This arrangement lets us share stride calculations across threads in the same row. Here's what the mapping looks like for the second iteration of a copy within a swizzle region:
Performance
| Kernel | Top Performing Configuration |
|---|---|
| 7 | (128, 64, 64, 4): load_0_2_2_fragments |
| 8 | (128, 128, 64, 4): load_2_2_2_fragments |
| 9 | (128, 128, 64, 4): load_2_2_2_fragments |
| 10 | (128, 128, 64, 4): load_2_2_2_fragments |
| 11 | (128, 128, 64, 4): load_2_2_2_fragments |
We improved 1.31%, from 175.10 to 177.40 TFLOPs - nearly recovering the regression from Kernel 10. The improvement might seem modest, but it's exactly what we'd expect. The RF → SMEM transfer only executes once per CTA at the very end, unlike the GMEM → SMEM and SMEM → RF transfers that happen repeatedly throughout execution.
Summary
Through kernels 8-11, we systematically eliminated unnecessary instructions by:
- Kernel 8:
- Implementing strided swizzling to virtually eliminate all unnecessary logic/shift instructions
- Significantly reduce the register pressure, making
the most performant kernel
- Kernel 9: Optimizing fragment storage to eliminate redundant register copies
- Kernel 10: Removing
CS2Rinstructions and optimizing initial softmax (minor regression due to instruction scheduling) - Kernel 11: Completing strided swizzling implementation for RF→SMEM transfers
The net result: 18.6% performance improvement from Kernel 7 (149.7 TFLOPs) to Kernel 11 (177.4 TFLOPs), bringing us within striking distance of reference performance (186.4 TFLOPs) on the A100.
In Part 9, we'll tackle the final stretch with a different assortment of optimizations. These minor refinements will push our final performance from 95.2% to 99.2% of reference.
Footnotes
-
IMAD.MOV.U32are executed on theFMApipeline as we discussed in Appendix A - Ampere Microarchitecture. ↩