CUDA/GPU Programming Terms
- GMEM: global memory, stored in DRAM
- SMEM: shared memory, stored in SRAM
- LMEM: local memory
- RF: register file / register memory
- CTA: cooperative thread array; same as thread-block. I'll use CTA to avoid confusion with matrix blocks
- tiling: dividing a tile into tiles, each processed by a CTA, warp, or thread. The shape of the tile should be inferable from the context.
- fragment: a tile of a matrix stored in registers. In this series, this specifically refers to an
tile in RF - Within a warp, each thread holds 2 values
- This requires a single 32-bit register for 16-bit values, and
- lane_id: thread index within a warp (
tid % 32
) - LD/ST: load/store operations
- mnk variables: standard naming convention for GEMM dimensions:
- For
D = AB^T + C
:- A is
(m, k)
- B is
(n, k)
- C and D are
(m, n)
- A is
- In our context,
k
corresponds to:in in
- For
- Arithmetic Intensity: # fp operations performed divided by the # bytes loaded. This can be different for different levels in the memory hierarchy. For instance, the # bytes loaded from the L1 cache can be different from the L2 cache.
Flash Attention Terms
Following the notation from the paper, with a few simplifications:
: Query and output tensors handled by the current CTA : The -th key/value tile : row-wise max of attention scores : row-wise sum of exponentiated attention scores : query rows in the block : key and rows in the and blocks : head dimension