CUDA/GPU Programming Terms

  • GMEM: global memory, stored in DRAM
  • SMEM: shared memory, stored in SRAM
  • LMEM: local memory
  • RF: register file / register memory
  • CTA: cooperative thread array; same as thread-block. I'll use CTA to avoid confusion with matrix blocks
  • tiling: dividing a tile into tiles, each processed by a CTA, warp, or thread. The shape of the tile should be inferable from the context.
  • fragment: a tile of a matrix stored in registers. In this series, this specifically refers to an tile in RF
    • Within a warp, each thread holds 2 values
    • This requires a single 32-bit register for 16-bit values, and
  • lane_id: thread index within a warp (tid % 32)
  • LD/ST: load/store operations
  • mnk variables: standard naming convention for GEMM dimensions:
    • For D = AB^T + C:
      • A is (m, k)
      • B is (n, k)
      • C and D are (m, n)
    • In our context, k corresponds to:
      • in
      • in
  • Arithmetic Intensity: # fp operations performed divided by the # bytes loaded. This can be different for different levels in the memory hierarchy. For instance, the # bytes loaded from the L1 cache can be different from the L2 cache.

Flash Attention Terms

Following the notation from the paper, with a few simplifications:

  • : Query and output tensors handled by the current CTA
  • : The -th key/value tile
  • : row-wise max of attention scores
  • : row-wise sum of exponentiated attention scores
  • : query rows in the block
  • : key and rows in the and blocks
  • : head dimension