In the previous part, we implemented three major optimizations from the CUTLASS GEMM library: eager block loading, sub-tiling with fragment interleaving, and double buffering. These techniques helped us overlap memory transfers with computation, reaching 99.6% of reference performance on the RTX 3090.

In this part, we'll implement two final optimizations:

  • Floating point instruction fusion using a clever optimization from the official implementation (Kernel 6)
  • Auto-tuning to find optimal kernel configurations (Kernel 7)

By the end, we'll slightly outperform the official implementation, achieving 101.5% the reference throughput.

Kernel 6. Improving FP32 Throughput

Our previous optimizations focused on tensor operations, and looking at the roofline graph, we can see that we're now hitting peak matmul FLOPs/s on the RTX 3090. To make further improvements, we need to look at other bottlenecks.

Aside: Roofline Graph

The roofline graph shows three different arithmetic intensities, with the L2 value being most relevant to our kernel:

  • L1 intensity appears artificially high:

    • We bypass L1 cache when copying from GMEM to SMEM using cp.async with .cg option
    • Only transfers from SMEM to GMEM (and any register spills) count toward L1 traffic
  • DRAM intensity only includes:

    • Initial data loads from main memory
    • Evicted sectors (cache lines that got pushed out and need reloading)
    • This gives us the "best case" memory bandwidth scenario
  • L2 intensity is most accurate for our analysis:

    • Captures all memory transfers since each CTA moves the same amount of data from L2
    • No L1 caching means L2 sees all our memory traffic
    • Best represents the true arithmetic intensity of our kernel

Finding Areas for Improvement

Since we've maximized tensor core utilization, let's examine what other operations consume significant compute cycles. Looking at the attention formula:

We can categorize our operations into two types:

  • Matrix multiplications: Execute on FP16/BF16 values using tensor cores
  • Softmax operations: Execute in FP32 using standard floating-point units

With tensor cores saturated, improving FP32 softmax performance becomes our next optimization target.

Fusing FP Multiplication and Addition in Softmax

We can use fused multiply-add (FFMA) instructions, which combine d = a * b + c into a single instruction, to reduce the number of instructions in our softmax computation. Let's consider the number of instructions per KV tile we're executing at a warp level.

Recall that each warp stores a matrix fragment in a grid of threads. Let

Currently, our softmax implementation explicitly scales by , which requires dedicated instructions (line 1). Let's examine the computational cost

Note: The expressions on the right show the number of instructions for each line (adds, muls, exponentials, and max/compare operations):

which sums to a total of instructions per tile.

The optimization from official implementation: Instead of scaling separately, we fuse the scale and subtract into the exponent argument using a single FFMA (fused multiply-add). Here's how:

  1. Scale the max value (line 2'): We compute , adding instructions
  2. Fuse the exponent argument (line 3): Compute with a single FFMA before the exponential
  3. Update scaling factors (line 4): Adjust the scaling for and calculations

* indicates the line was modified.

This optimized approach requires only instructions per tile.

Making the Fast Exponential Explicit

Our current code already uses the fast exponential approximation where expf() is internally implemented as exp2f() using the identity . We'll make this approximation explicit in our code and embed the factor directly into our scaling coefficient:

Note that the instruction count remains the same, we're just making the existing fast approximation explicit.

Final Algorithm

The reference implementation uses a slight variation: . This approach has marginally better performance despite requiring an extra instructions compared to storing directly. We'll adopt this approach as well, ending with instructions.

Numerical Precision

These changes in kernel 6 slightly alter floating-point error characteristics:

  • Fusing multiplication and addition via FMA generally reduces rounding error.
  • Approximating the exponential by scaling logits by and using exp2f() slightly increases approximation error.

Whether this is acceptable depends on the use case. If your application is especially sensitive to numerical error, you may want to investigate these trade‑offs and their consequences more deeply before adopting any changes.

Summary

By fusing scaling operations with exponential calculations, we reduced the softmax instruction count per tile:

  • Before: instructions
  • After: instructions

This represents a reduction of instructions per warp tile, which is a total reduction for our current block size .

Code

To implement this, we'll change two functions:

  • scale_l_O()
  • exponentiate_tensor()

We no longer call scale_S_accum() now, since it's been folded in.

Note on scaling and exp2f: in code, softmax_scale is exactly
. We keep the math in natural
exponentials for readability, but implement exponentiation via exp2f
using the identity so the FFMA produces the
final exponent argument directly.

softmax.cuh
template <int QO_fragments, int d_head_accum_fragments, typename accum_t = float>
__forceinline__ __device__ constexpr void
scale_l_O(
		accum_t (&m_next)[QO_fragments],
		accum_t (&m_cur)[QO_fragments],
        accum_t (&l)[QO_fragments],
        accum_t (&O_accum)[QO_fragments][d_head_accum_fragments],
        accum_t softmax_scale
) {
    #pragma unroll
    for (int q = 0; q < QO_fragments; ++q) {
        accum_t scale = exp2f((m_cur[q] - m_next[q]) * softmax_scale);
 
        m_cur[q] = m_next[q];
        l[q] *= scale;
        for (int d_head = 0; d_head < d_head_accum_fragments; ++d_head) {
            O_accum[q][d_head] *= scale;
        }
    }
}
softmax.cuh
template <int QO_fragments, int KV_accum_fragments,
          typename accum_t = float>
__forceinline__ __device__ constexpr void
exponentiate_tensor(
	accum_t (&S_accum)[QO_fragments][KV_accum_fragments],
	accum_t (&m)[QO_fragments],
	accum_t softmax_scale
) {
    #pragma unroll
    for (int q = 0; q < QO_fragments; ++q) {
        accum_t max_scaled = m[q] * softmax_scale;
        #pragma unroll
        for (int k = 0; k < KV_accum_fragments; ++k) {
            S_accum[q][k] = exp2f(S_accum[q][k] * softmax_scale - max_scaled);
        }
    }
}

Performance

Our FP instruction fusion optimization slightly improved our throughput from 67.11 TFLOPs → 67.23 TFLOPs, which is 99.9% of reference performance.

Profiling Results

By reducing the number of softmax FP instructions by 11.1%, we correspondingly reduced FP pipeline utilization by 12.41%.

This reduction in FP pipeline pressure had a secondary benefit: it slightly increased tensor core utilization by +0.17%, from 47.11% to 47.19%.

Kernel 7. Auto-Tuning

Up to this point, we've been optimizing our kernel using a fixed block configuration of . However, each of the improvements we've implemented represents a configurable component that can be toggled independently.

Auto-tuning is the standard practice for systematically exploring this configuration space to find optimal build-time parameters. Our configuration space includes:

flash_attention.cuh
struct FlashForwardKernelConfig {
	const int d_head; // [128]
	const int B_r; // [64, 128]
	const int B_c; // [32, 64]
	const int n_warps; // [4]
 
	const bool async_copy; // always true. this was for testing purposes.
 
	// Kernel #2
	const bool swizzled;
 
	// Kernel #3
	const bool eager_load_blocks;
 
	// Kernel #4
	// This can be either 0 or 2.
	// If it is:
	// - 0: load the entire tile into the RF at once before executing any matmuls
	//   - additionally for Q, persist without reloading.
	// - 2: load sub-tiles 2 fragments wide at a time
	const int Q_mma_load_K_fragments;
	const int K_mma_load_K_fragments;
	const int V_mma_load_K_fragments;
 
	// Kernel #5
	const bool mma_double_buffer_loads;
	
	// Kernel #6: fusing FP multiplication and addition instructions
	const bool optimized_softmax;
}

Each new optimization exponentially increases our configuration space, so to make auto-tuning tractable, we'll filter out configurations we know will be suboptimal. This includes kernels that:

  • Don't use swizzling
  • Have excessive register spilling

Configuration Notation

For clarity, we'll use a standardized format to describe kernel configurations:

({d_head}, {B_r}, {B_c}, {n_warps}): {async}+{eager}+{swizzled}+load_{q_fragments}_{k_fragments}_{v_fragments}_fragments+{buffer}+{opt_softmax}

The block configuration and fragment counts are always present, while other options appear only when enabled.

Optimal Configuration Results

Auto-tuning reveals the optimal configuration

(128, 64, 64, 4): async+eager+swizzled+load_0_2_0_fragments+buffer+opt_softmax

which pushes us over the performance of the official implementation to 101.5%!

Kernel 6 vs Kernel 7

The main difference between kernel 6 and 7 is that in kernel 7, persists in the RF throughout the mainloop, so each warp only loads from SMEM → RF only once. This reduces the number of ldmatrix instructions executed in each iteration and resulted in a decent reduction in SMEM related warp stalls:

StallKernel 6Kernel 7Delta
barrier4.82%2.72%-2.11%
mio_throttle2.46%1.95%-0.51%
short_scoreboard1.94%1.70%-0.24%

Brief Block Size Analysis

Here are the top performing configurations for other block sizes and how they compare to the top performing one. Note: From now on, all kernels will be configured with async+eager+swizzled, so I'll remove the async+eager+swizzled+ prefix to make it easier to read.

TFLOPs
(seqlen=4096)
Perf Rel.
To Kernel 7
(128, 64, 64, 4): load_0_2_0_fragments+buffer+opt_softmax68.31100.0%
(128, 64, 32, 4): load_2_2_0_fragments67.3998.64%
(128, 128, 32, 4): load_2_2_0_fragments+buffer+opt_softmax67.3698.61%
(128, 128, 64, 4): load_2_2_0_fragments+opt_softmax54.2679.42%

The occupancy tables:

Registers per ThreadSMEM per CTAWarps per SM
(128, 64, 64, 4)22948KiB8
(128, 64, 32, 4)16832KiB12
(128, 128, 32, 4)255 (0B spilled)48KiB8
(128, 128, 64, 4)255 (272B spilled)64KiB4

Key observations:

  • the TFLOPs/s of (64, 32) and (128, 32) are not far behind from (64, 64)
  • (64, 32) can fit 3 CTAs per SM, but for our workload, 2 CTAs can be enough to saturate the RTX 3090 tensor cores
  • (128, 64) can only fit 1 CTA per SM and suffers from register spilling (272B per thread), resulting in poor performance

Performance on A100

Up until this point, we've been benchmarking exclusively on the RTX 3090. How does our optimized kernel perform on the A100?

TFLOPs
(seqlen=4096)
Perf Rel. To
Reference
(128, 64, 64, 4): load_0_2_0_fragments+opt_softmax149.7180.31%
(128, 128, 32, 4): load_2_2_2_fragments+buffer142.8276.62%
(128, 64, 32, 4): load_2_2_2_fragments+opt_softmax135.2472.55%
(128, 128, 64, 4): load_2_2_2_fragments+opt_softmax130.1469.81%
Reference186.41100.00%

Performance Regression on A100

We've gone from exceeding reference performance on RTX 3090 (101.5%) to achieving only ~80% on the A100.

This 20% performance gap suggests that we'll need to look at A100 specific bottlenecks to increase our performance.

Summary

FP Instruction Fusion (Kernel 6)

  • Fused attention logit scaling with max subtraction in online softmax, reducing FP instructions by 11.1%
  • Achieved 99.9% of reference performance on the RTX 3090

Auto-Tuning (Kernel 7)

  • Systematically explored configuration space to find optimal parameters
  • Best configuration: (128, 64, 64, 4) with persisted in the RF
  • Exceeded reference performance: 101.5% on RTX 3090

Critical Finding

  • Same kernel achieves only 80.3% on the A100, highlighting architecture-specific optimization needs

What's Next?

In Part 7, we'll profile kernel 7 on the A100 and find the causes for the significant performance gap between the RTX 3090 and the A100.