Skip to content

[docs/examples] Blackwell cute tutorials: narrow TMEM_LOAD atoms (32dp32b1x) carry a large per-load lowering cost — prefer wider atoms (32dp32b32x) in t2r epilogues#3313

Open
cfregly wants to merge 1 commit into
NVIDIA:mainfrom
cfregly:docs/blackwell-tutorial-tmem-load-atom-width

Conversation

@cfregly

@cfregly cfregly commented Jun 11, 2026

Copy link
Copy Markdown

What

All five Blackwell CuTe tutorials demonstrate the TMEM->register epilogue
with the narrowest TMEM_LOAD copy atom:

// examples/cute/tutorial/blackwell/01_mma_sm100.cu (and 02..05 likewise)
TiledCopy tiled_t2r_copy = make_tmem_copy(SM100_TMEM_LOAD_32dp32b1x{}, tCtAcc);

On sm_103 (CUDA 13.2 ptxas; also observed on the 13.x line generally) every
tcgen05.ld.sync.aligned.32x32b.x1.b32 this atom emits is lowered to a
per-load convergence-helper subroutine call in SASS:

LEPC R20, 0x460 ;
CALL.ABS.NOINC R2 ;
WARPSYNC.ALL ;
...
LDTM R2, tmem[UR4] ;

For a 128x256 fp32 accumulator that is 256 loads per thread = 256 helper
calls per warp (~60 cycles each), which can dominate a small kernel's fixed
cost. In a warp-specialized 2048^3 fp16 GEMM kernel built from these
tutorials we measured the t2r phase at 7.65-7.81 us of a 23.8-24.2 us
kernel
(in-kernel %globaltimer stamps, per-CTA medians, 128 CTAs).
Switching the one line to SM100_TMEM_LOAD_32dp32b32x (8 loads/thread)
cut t2r to 0.42 us and the whole kernel from 23.8 -> 16.0 us (1.49x),
bit-identically (torch.equal on the outputs — the atom only changes how
many columns one instruction moves; each thread keeps the same
(row, all-columns) fragment, so per-element conversion and store mapping
are untouched). Width sweep at the same shape: x16 ties x32 (0.51 vs
0.42 us); x128 REGRESSES (0.99 us t2r + slower writeback — the 128-output
asm serializes register writeback), so widest is not best; x32 was the
sweet spot measured.

A tutorial that demonstrates the 1x atom without comment teaches a ~1.5x
performance bug as the canonical epilogue.

Proposed change

Minimal, docs-only (no behavior change to any library kernel):

  1. In tutorials 01-05, either switch the epilogue atom to
    SM100_TMEM_LOAD_32dp32b32x (and refresh the affected print
    annotations), or keep 32dp32b1x for pedagogical simplicity and add a
    short comment block, e.g.:

    // PERF NOTE: 32dp32b1x is the simplest TMEM_LOAD atom but issues one
    // tcgen05.ld per accumulator column; ptxas (CUDA 13.x) lowers each load
    // through a per-load convergence-helper call, so the t2r phase can
    // dominate small/medium kernels. Prefer a wider atom (e.g.
    // SM100_TMEM_LOAD_32dp32b32x: 32 columns per instruction) in real
    // epilogues; on sm_103 we measured 1.49x on the whole kernel from this
    // one line. Very wide atoms (x128) can regress on register-writeback
    // serialization — sweep the width for your tile shape.
  2. Optionally, a sentence in
    media/docs/cpp/cute/0y_tmem_tensor.md (or the blackwell functionality
    doc) noting atom width as a first-class performance knob.

Standalone evidence

A ~120-line reproducer (attached: tmem_load_atom_repro.cu) containing
only the TMEM alloc + make_tmem_copy t2r + store (no mainloop, no MMA),
built with nvcc -std=c++20 -arch=sm_103a against current CUTLASS headers:

32dp32b1x 32dp32b32x
LDTM in SASS 256 8
CALL.ABS.NOINC / LEPC 256 / 256 8 / 8
kernel time (GB300, 1 CTA, 2000 launches) 10.91 us 5.86 us

cuobjdump -sass one-liners to see it:

cuobjdump -sass repro_1x  | grep -c "CALL.ABS.NOINC"   # 256
cuobjdump -sass repro_32x | grep -c "CALL.ABS.NOINC"   # 8
cuobjdump -sass repro_1x  | grep -B7 "LDTM" | head     # LEPC/CALL/WARPSYNC wrap

Environment

  • NVIDIA GB300 (sm_103), driver 580.159.03
  • CUDA 13.2 (nvcc V13.2.78)
  • CUTLASS: reproduced against both 4.2.0 headers and current main
  • Kernel-level numbers: fp16 2048^3, 128x256 tile, 4-stage warp-specialized
    pipeline; CUDA-graph interleaved A/B, 7 reps/arm, zero distribution
    overlap; ncu cross-check (sm__pipe_tensor_cycles_active 20.4% -> 37.9%
    of elapsed)
tmem_load_atom_repro.cu
// Minimal standalone reproducer: ptxas (CUDA 13.2) lowers EVERY
// SM100_TMEM_LOAD_32dp32b1x tcgen05.ld into a per-load
// LEPC + CALL.ABS.NOINC + WARPSYNC convergence-helper subroutine call.
// For a 128x256 fp32 accumulator that is 256 loads/thread = 256 helper
// calls per warp in the unrolled SASS, which dominates a small kernel's
// fixed cost. Widening the atom to SM100_TMEM_LOAD_32dp32b32x (8
// loads/thread) removes them.
//
// This file contains ONLY the TMEM allocate + tmem->register copy +
// register->gmem store epilogue (no mainloop, no MMA), so the SASS of the
// copy is easy to inspect. The TiledMMA is used solely to derive the
// canonical 128x256 TMEM accumulator tensor that make_tmem_copy partitions
// — the same construction the CuTe Blackwell tutorials use.
//
// Build (CUTLASS >= 4.2 headers, CUDA >= 13.x, sm_103a or sm_100a):
//   nvcc -std=c++20 -arch=sm_103a -I$CUTLASS_DIR/include \
//        -DT2R_ATOM=SM100_TMEM_LOAD_32dp32b1x  -o repro_1x  tmem_load_atom_repro.cu
//   nvcc -std=c++20 -arch=sm_103a -I$CUTLASS_DIR/include \
//        -DT2R_ATOM=SM100_TMEM_LOAD_32dp32b32x -o repro_32x tmem_load_atom_repro.cu
//
// SASS evidence:
//   cuobjdump -sass repro_1x  | grep -c "CALL.ABS.NOINC"   # ~256 (one per LDTM)
//   cuobjdump -sass repro_32x | grep -c "CALL.ABS.NOINC"   # ~8
//   cuobjdump -sass repro_1x  | grep -B2 -A2 "LDTM" | head # LEPC/CALL/WARPSYNC wrap
//
// Optional run (any SM100-family GPU): ./repro_1x ; ./repro_32x
// prints CUDA-event kernel time; the 1x build is several times slower for
// the identical (bit-identical destination mapping) data movement.

#include <cuda_runtime.h>

#include <cstdio>
#include <cstdlib>

#include <cutlass/half.h>

#include <cute/tensor.hpp>
#include <cute/arch/mma_sm100_umma.hpp>
#include <cute/arch/tmem_allocator_sm100.hpp>
#include <cute/atom/mma_traits_sm100.hpp>

using namespace cute;

#ifndef T2R_ATOM
#define T2R_ATOM SM100_TMEM_LOAD_32dp32b1x
#endif

// Stringify the atom name for the report.
#define XSTR(s) STR(s)
#define STR(s) #s

using TypeAB = cutlass::half_t;
using Accum = float;

constexpr int kTileM = 128;
constexpr int kTileN = 256;

__global__ void tmem_epilogue_kernel(float* __restrict__ out) {
  // 1-SM 128x256 UMMA atom: used only to derive the canonical TMEM
  // accumulator layout (tmem_ptr tensor) for make_tmem_copy.
  TiledMMA tiled_mma = make_tiled_mma(
      SM100_MMA_F16BF16_SS<TypeAB, TypeAB, Accum, kTileM, kTileN,
                           UMMA::Major::K, UMMA::Major::K>{});
  auto cta_mma = tiled_mma.get_slice(0);

  Tensor mD = make_tensor(make_gmem_ptr(out),
                          make_layout(make_shape(Int<kTileM>{}, Int<kTileN>{}),
                                      make_stride(Int<kTileN>{}, _1{})));
  Tensor tCgD = cta_mma.partition_C(mD);
  Tensor tCtAcc = cta_mma.make_fragment_C(tCgD);  // TMEM-resident accumulator

  __shared__ uint32_t tmem_base_ptr;
  cute::TMEM::Allocator1Sm tmem_allocator{};
  if (threadIdx.x / 32 == 0) {
    tmem_allocator.allocate(
        cute::TMEM::Allocator1Sm::Sm100TmemCapacityColumns, &tmem_base_ptr);
  }
  __syncthreads();
  tCtAcc.data() = tmem_base_ptr;

  // ---- The one line under test: the TMEM_LOAD copy-atom width. ----------
  auto tiled_t2r_copy = make_tmem_copy(T2R_ATOM{}, tCtAcc);
  // -----------------------------------------------------------------------
  auto thr_t2r_copy = tiled_t2r_copy.get_slice(threadIdx.x);

  Tensor tDtAcc = thr_t2r_copy.partition_S(tCtAcc);
  Tensor tDgD = thr_t2r_copy.partition_D(tCgD);
  Tensor tDrAcc = make_tensor<Accum>(shape(tDgD));

  copy(tiled_t2r_copy, tDtAcc, tDrAcc);  // TMEM -> registers (tcgen05.ld)
  copy(tDrAcc, tDgD);                    // registers -> gmem (keeps it live)

  __syncthreads();
  if (threadIdx.x / 32 == 0) {
    tmem_allocator.release_allocation_lock();
    tmem_allocator.free(tmem_base_ptr,
                        cute::TMEM::Allocator1Sm::Sm100TmemCapacityColumns);
  }
}

int main() {
  float* d_out = nullptr;
  if (cudaMalloc(&d_out, sizeof(float) * kTileM * kTileN) != cudaSuccess) {
    std::printf("cudaMalloc failed (no SM100-family GPU?)\n");
    return 1;
  }

  // One warmup + error check.
  tmem_epilogue_kernel<<<1, 128>>>(d_out);
  cudaError_t err = cudaDeviceSynchronize();
  if (err != cudaSuccess) {
    std::printf("kernel failed: %s\n", cudaGetErrorString(err));
    return 1;
  }

  constexpr int kIters = 2000;
  cudaEvent_t beg, end;
  cudaEventCreate(&beg);
  cudaEventCreate(&end);
  cudaEventRecord(beg);
  for (int i = 0; i < kIters; ++i) {
    tmem_epilogue_kernel<<<1, 128>>>(d_out);
  }
  cudaEventRecord(end);
  cudaEventSynchronize(end);
  float ms = 0.0f;
  cudaEventElapsedTime(&ms, beg, end);

  std::printf("atom=%s  kernel time: %.3f us/launch (%d launches)\n",
              XSTR(T2R_ATOM), 1000.0f * ms / kIters, kIters);
  cudaFree(d_out);
  return 0;
}
sass_evidence.txt
# DRAFT — internal review pending
# SASS evidence for tmem_load_atom_repro.cu (generated + verified 2026-06-11)
# Environment: NVIDIA GB300 (sm_103), CUDA 13.2 (nvcc V13.2.78, ptxas 13.2),
#   driver 580.159.03. Compiled against CUTLASS main headers; identical
#   counts reproduced against CUTLASS 4.2.0 headers.
# Build:
#   nvcc -std=c++20 -arch=sm_103a -I$CUTLASS_DIR/include \
#        -DT2R_ATOM=SM100_TMEM_LOAD_32dp32b1x  -o repro_1x  tmem_load_atom_repro.cu
#   nvcc -std=c++20 -arch=sm_103a -I$CUTLASS_DIR/include \
#        -DT2R_ATOM=SM100_TMEM_LOAD_32dp32b32x -o repro_32x tmem_load_atom_repro.cu
#   cuobjdump -sass repro_1x > repro_1x.sass   (and likewise repro_32x)
# Runtime sanity (same GPU, 2000 launches, CUDA events; epilogue-only kernel):
#   atom=SM100_TMEM_LOAD_32dp32b1x   kernel time: 10.906 us/launch
#   atom=SM100_TMEM_LOAD_32dp32b32x  kernel time:  5.861 us/launch

## instruction counts (cuobjdump -sass | grep -c), kernel tmem_epilogue_kernel
[repro_1x]
  LDTM:            256
  CALL.ABS.NOINC:  256
  LEPC:            256
  WARPSYNC:        259
[repro_32x]
  LDTM:            8
  CALL.ABS.NOINC:  8
  LEPC:            8
  WARPSYNC:        11

## repro_1x.sass: one of the 256 per-load helper-call sites (LEPC + CALL.ABS.NOINC + WARPSYNC wrapping each LDTM)
        /*0400*/                   IMAD.U32 R6, RZ, RZ, UR6 ;
        /*0410*/                   IMAD.U32 R7, RZ, RZ, UR7 ;
        /*0420*/                   IMAD.U32 R10, RZ, RZ, UR8 ;
        /*0430*/                   MOV R11, UR9 ;
        /*0440*/                   LEPC R20, 0x460 ;
        /*0450*/                   CALL.ABS.NOINC R2 ;
        /*0460*/                   WARPSYNC.ALL ;
        /*0470*/                   R2UR UR4, R22.reuse ;
        /*0480*/                   VIADD R0, R22, 0x1 ;
        /*0490*/                   SHF.R.U32.HI R0, RZ, 0x15, R0 ;
        /*04a0*/                   LOP3.LUT P0, RZ, R0, 0x3, R23, 0x48, !PT ;
        /*04b0*/                   LDTM R2, tmem[UR4] ;
        /*04c0*/                   STL [R1+0x48], R2 ;
        /*04d0*/              @!P0 BRA 0x5e0 ;
        /*04e0*/                   MOV R0, 0x0 ;

## repro_32x.sass: the same site at x32 width (8 total LDTM.x32 in the kernel)
        /*0400*/                   IMAD.U32 R6, RZ, RZ, UR6 ;
        /*0410*/                   IMAD.U32 R7, RZ, RZ, UR7 ;
        /*0420*/                   IMAD.U32 R10, RZ, RZ, UR8 ;
        /*0430*/                   MOV R11, UR9 ;
        /*0440*/                   LEPC R20, 0x460 ;
        /*0450*/                   CALL.ABS.NOINC R2 ;
        /*0460*/                   WARPSYNC.ALL ;
        /*0470*/                   R2UR UR4, R16.reuse ;
        /*0480*/                   VIADD R0, R16, 0x20 ;
        /*0490*/                   SHF.R.U32.HI R0, RZ, 0x15, R0 ;
        /*04a0*/                   LOP3.LUT P0, RZ, R0, 0x3, R17, 0x48, !PT ;
        /*04b0*/                   LDTM.x32 R20, tmem[UR4] ;
        /*04c0*/                   STL.64 [R1], R20 ;
        /*04d0*/                   STL.64 [R1+0x8], R22 ;
        /*04e0*/                   STL.64 [R1+0x10], R24 ;

…LOAD atoms; prefer wider atoms in t2r epilogues

ptxas (CUDA 13.x) lowers each tcgen05.ld of the 32dp32b1x atom through a
per-load LEPC + CALL.ABS.NOINC + WARPSYNC convergence-helper call: 256
loads/thread for a 128x256 fp32 accumulator = 256 helper calls per warp,
which can dominate a small kernel's fixed cost. Switching one line to
SM100_TMEM_LOAD_32dp32b32x (8 loads/thread) measured 1.49x on a full GEMM
kernel on GB300 (sm_103), bit-identical output; x128 regresses (serialized
register writeback). Comment-only change to the five Blackwell tutorials.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant