In the previous part, we implemented three major optimizations from the CUTLASS GEMM library: eager block loading, sub-tiling with fragment interleaving, and double buffering. These techniques helped us overlap memory transfers with computation, reaching 99.6% of reference performance on the RTX 3090.
In this part, we'll implement two final optimizations:
- Floating point instruction fusion using a clever optimization from the official implementation (Kernel 6)
- Auto-tuning to find optimal kernel configurations (Kernel 7)
By the end, we'll slightly outperform the official implementation, achieving 101.5% the reference throughput.
Kernel 6. Improving FP32 Throughput
Our previous optimizations focused on tensor operations, and looking at the roofline graph, we can see that we're now hitting peak matmul FLOPs/s on the RTX 3090. To make further improvements, we need to look at other bottlenecks.
Aside: Roofline Graph
The roofline graph shows three different arithmetic intensities, with the L2 value being most relevant to our kernel:
-
L1 intensity appears artificially high:
- We bypass L1 cache when copying from GMEM to SMEM using
cp.asyncwith.cgoption - Only
transfers from SMEM to GMEM (and any register spills) count toward L1 traffic
- We bypass L1 cache when copying from GMEM to SMEM using
-
DRAM intensity only includes:
- Initial data loads from main memory
- Evicted sectors (cache lines that got pushed out and need reloading)
- This gives us the "best case" memory bandwidth scenario
-
L2 intensity is most accurate for our analysis:
- Captures all memory transfers since each CTA moves the same amount of data from L2
- No L1 caching means L2 sees all our memory traffic
- Best represents the true arithmetic intensity of our kernel
Finding Areas for Improvement
Since we've maximized tensor core utilization, let's examine what other operations consume significant compute cycles. Looking at the attention formula:
We can categorize our operations into two types:
- Matrix multiplications: Execute on FP16/BF16 values using tensor cores
- Softmax operations: Execute in FP32 using standard floating-point units
With tensor cores saturated, improving FP32 softmax performance becomes our next optimization target.
Fusing FP Multiplication and Addition in Softmax
We can use fused multiply-add (FFMA) instructions, which combine d = a * b + c into a single instruction, to reduce the number of instructions in our softmax computation. Let's consider the number of instructions per KV tile we're executing at a warp level.
Recall that each warp stores a matrix fragment in a
Currently, our softmax implementation explicitly scales
Note: The expressions on the right show the number of instructions for each line (adds, muls, exponentials, and max/compare operations):
which sums to a total of
The optimization from official implementation: Instead of scaling
- Scale the max value (line 2'): We compute
, adding instructions - Fuse the exponent argument (line 3): Compute
with a single FFMA before the exponential - Update scaling factors (line 4): Adjust the scaling for
and calculations
* indicates the line was modified.
This optimized approach requires only
Making the Fast Exponential Explicit
Our current code already uses the fast exponential approximation where expf() is internally implemented as exp2f() using the identity
Note that the instruction count remains the same, we're just making the existing fast approximation explicit.
Final Algorithm
The reference implementation uses a slight variation:
Numerical Precision
These changes in kernel 6 slightly alter floating-point error characteristics:
- Fusing multiplication and addition via FMA generally reduces rounding error.
- Approximating the exponential by scaling logits by
and using exp2f()slightly increases approximation error.Whether this is acceptable depends on the use case. If your application is especially sensitive to numerical error, you may want to investigate these trade‑offs and their consequences more deeply before adopting any changes.
Summary
By fusing scaling operations with exponential calculations, we reduced the softmax instruction count per tile:
- Before:
instructions - After:
instructions
This represents a reduction of
Code
To implement this, we'll change two functions:
scale_l_O()exponentiate_tensor()
We no longer call scale_S_accum() now, since it's been folded in.
Note on scaling and exp2f: in code, softmax_scale is exactly
exponentials for readability, but implement exponentiation via exp2f
using the identity
final exponent argument directly.
template <int QO_fragments, int d_head_accum_fragments, typename accum_t = float>
__forceinline__ __device__ constexpr void
scale_l_O(
accum_t (&m_next)[QO_fragments],
accum_t (&m_cur)[QO_fragments],
accum_t (&l)[QO_fragments],
accum_t (&O_accum)[QO_fragments][d_head_accum_fragments],
accum_t softmax_scale
) {
#pragma unroll
for (int q = 0; q < QO_fragments; ++q) {
accum_t scale = exp2f((m_cur[q] - m_next[q]) * softmax_scale);
m_cur[q] = m_next[q];
l[q] *= scale;
for (int d_head = 0; d_head < d_head_accum_fragments; ++d_head) {
O_accum[q][d_head] *= scale;
}
}
}template <int QO_fragments, int KV_accum_fragments,
typename accum_t = float>
__forceinline__ __device__ constexpr void
exponentiate_tensor(
accum_t (&S_accum)[QO_fragments][KV_accum_fragments],
accum_t (&m)[QO_fragments],
accum_t softmax_scale
) {
#pragma unroll
for (int q = 0; q < QO_fragments; ++q) {
accum_t max_scaled = m[q] * softmax_scale;
#pragma unroll
for (int k = 0; k < KV_accum_fragments; ++k) {
S_accum[q][k] = exp2f(S_accum[q][k] * softmax_scale - max_scaled);
}
}
}Performance
Our FP instruction fusion optimization slightly improved our throughput from 67.11 TFLOPs → 67.23 TFLOPs, which is 99.9% of reference performance.
Profiling Results
By reducing the number of softmax FP instructions by 11.1%, we correspondingly reduced FP pipeline utilization by 12.41%.
This reduction in FP pipeline pressure had a secondary benefit: it slightly increased tensor core utilization by +0.17%, from 47.11% to 47.19%.
Kernel 7. Auto-Tuning
Up to this point, we've been optimizing our kernel using a fixed block configuration of
Auto-tuning is the standard practice for systematically exploring this configuration space to find optimal build-time parameters. Our configuration space includes:
struct FlashForwardKernelConfig {
const int d_head; // [128]
const int B_r; // [64, 128]
const int B_c; // [32, 64]
const int n_warps; // [4]
const bool async_copy; // always true. this was for testing purposes.
// Kernel #2
const bool swizzled;
// Kernel #3
const bool eager_load_blocks;
// Kernel #4
// This can be either 0 or 2.
// If it is:
// - 0: load the entire tile into the RF at once before executing any matmuls
// - additionally for Q, persist without reloading.
// - 2: load sub-tiles 2 fragments wide at a time
const int Q_mma_load_K_fragments;
const int K_mma_load_K_fragments;
const int V_mma_load_K_fragments;
// Kernel #5
const bool mma_double_buffer_loads;
// Kernel #6: fusing FP multiplication and addition instructions
const bool optimized_softmax;
}Each new optimization exponentially increases our configuration space, so to make auto-tuning tractable, we'll filter out configurations we know will be suboptimal. This includes kernels that:
- Don't use swizzling
- Have excessive register spilling
Configuration Notation
For clarity, we'll use a standardized format to describe kernel configurations:
({d_head}, {B_r}, {B_c}, {n_warps}): {async}+{eager}+{swizzled}+load_{q_fragments}_{k_fragments}_{v_fragments}_fragments+{buffer}+{opt_softmax}
The block configuration and fragment counts are always present, while other options appear only when enabled.
Optimal Configuration Results
Auto-tuning reveals the optimal configuration
(128, 64, 64, 4): async+eager+swizzled+load_0_2_0_fragments+buffer+opt_softmax
which pushes us over the performance of the official implementation to 101.5%!
Kernel 6 vs Kernel 7
The main difference between kernel 6 and 7 is that in kernel 7, ldmatrix instructions executed in each iteration and resulted in a decent reduction in SMEM related warp stalls:
| Stall | Kernel 6 | Kernel 7 | Delta |
|---|---|---|---|
barrier | 4.82% | 2.72% | -2.11% |
mio_throttle | 2.46% | 1.95% | -0.51% |
short_scoreboard | 1.94% | 1.70% | -0.24% |
Brief Block Size Analysis
Here are the top performing configurations for other block sizes and how they compare to the top performing one. Note: From now on, all kernels will be configured with async+eager+swizzled, so I'll remove the async+eager+swizzled+ prefix to make it easier to read.
| TFLOPs (seqlen=4096) | Perf Rel. To Kernel 7 | |
|---|---|---|
| (128, 64, 64, 4): load_0_2_0_fragments+buffer+opt_softmax | 68.31 | 100.0% |
| (128, 64, 32, 4): load_2_2_0_fragments | 67.39 | 98.64% |
| (128, 128, 32, 4): load_2_2_0_fragments+buffer+opt_softmax | 67.36 | 98.61% |
| (128, 128, 64, 4): load_2_2_0_fragments+opt_softmax | 54.26 | 79.42% |
The occupancy tables:
| Registers per Thread | SMEM per CTA | Warps per SM | |
|---|---|---|---|
| (128, 64, 64, 4) | 229 | 48KiB | 8 |
| (128, 64, 32, 4) | 168 | 32KiB | 12 |
| (128, 128, 32, 4) | 255 (0B spilled) | 48KiB | 8 |
| (128, 128, 64, 4) | 255 (272B spilled) | 64KiB | 4 |
Key observations:
- the TFLOPs/s of (64, 32) and (128, 32) are not far behind from (64, 64)
- (64, 32) can fit 3 CTAs per SM, but for our workload, 2 CTAs can be enough to saturate the RTX 3090 tensor cores
- (128, 64) can only fit 1 CTA per SM and suffers from register spilling (272B per thread), resulting in poor performance
Performance on A100
Up until this point, we've been benchmarking exclusively on the RTX 3090. How does our optimized kernel perform on the A100?
| TFLOPs (seqlen=4096) | Perf Rel. To Reference | |
|---|---|---|
| (128, 64, 64, 4): load_0_2_0_fragments+opt_softmax | 149.71 | 80.31% |
| (128, 128, 32, 4): load_2_2_2_fragments+buffer | 142.82 | 76.62% |
| (128, 64, 32, 4): load_2_2_2_fragments+opt_softmax | 135.24 | 72.55% |
| (128, 128, 64, 4): load_2_2_2_fragments+opt_softmax | 130.14 | 69.81% |
| Reference | 186.41 | 100.00% |
Performance Regression on A100
We've gone from exceeding reference performance on RTX 3090 (101.5%) to achieving only ~80% on the A100.
This 20% performance gap suggests that we'll need to look at A100 specific bottlenecks to increase our performance.
Summary
FP Instruction Fusion (Kernel 6)
- Fused attention logit scaling with max subtraction in online softmax, reducing FP instructions by 11.1%
- Achieved 99.9% of reference performance on the RTX 3090
Auto-Tuning (Kernel 7)
- Systematically explored configuration space to find optimal parameters
- Best configuration:
(128, 64, 64, 4)withpersisted in the RF - Exceeded reference performance: 101.5% on RTX 3090
Critical Finding
- Same kernel achieves only 80.3% on the A100, highlighting architecture-specific optimization needs
What's Next?
In Part 7, we'll profile kernel 7 on the A100 and find the causes for the significant performance gap between the RTX 3090 and the A100.