Skip to content

Add NVFP4 per-token quantization recipe#3045

Open
cael-ling wants to merge 17 commits into
NVIDIA:mainfrom
cael-ling:feature/nvfp4-per-token-recipe
Open

Add NVFP4 per-token quantization recipe#3045
cael-ling wants to merge 17 commits into
NVIDIA:mainfrom
cael-ling:feature/nvfp4-per-token-recipe

Conversation

@cael-ling

@cael-ling cael-ling commented May 26, 2026

Copy link
Copy Markdown
Contributor

Description

This PR adds an NVFP4 per-token quantization recipe for model pre-training. The default NVFP4BlockScaling recipe computes a single per-tensor outer amax (s_global) per tensor. The per-token variant instead computes a per-row outer amax (length M) for rowwise data and a per-col outer amax (length K) for columnwise data, giving each token/row its own global scale.

Changes

  • Per-token cast kernels: vector-amax + encode/swizzle producing NVFP4 tensors whose _amax_rowwise / _amax_columnwise are per-row/per-col vectors.
  • CUTLASS GEMM (nvfp4_cutlass_per_token_gemm) that rescales with the per-row/per-col outer-amax vectors inside the epilogue;
  • Forward + backward coverage (dgrad NN / wgrad NT layouts).
  • NVFP4PerTokenBlockScaling recipe (re-exported from transformer_engine.pytorch.fp8), plus an equivalent NVTE_NVFP4_PER_TOKEN=1 env-var switch on a plain NVFP4BlockScaling so frameworks that only build a default recipe (e.g. Megatron-Core) can opt in with no code change.
  • Opt-in RHT / SR (per_token_rht / per_token_sr) — off by default on the per-token path.
  • Opt-in 2D weight quantization (per_token_weight_2d): transposition-invariant 16×16 cast emitted in per-token layout.
  • Docs: API reference entry, NVTE_NVFP4_* env-var docs, and a "Per-token NVFP4" feature section with Megatron-Core launch instructions.
  • Example: examples/pytorch/nvfp4_per_token_megatron — single-GPU MoE example comparing per-token vs per-tensor vs BF16 with identical model/data/seed.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@github-actions github-actions Bot added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label May 26, 2026
@cael-ling cael-ling force-pushed the feature/nvfp4-per-token-recipe branch from 6f17fe4 to 928ab1c Compare May 27, 2026 13:09
DIVUP_TO_MULTIPLE(buff_elems_total_in * sizeof(IType), TMA_SHMEM_ALIGNMENT);
constexpr int dshmem_size = buff_size_aligned_in + TMA_SHMEM_ALIGNMENT; // + align pad

dim3 grid(static_cast<unsigned>(K / CHUNK_DIM_X), static_cast<unsigned>(M / CHUNK_DIM_Y), 1);

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe use DIVUP here to handle the remainder case?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fast path has a hard precondition that M and K are exact multiples of CHUNK_DIM (128): validate() does NVTE_CHECK(M % CHUNK_DIM_Y == 0) / NVTE_CHECK(K % CHUNK_DIM_X == 0), and is_supported() returns false unless both hold — so any non-multiple shape is rejected / routed to the generic per-token fallback before it ever reaches this launcher.

// After all 4 stages, emit one atomicMaxFloat per row slot + one per col slot.
//
// kWithRht=true: col-wise amax over RHT-rotated 16-row strips (per-thread
// FHT with random_sign_mask_t). Row direction never sees RHT.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo: Row direction never sees RHT -> Row direction never uses RHT

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch

}
}
#else
NVTE_DEVICE_ERROR("Per-token amax kernel requires SM 10.0+ (Blackwell).");

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For these quantization kernel, TMA only require SM 9.0+ only. Is there any other constraints that limit to sm 10.0+?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The CUDA_ARCH >= 1000 guard is intentional but not because of a hardware op in this kernel. Two reasons:

  1. The shared TE PTX wrappers it calls — cp_async_bulk_tensor_2d_global_to_shared and mbarrier_wait_parity_acquire_cta_shared_cta in util/ptx.cuh — are themselves guarded to >= 1000 and emit NVTE_DEVICE_ERROR below that. They were authored/validated only for the Blackwell path.
  2. The whole NVFP4 quantize path is host-gated to SM100 anyway (NVTE_ERROR("NVFP4 requires SM100 ...")), since NVFP4 is a Blackwell datatype and the downstream FP4 GEMM that consumes these scales only exists on SM100. So the amax kernel is never launched off <SM100; the per-arch guard just yields a clean error instead of an undefined symbol.

@cael-ling cael-ling marked this pull request as ready for review June 11, 2026 07:58
@greptile-apps

greptile-apps Bot commented Jun 11, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR adds an NVFP4 per-token quantization recipe (NVFP4PerTokenBlockScaling) for model pre-training. Instead of a single per-tensor global amax, the recipe computes a per-row outer amax (M,) for rowwise data and a per-column outer amax (K,) for columnwise data, enabling finer-grained scaling per token.

  • New cast kernels (quantize_nvfp4_per_token.cu, quantize_nvfp4_per_token_group.cu): K1 (vector amax) + K2 (FP4 encode/swizzle) producing per-token NVFP4 tensors for both single and grouped (MoE) inputs.
  • Fused CUTLASS GEMM (nvfp4_cutlass_gemm.cu, nvfp4_cutlass_grouped_gemm.cu): nvfp4_cutlass_per_token_gemm consumes per-row/per-col amax vectors in the epilogue (EVT), covering TN/NN/NT layouts for fwd/dgrad/wgrad.
  • Recipe and Python integration: NVFP4PerTokenBlockScaling dataclass plus NVTE_NVFP4_PER_TOKEN=1 env-var opt-in on the base recipe, quantizer flag propagation through NVFP4TensorStorage._per_token, and automatic general_gemm routing to the fused EVT kernel.

Confidence Score: 5/5

The new per-token recipe is additive (opt-in via explicit recipe class or env var) and does not modify the default NVFP4BlockScaling code path, so existing users are unaffected.

The three previously flagged issues (null-ptr restore in weight-2D path, silent alpha=1 assumption in per-token GEMM, 0-token split rejection) remain open but are either gated behind an unreachable code path today or are documented limitations. No new blocking defects were found. The dispatch table, amax-vector allocation, FSDP2 metadata propagation, and grouped-GEMM empty-expert handling are all correct.

transformer_engine/pytorch/csrc/quantizer.cpp — the weight-2D amax restore after nvte_compute_amax_with_config uses rowwise_amax_ptr without a null guard; transformer_engine/pytorch/cpp_extensions/gemm.py — the grouped GEMM bias+accumulate guard uses assert.

Important Files Changed

Filename Overview
transformer_engine/pytorch/cpp_extensions/gemm.py Adds per-token GEMM dispatch via _nvfp4_per_token_gemm and _nvfp4_per_token_grouped_gemm; return-type annotation for single_output=True is incorrect and a bias+accumulate mutual-exclusion guard uses assert (disableable with -O).
transformer_engine/pytorch/csrc/quantizer.cpp Per-token and per-token weight-2D quantize_impl paths added; previously-flagged latent null-pointer restore (rowwise_amax_ptr after nvte_compute_amax_with_config) remains; new per-token and split-quantize dispatches are otherwise correct.
transformer_engine/common/recipe/init.py Adds NVFP4PerTokenBlockScaling subclass and nvfp4_per_token() classmethod; mutex knobs are correctly forced in __post_init__ and guarded again in get_quantizer(), so env-var activation after construction is safe.
transformer_engine/pytorch/quantization.py Per-token quantizer factory correctly gates all mutex features (RHT, SR, 2D-quant, 4over6, row-scaled) with and not per_token guards, making env-var opt-in safe even when recipe was constructed with default settings.
transformer_engine/pytorch/custom_recipes/quantization_nvfp4_per_token_group.py Grouped per-token quantize wrapper; split_sections[i] <= 0 raises unconditionally (previously flagged), making this unusable with zero-token experts from dynamic routing without pre-filtering.
transformer_engine/pytorch/tensor/nvfp4_tensor.py Propagates per_token flag through FSDP2 shard/unshard metadata, __reduce_ex__, _ViewFunc, and _ReshapeFunc consistently; FSDP2 fast-fail guard for per-token mode is correctly placed.
transformer_engine/pytorch/csrc/extensions/nvfp4_cutlass_gemm.cpp Per-token CUTLASS GEMM C++ binding with thorough shape/dtype/contiguity checks; SF-swizzle handled internally or skipped when pre-swizzled; accumulate path correctly restricted to fp32 output.
transformer_engine/pytorch/csrc/extensions/nvfp4_per_token.cpp K1/K2/composite quantize C++ bindings; shape/dtype/alignment checks are thorough and match the kernel constraints; grouped bulk-allocation path correctly handles per-token amax vector sizing.

Sequence Diagram

%%{init: {'theme': 'neutral'}}%%
sequenceDiagram
    participant Recipe as NVFP4PerTokenBlockScaling
    participant RecipeState as NVFP4BlockScalingRecipeState
    participant Quantizer as NVFP4Quantizer (Python)
    participant QuantizerCPP as NVFP4Quantizer (C++)
    participant KernelK1 as K1: vector amax kernel
    participant KernelK2 as K2: FP4 encode/swizzle kernel
    participant GEMM as nvfp4_cutlass_per_token_gemm (EVT)

    Recipe->>RecipeState: get_quantizer(tensor_type, mode)
    RecipeState->>RecipeState: "per_token = nvfp4_per_token() and tensor_type in scope"
    RecipeState->>Quantizer: "NVFP4Quantizer(per_token=True, ...)"
    Quantizer->>QuantizerCPP: quantize_impl(input, out)

    alt per_token_weight_2d (weight only)
        QuantizerCPP->>KernelK1: nvte_compute_amax_with_config (scalar global amax)
        QuantizerCPP->>KernelK2: nvte_quantize_v2 (2D 16x16 cast)
        QuantizerCPP->>QuantizerCPP: broadcast scalar amax to (M,) rowwise + (K,) colwise vectors
    else standard per_token
        QuantizerCPP->>KernelK1: nvte_nvfp4_per_token_quantize (K1+K2 composite)
        KernelK1-->>QuantizerCPP: amax_rowwise (M,) + amax_columnwise (K,) + FP4 data
    end

    Note over QuantizerCPP: out._per_token = True
    QuantizerCPP-->>Quantizer: NVFP4Tensor with per-token amaxes

    Note over RecipeState,GEMM: general_gemm dispatch
    RecipeState->>GEMM: "_is_nvfp4_per_token_tensor(A/B) = True"
    GEMM->>GEMM: select rowwise/columnwise view per TN/NN/NT layout
    GEMM->>GEMM: tex.nvfp4_cutlass_per_token_gemm(ka_data, kb_data, ka_sf, kb_sf, ka_amax, kb_amax, out)
    Note over GEMM: EVT epilogue: D[i,j] = bf16(alpha_a[i] * alpha_b[j] * (A@B^T)[i,j])
    GEMM-->>RecipeState: bf16 / fp32 output tensor
Loading
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
sequenceDiagram
    participant Recipe as NVFP4PerTokenBlockScaling
    participant RecipeState as NVFP4BlockScalingRecipeState
    participant Quantizer as NVFP4Quantizer (Python)
    participant QuantizerCPP as NVFP4Quantizer (C++)
    participant KernelK1 as K1: vector amax kernel
    participant KernelK2 as K2: FP4 encode/swizzle kernel
    participant GEMM as nvfp4_cutlass_per_token_gemm (EVT)

    Recipe->>RecipeState: get_quantizer(tensor_type, mode)
    RecipeState->>RecipeState: "per_token = nvfp4_per_token() and tensor_type in scope"
    RecipeState->>Quantizer: "NVFP4Quantizer(per_token=True, ...)"
    Quantizer->>QuantizerCPP: quantize_impl(input, out)

    alt per_token_weight_2d (weight only)
        QuantizerCPP->>KernelK1: nvte_compute_amax_with_config (scalar global amax)
        QuantizerCPP->>KernelK2: nvte_quantize_v2 (2D 16x16 cast)
        QuantizerCPP->>QuantizerCPP: broadcast scalar amax to (M,) rowwise + (K,) colwise vectors
    else standard per_token
        QuantizerCPP->>KernelK1: nvte_nvfp4_per_token_quantize (K1+K2 composite)
        KernelK1-->>QuantizerCPP: amax_rowwise (M,) + amax_columnwise (K,) + FP4 data
    end

    Note over QuantizerCPP: out._per_token = True
    QuantizerCPP-->>Quantizer: NVFP4Tensor with per-token amaxes

    Note over RecipeState,GEMM: general_gemm dispatch
    RecipeState->>GEMM: "_is_nvfp4_per_token_tensor(A/B) = True"
    GEMM->>GEMM: select rowwise/columnwise view per TN/NN/NT layout
    GEMM->>GEMM: tex.nvfp4_cutlass_per_token_gemm(ka_data, kb_data, ka_sf, kb_sf, ka_amax, kb_amax, out)
    Note over GEMM: EVT epilogue: D[i,j] = bf16(alpha_a[i] * alpha_b[j] * (A@B^T)[i,j])
    GEMM-->>RecipeState: bf16 / fp32 output tensor
Loading

Reviews (7): Last reviewed commit: "Apply pre-commit formatting for NVFP4 pe..." | Re-trigger Greptile

Comment thread transformer_engine/pytorch/csrc/quantizer.cpp
Comment on lines +507 to +530
# Per-token NVFP4 dispatches to fused EVT GEMM that consumes per-row
# (M,) and per-col (N,) outer-amax vectors directly. cuBLASLt cannot,
# so this MUST short-circuit before the row-scaled-or-generic fork.
if _is_nvfp4_per_token_tensor(A) or _is_nvfp4_per_token_tensor(B):
if not (_is_nvfp4_per_token_tensor(A) and _is_nvfp4_per_token_tensor(B)):
raise NotImplementedError(
"NVFP4 per-token GEMM requires both A and B to be per-token tensors. "
"Mixing per-token + prod NVFP4 in one GEMM is not supported."
)
out = _nvfp4_per_token_gemm(
A,
B,
transa=transa,
transb=transb,
out=out,
out_dtype=out_dtype,
bias=bias,
grad=grad,
accumulate=accumulate,
gelu=gelu,
quantization_params=quantization_params,
ub=ub,
extra_output=extra_output,
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 alpha scalar silently ignored for per-token GEMM

general_gemm validates and stores alpha in kwargs["alpha"], but the per-token short-circuit path dispatches to _nvfp4_per_token_gemm which has no alpha parameter and never forwards the value. The C++ binding nvfp4_cutlass_per_token_gemm also lacks a global scalar alpha argument — only the per-row/per-col alpha_a/alpha_b vectors are supported. For all current TE module call sites alpha=1.0 is the invariant, so numerical output is unaffected today. If a caller ever passes alpha != 1.0 through general_gemm with per-token tensors, the result will be silently wrong instead of raising an error.

Comment on lines +47 to +51
for i, M_i in enumerate(split_sections):
if M_i <= 0:
raise ValueError(f"split_sections[{i}] must be > 0, got {M_i}")
if M_i % _PER_TOKEN_TILE != 0:
raise ValueError(f"split_sections[{i}] = {M_i} must be a multiple of {_PER_TOKEN_TILE}")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Public grouped-quantize API unconditionally rejects 0-token splits

split_sections[i] <= 0 raises ValueError, but in MoE training with dynamic token routing, experts commonly receive zero tokens in a given micro-batch. The general_grouped_gemm per-token loop already handles this by skipping the launch when m_splits[i] == 0, so the GEMM side is fine. If users call this Python wrapper directly (e.g., from bench scripts or custom MoE quantization pipelines), they must pre-filter empty experts. A comment or guard skipping allocation for empty splits would make the API usable in unbalanced-routing scenarios.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

@cael-ling cael-ling force-pushed the feature/nvfp4-per-token-recipe branch from 418d8dd to 160d13d Compare June 27, 2026 13:45
cael-ling and others added 17 commits June 27, 2026 06:49
Rewrites the grouped multi-tensor cast as a K1 fused amax + K2 fused cast
pair and ships pytest correctness + sweep benches against the per-tensor
RHT+SR production baseline.

  * common/cast/.../quantize_nvfp4_per_token_group.cu: K1+K2 fused
    grouped kernel, reusing the single-tensor 4-stage TMA pipeline.
  * common/gemm/nvfp4_per_token_post_scale.cu: row-wise post-scale
    kernel for the cuBLASLT NVFP4 dequantize step (maybe updated due
    to 2d quant of W).
  * pytorch/csrc/extensions/nvfp4_per_token.cpp + pybind.cpp: new C++
    grouped bulk binding and per-token GEMM entry; thin pybind layer.
  * pytorch/custom_recipes/{gemm_nvfp4_per_token,
    quantization_nvfp4_per_token_group}.py: Python wrappers.
  * tests/pytorch/nvfp4/test_nvfp4_per_token{,_group}.py: byte-equal
    cast tests + bf16-close GEMM tests.
  * tests/pytorch/nvfp4/bench_nvfp4_per_token{,_group}.py: 6x3 sweep
    over M in {1024..32768} x K in {2048,4096,8192}, eager + CUDA
    Graphs columns, ratio against per-tensor RHT+SR baseline.

Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Cael Ling <caell@nvidia.com>
…uped)

Wire `with_rht` / `random_sign_mask_t` through the per-token K1 (amax)
and K2 (encode) kernels for both single-tensor and grouped paths.
with_rht=False is byte-equal to the pre-RHT code path; when true,
applies a 16-pt RHT on the columnwise direction in both K1 and K2
(rowwise stays raw) with outer amax + inner SF self-consistent.

Implementation: per-thread fp32 FHT on CUDA cores, branchless fp32
sign-bit XOR for the +/-1 sign diagonal, 0.25 normalization folded into
block_amax / block_scale (bit-exact).

Tests cover K1, K2, composite + grouped vs a PyTorch fp32 reference and
byte-equality regressions. Benches gain a --rht flag (2-way default,
3-way under --rht).

Perf vs prod NVFP4Quantizer(rht+sr), Graph mode, 18 shapes M up to 32K:
* single tensor : 0.49x-0.77x (no RHT), 0.59x-0.88x (+RHT)
* grouped (N=8) : 0.41x-0.77x (no RHT), 0.50x-0.94x (+RHT)

Also drops unused THREADS_X_TR / THREADS_Y_TR (nvcc warning NVIDIA#177-D).

Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Cael Ling <caell@nvidia.com>
Add an optional fused-swizzle path to the NVFP4 per-token K2 encode
kernel: when with_swizzle=True the rowwise scale_inv is emitted directly
in the cuBLAS LT 128Mx4K swizzled tile layout, skipping the downstream
nvte_swizzle_scaling_factors launch. The colwise scale_inv stays in the
compact M-major layout (rowwise-only fusion for now).

The new code path is gated by a kWithSwizzle template parameter on
per_token_encode_kernel. The scatter epilogue uses thread mapping
b=tid&3, ty=tid>>2 to give each warp a coalesced 128-byte gmem store,
and packs two K-tiles into one uint64_t SMEM load (2-way bank conflict
instead of 4-way). Pre-existing code path is byte-equal.

with_swizzle is threaded through nvte_nvfp4_per_token_{quantize,encode},
their PyTorch bindings, and the nvfp4_per_token_{quantize,encode} Python
recipes. nvfp4_per_token_gemm takes new a_sf_swizzled / b_sf_swizzled
flags so the caller opts into the fast path per operand (mirrors prod
NVFP4 GEMM's per-operand swizzle).

Add tex.nvfp4_per_token_swizzle_rowwise_sf -- a thin wrapper around
nvte_swizzle_scaling_factors that does one standalone per-operand
swizzle launch. Bench-only; lets --qs attribute swizzle cost separately
from K1+K2 and from cuBLAS LT GEMM.

Bench (bench_nvfp4_per_token.py): add --qs mode (K1+K2 + standalone
swizzle, no GEMM) with two modifiers -- --pair (2 operands, matches one
prod GEMM call's quant+swizzle pipeline) and --fuse (adds a per-token
(fuse) column for the K2-fused path). The existing --swizzle end-to-end
mode also gains the fused-swizzle column. --pair / --fuse auto-imply
--qs to avoid silent fall-through to the default --composite table.

Tests (test_nvfp4_per_token.py): byte-equality of the fused-swizzle
rowwise SF vs a pure-Python permutation reference, byte-equality of all
other outputs (FP4 data, colwise SF, row/col amax) vs with_swizzle=False,
and numerical equivalence of the end-to-end GEMM via both code paths.

Perf at K=N=4096, Graph mode: fused-swizzle path is ~7-35% faster than
the unfused per-token pipeline (--qs) and reaches up to ~2.6x faster
than per-tensor at small M.

Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com>
Co-authored-by: Jiaxing Qi <jqi@nvidia.com>
Signed-off-by: Cael Ling <caell@nvidia.com>
The per-token cuBLASLt NVFP4 path needs a trailing post-scale kernel
(D *= alpha_a[i] * alpha_b[j]) that is HBM-bound on the M*N output. This
patch ships a forked-CUTLASS NVFP4 GEMM whose EVT epilogue folds the
per-row * per-col rescale into the in-TMEM accumulator -- a single launch
with no separate post-scale, no M*N HBM round-trip.

New C-API entry points (transformer_engine/common/gemm/nvfp4_cutlass_gemm.cu):
  - nvte_nvfp4_cutlass_gemm: scalar (alpha, beta) NVFP4xNVFP4 -> BF16 GEMM
    (CUTLASS analog of the cuBLASLt per-tensor path; used as test ground truth).
  - nvte_nvfp4_cutlass_per_token_gemm: same mainloop, EVT epilogue
    D[i,j] = bf16(NVFP4_DEQUANT_K * alpha_a[i] * alpha_b[j] * acc).
    The outer 1/2688^2 factor (NVFP4 spec) is baked into the EVT explicitly,
    matching the value cuBLASLt auto-folds via its amax slot.

Python bindings (tex.nvfp4_cutlass_gemm / tex.nvfp4_cutlass_per_token_gemm)
plus a/b_sf_swizzled flags for apples-to-apples --gemm-only benching.

Numerical correctness (tests/pytorch/nvfp4/test_nvfp4_cutlass_per_token_gemm.py):
  - fused EVT == cuBLASLt per-token within bf16 ULP (rtol=2e-2), across
    M,N,K = 256..1024.
  - fused EVT with unity alphas == nvfp4_cutlass_gemm(alpha=1/2688^2) BIT-EXACT
    (sanity check that the EVT tree and the baked constant are both correct).

Bench (tests/pytorch/nvfp4/bench_nvfp4_per_token.py --gemm-only) streamlined
to the only comparison that matters for shipping: ct_fused (per-token CUTLASS
fused) vs pten_gemm (prod per-tensor cuBLASLt), with the cf/pten ratio.

Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com>
Co-authored-by: Jiaxing Qi <jqi@nvidia.com>
Signed-off-by: Cael Ling <caell@nvidia.com>
Extends tests/pytorch/nvfp4/{bench,test}_nvfp4_cutlass_per_token_gemm
with end-to-end forward and backward coverage that aligns the prod
baseline with NVFP4BlockScaling real-ship defaults (input RHT-1D,
weight 2D no-RHT, grad RHT-cols + SR), so per-token (no RHT/SR) is
measured against an actually-shippable prod recipe rather than a
toy quantizer.

bench_nvfp4_per_token.py:
* --e2e-fwd: per-token quant (with_swizzle=True) + fused-EVT CUTLASS
  GEMM vs NVFP4Quantizer + general_gemm (the real nn.Linear fwd
  dispatch). Quant + GEMM inside the timing loop, N = K. Function
  docstring carries an ASCII kernel-pipeline diagram for both paths
  (per-call launch budget: per-token ~5 vs prod ~10).
* --e2e-bwd: real prod nn.Linear.bwd lifecycle. Timing loop = 1 x dY
  quant + dgrad GEMM + wgrad GEMM; X and W are pre-quantized OUTSIDE
  the loop (mirrors prod's reuse of fwd-saved QuantizedTensorStorage,
  bwd never re-quantizes). pten side uses RHT-cols + SR grad
  quantizer + general_gemm NN (dgrad) / NT (wgrad). Function docstring
  carries an ASCII kernel-pipeline diagram (per-step launch budget:
  per-token ~4 vs prod ~12).
* --gemm-only: 3-way table adds an lt_post column (cuBLASLt NVFP4 +
  bf16 per-row*per-col post-scale, "Route 1") next to the existing
  ct_fused fused-EVT path ("Route 2") and the prod pten_gemm
  baseline. Headline ratio lp/cf decides whether to dispatch
  per-token through cuBLASLt + post_scale or fused EVT; current
  data shows ct_fused wins or ties at every shape we care about.

test_nvfp4_cutlass_per_token_gemm.py:
* Layer 2 fwd: per-token quant + fused-EVT GEMM vs BF16 fp32 ground
  truth (rel_l2 < 0.30, robust to per-shape noise).
* Layer 3 fwd: dual-SNR table comparing per-token vs prod, both
  measured against BF16 ground truth, with a per-token-vs-prod ratio.
* Layer 3 bwd: same dual-SNR pattern for dgrad and wgrad. Prod side
  uses real-ship NVFP4BlockScaling grad quantizer (RHT cols + SR);
  per-token side has no RHT/SR (numerical-floor comparison).
* Sanity micro-test for weight 2D quant plumbing through general_gemm
  (catches breakage cheaper than the broader Layer 3 test).

Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com>
Co-authored-by: Jiaxing Qi <jqi@nvidia.com>
Signed-off-by: Cael Ling <caell@nvidia.com>
Add NN/NT GEMM layout dispatch so the per-token NVFP4 path covers dgrad
and wgrad, and let per-token opt into RHT via
NVFP4PerTokenBlockScaling(per_token_rht=...) while SR/2D stay disabled
(kernels unimplemented at this commit). Extends the per-token CUTLASS
GEMM, the torch NVFP4Quantizer, and the NVFP4Tensor plumbing, plus
dgrad/wgrad numerical tests and a fwd+bwd module smoke test.

Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com>
Co-authored-by: Jiaxing Qi <jqi@nvidia.com>
Signed-off-by: Cael Ling <caell@nvidia.com>
Thread a Philox rng_state and a kWithSr template flag through the
per-token encode kernel (rowwise + colwise) and the
nvte_nvfp4_per_token_encode/quantize C-API, mirroring the per-tensor
SR path. Drop the SR mutex check in the torch NVFP4Quantizer and build
the rng_state when stochastic rounding is requested. Add a per_token_sr
recipe flag on NVFP4PerTokenBlockScaling wired through the quantizer
factory, plus statistical tests (SR unbiasedness -- lower RMSE than RN
when averaged -- and RN-determinism / SR-nondeterminism) folded into
test_nvfp4_per_token.py.

Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com>
Co-authored-by: Jiaxing Qi <jqi@nvidia.com>
Signed-off-by: Cael Ling <caell@nvidia.com>
Wire with_sr + rng_state through the grouped per-token C-API and cast
dispatch, implement the SR FP4 cast in the grouped kernel, and drop the
"per-token does not support SR" guard. Also fix two comment typos
(sees -> uses) in quantize_nvfp4_per_token.cu per review.

Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com>
Co-authored-by: Jiaxing Qi <jqi@nvidia.com>
Signed-off-by: Cael Ling <caell@nvidia.com>
Introduce NVTE_NVFP4_PER_TOKEN_WEIGHT_2D (recipe.per_token_weight_2d),
default off so the per-token path stays byte-equal. When enabled, only the
forward WEIGHT switches to the per-tensor 2D cast (16x16 inner tile + scalar
outer amax) re-dressed in per-token tensor layout: the scalar outer amax is
broadcast across the per-row/col alpha vectors and the inner SF is the same
16-row-replicated 2D tile, so the existing per-token CUTLASS GEMM consumes it
unchanged with no kernel modification. Activation/gradient casts stay
per-token 1D.

Co-authored-by: Jiaxing Qi <jqi@nvidia.com>
Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Cael Ling <caell@nvidia.com>
Document the user-facing surface of the NVFP4 per-token recipe and add a
runnable single-GPU example so the recipe can be exercised end to end.

- docs/api/common.rst: list NVFP4PerTokenBlockScaling in the API reference.
- docs/envvars.rst: document the NVTE_NVFP4_* knobs -- per-token activation
  (NVTE_NVFP4_PER_TOKEN) plus the RHT/SR/weight-2D opt-ins, and the
  per-tensor disable flags.
- docs/features/.../nvfp4.rst: add a "Per-token NVFP4" section explaining the
  per-row/per-col outer-amax cast, its differences from the per-tensor default
  (RHT/SR off by default, forced-off knobs, unfused-norm requirement), and how
  to launch it with Megatron-Core.
- recipe/__init__.py: document the per_token_rht/per_token_sr/per_token_weight_2d
  constructor kwargs and drop the stale "stochastic rounding unsupported" note.
- pytorch/fp8.py: re-export NVFP4PerTokenBlockScaling.
- examples/pytorch/nvfp4_per_token_megatron: single-GPU MoE example (run +
  sbatch + job-chain scripts and README) comparing per-token vs per-tensor vs
  BF16 with identical model/data/seed.

Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com>
Co-authored-by: Jiaxing Qi <jqi@nvidia.com>
Signed-off-by: Cael Ling <caell@nvidia.com>
Replace the per-expert Python loop for the plain
D = bf16(alpha_a * alpha_b * (A @ B^T)) path with a single ptr-array
CUTLASS grouped kernel (SM100). The dispatcher in general_grouped_gemm
routes to the native kernel when no accumulate/bias/gelu/output-quant is
requested, and otherwise falls back to the per-expert loop
(NVTE_NVFP4_PER_TOKEN_GROUPED_FALLBACK=1 forces the fallback).

The launcher caches the SM count and reuses persistent device scratch +
workspace buffers across launches to avoid per-call cudaMalloc/Free and
cudaGetDeviceProperties overhead. Parity tests assert the grouped kernel
matches the dense per-token GEMM bit-exact per group.

Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com>
Co-authored-by: Jiaxing Qi <jqi@nvidia.com>
Signed-off-by: Cael Ling <caell@nvidia.com>
Add an fp32-output accumulate variant to the dense and grouped per-token
NVFP4 CUTLASS kernels. The EVT computes D = beta * C + dW, where beta=1
accumulates the weight gradient in place into the fp32 main_grad buffer
(C aliases D) and beta=0 overwrites. This lets te.Linear / GroupedLinear
wgrad accumulate straight into main_grad, mirroring the prod per-tensor
cuBLAS LT path (C == D in place, beta = accumulate ? 1 : 0; output
quantization disables accumulation).

Add dense and grouped parity tests covering fp32 overwrite (matches the
bf16 path cast to fp32) and bit-exact in-place accumulation.

Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com>
Co-authored-by: Jiaxing Qi <jqi@nvidia.com>
Signed-off-by: Cael Ling <caell@nvidia.com>
The grouped per-token GEMM uploads ~14 small per-group metadata arrays
(problem shapes, A/B/C/D strides, SFA/SFB layouts, data/SF/D pointers,
alpha_a/alpha_b) to device before each launch. Issuing one cudaMemcpyAsync
per array adds a fixed ~15-20us of host-side overhead that dominates the
launch-bound regime (small / many-expert MoE).

Pack all arrays into a process-persistent pageable host mirror at their
256B-aligned scratch offsets and ship them in a SINGLE H2D copy. The buffer
is intentionally pageable: cudaMemcpyAsync from pageable host memory stages
the source into the driver before returning, so the mirror is safe to
overwrite on the next call even when the host runs ahead of the stream.

Gated by NVTE_NVFP4_GROUPED_BATCHED_H2D (default on); set it to 0 to fall
back to the per-array copies for A/B measurement in the same build.

Pure-GEMM bench (SM100): the native kernel drops a constant ~15-20us per
launch, lifting native-vs-fallback speedup from ~1.5-3.5x to ~2.0-4.0x;
biggest relative win on small/few-expert shapes. Numerics unchanged:
test_nvfp4_cutlass_per_token_gemm.py grouped cases
pass bit-exact with the env on and off.

Signed-off-by: Cael Ling <caell@nvidia.com>
Add an optional per-group bias to the CUTLASS grouped per-token GEMM by
extending the EVT epilogue with a bias node, and thread the bias pointers
through the C-API, the PyTorch extension wrapper, and the general_grouped_gemm
dispatch. Bias is fprop-only (no accumulate) and matches the cuBLASLt path's
fused bias-add. Add a numerical test comparing the fused output against an
fp32 GEMM plus bias reference.

Signed-off-by: Cael Ling <caell@nvidia.com>
Per-token NVFP4's outer amax is a per-row/per-col vector computed by its
amax kernel, so the fused norm/activation/dact + scalar-amax kernels cannot
feed quantize_with_amax. Auto-select the UNFUSED path for per-token in the normalization, activation, and dact+dbias dispatch sites, removing the need for the NVTE_NORM_FWD_USE_CUDNN=1 workaround. This is a no-op for per-tensor NVFP4. Update docs and the Megatron example to drop the env var, and add a LayerNormLinear/LayerNormMLP fwd+bwd regression that runs with the env var unset.

Signed-off-by: Cael Ling <caell@nvidia.com>
Upstream now stores the NVFP4 1x16 block scale_inv as torch.float8_e4m3fn
instead of raw uint8. The per-token CUTLASS GEMM binding requires uint8
scale bytes, so _nvfp4_per_token_select reinterprets the scale as uint8
(zero-copy: e4m3 and uint8 are byte-identical). Fixes the "a_sf/b_sf must
be uint8" runtime error in te.Linear / LayerNorm* under the per-token
recipe. Covers both single and grouped per-token GEMM.

Signed-off-by: Cael Ling <caell@nvidia.com>
Signed-off-by: Cael Ling <caell@nvidia.com>
@cael-ling cael-ling force-pushed the feature/nvfp4-per-token-recipe branch from 334577e to 690ffea Compare June 27, 2026 14:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants