Skip to content

[Common/PyTorch] Grouped-quantize kernels for 1D and 2D FP8 block-scaling#3135

Open
denera wants to merge 1 commit into
NVIDIA:mainfrom
denera:common/fp8-block-scaling-grouped-quantize
Open

[Common/PyTorch] Grouped-quantize kernels for 1D and 2D FP8 block-scaling#3135
denera wants to merge 1 commit into
NVIDIA:mainfrom
denera:common/fp8-block-scaling-grouped-quantize

Conversation

@denera

@denera denera commented Jun 17, 2026

Copy link
Copy Markdown
Collaborator

Description

Implements grouped-tensor quantize for the FP8 1D (1x128) and 2D (128x128) block-scaling recipes in row-wise (RW), column-wise (CW) and BOTH quantization directions. A single CUDA kernel launch walks 128x128 tiles across every tensor in the group, with each CTA decoding its owning tensor from the device-side GroupedTensor metadata with (N, R, K) shapes. Supports SAME_BOTH_DIMS (all tensors identical) and VARYING_FIRST_DIM (constant K, varying R) shape representations.

Three kernels share the dispatcher in group_quantize_blockwise_{1d,2d}:

  • group_block_scaled_1d_rw_kernel — RW-only dispatch; 8 threads/row, reads global memory directly into vec-16 registers; bypasses TMA because the shared memory roundtrip and ptx::mbarrier does not buy anything without re-use in CW path.
  • group_block_scaled_1d_tma_kernel — CW-only and BOTH dispatch; TMA bulk-load fills shared memory input cache. BOTH runs RW pass first (8 threads/row, vec-16 read from shared memory) then CW pass (2 threads/column, 64-row register stage); CW-only skips the RW pass. CW path writes the transposed-FP8 tile to a shared memory transpose staging buffer, then drains to global memory.
  • group_block_scaled_2d_tma_kernel — RW-only, CW-only and BOTH dispatch; TMA bulk-load fills shared memory input cache. Pass 1 stages 8 IVecs/thread in registers while computing the per-tile scalar amax. Pass 2 quantizes from registers, emits row-wise output, stages column-wise output to shared memory transpose staging buffer, then drains to global memory.

Kernels are gated to Hopper (sm_90) at the host dispatcher (cuBlasLt grouped GEMM supports FP8 block-scaling only on Hopper).

PR includes PyTorch integration into te.GroupedTensor only.

PyTorch integration into te.GroupedLinear and JAX integration deferred to a follow-up PRs.

Partially resolves #2525

Performance

Benchmark on H200 with a sweep of grouped tensors in (N, M, K) shapes:

  • N ∈ {4, 8, 16, 32, 64, 128} (# of device-local experts)
  • M = 4096 @ N = 4 → M = 128 @ N = 128 (# of tokens/expert, scaling inversely with # experts)
  • K ∈ {1024, 1792, 2048, 3584, 4096, 7168} (device-local shard of TP-hidden/intermediate-FFN dim)

Two shape families per config:

  • U_MoE (uniform, SAME_BOTH_DIMS): all experts share the (M, K) shape
  • J_MoE (jagged, VARYING_FIRST_DIM): per-expert M drawn from an imbalanced routing, common K

Buckets:

  • Small/Unsaturated (S): R·K ≤ 32M elements (< 2048 tiles and < 15 waves on H200's 132 SMs)
  • Large/Saturated (L): R·K > 32M elements (> 2048 tiles, SMs busy across many waves)

Bucket medians across 3 reps. Speedup is grouped vs the split-quantized fallback that loops over the grouped tensor and quantizes each constituent sequentially. % mono is grouped throughput relative to a single non-grouped FP8 block-scaling quantize on the equivalent monolithic (N·M, K) tensor.

Bucket Path Grouped (ms) Split (ms) Speedup % memcpy tput % mono tput
S 1D RW 0.019 0.084 4.50× 64.9 % 114.9 %
S 1D CW 0.022 0.090 4.17× 56.4 % 112.4 %
S 1D BOTH 0.033 0.116 3.57× 50.2 % 102.8 %
S 2D RW 0.018 0.076 4.19× 65.9 % 100.9 %
S 2D CW 0.020 0.089 4.64× 62.6 % 126.3 %
S 2D BOTH 0.027 0.089 3.68× 66.3 % 98.8 %
L 1D RW 0.058 0.198 2.06× 87.1 % 118.7 %
L 1D CW 0.064 0.213 2.05× 77.3 % 118.5 %
L 1D BOTH 0.098 0.282 1.77× 66.7 % 108.4 %
L 2D RW 0.054 0.178 1.99× 87.7 % 100.1 %
L 2D CW 0.058 0.213 2.20× 85.3 % 135.9 %
L 2D BOTH 0.078 0.213 1.62× 85.0 % 102.4 %
# experts (N) S bucket L bucket
4 1.74× 1.42×
8 2.40× 1.45×
16 4.18× 1.89×
32 5.50× 2.81×
64 10.43× 7.51×
128 19.81× 8.72×

Notes

  • % of mono throughput is roughly consistent across buckets for every path, confirming no per-expert overhead in the new kernels.
  • Greater-than-100% mono throughput cases come from TMA bulk-loads, register staging, and vec-16 reads that the non-grouped FP8 block-scaling kernels do not have.
  • Speedup over split-quantize scales as expected with # of experts (roughly linearly in the unsaturated regime).
  • S-bucket % memcpy is lower than L because launch and per-CTA setup are not amortized over a long bandwidth-bound steady state; absolute kernel times are still small (< 35 µs).

Known Sub-Optimalities

1D CW load bank conflicts on ~35% of load wavefronts (reading from the shared memory input-cache)

  • No possible stride padding or XOR swizzle to alleviate.
  • TMA hardware swizzle with CU_TENSOR_MAP_SWIZZLE_128B has the right pattern but caps FP16/BF16 at 64-elements; does not fit the 128-element tile for FP8 block-scaling without doubling per-tile launch overhead (quadrupling for FP32).
  • CW-only stays at 2 t/col. Going to 4 t/col would double the bank-conflict footprint (4 lanes per column at the same row stride instead of 2), and CW-only is not occupancy-bound (3 CTAs/SM regardless), so the restructure costs more than it saves.
  • 1D BOTH uses 4 t/col with a 32-row reg_data per thread and two column passes per CTA. The RW pass's per-expert scale-offset arithmetic plus a 64-row reg_data crossed the 85-reg / 3-CTA threshold on sm_90; halving reg_data restores 4 CTAs/SM. The doubled column pass and extra XOR-reduce stage are cheap relative to the occupancy gain.

1D BOTH reads the shared memory input-cache twice

  • The RW (8 threads/row) and CW (2 threads/column) passes have different threading.
  • Attempted to unify with 8 threads/row for both RW and CW. Caused bank conflicts on ~76% of store wavefronts (writing to the shared memory transpose buffer), reduced to ~43% with a XOR swizzle but not enough to beat separate RW/CW passes.
  • Did not pursue the 2 threads/column unification; costs 40x more shfl ops than 8 threads/row attempt, plus a shared memory partial buffer and sync.

2D CW/BOTH has bank conflicts on ~16% of store wavefronts (when writing to the shared memory transpose buffer)

  • Already reduced from ~75% via a XOR swizzle, further reduction was not possible.
  • Minimal impact (< 5%) on kernel time.

No TMA-store

  • MXFP8 grouped quantize kernel leverages this by decomposing a 128x128 tile into 32-row sub-stages that each have their own independent 32x1 or 1x32 scale; shared memory footprint is based on a single sub-stage; can be quantized and TMA-stored independently; hides TMA-store of one stage under the compute of next stage.
  • FP8 block-scaling 128-element scale-block spans the entire 128-row tile. Cannot decompose into independent sub-stages and pipeline the TMA-stores. Single non-pipelined TMA-store requires holding the transposed staging buffer for the entire tile until all work on tile is finished, blows up shared memory footprint, collapses occupancy to 2CTA/SM. The recipe itself is the roadblock.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@denera denera requested review from ptrendx and vthumbe1503 June 17, 2026 13:01
@denera denera self-assigned this Jun 17, 2026
@denera denera added performance Performance issues FP8 MoE labels Jun 17, 2026
constexpr int kThreadsPerBlock = 256;
constexpr int kNumWarps = kThreadsPerBlock / kThreadsPerWarp;

// Align a dynamic-smem pointer to 128 bytes (TMA requirement).

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we reuse the existing align_smem_ptr_per_TMA_requirements() helper from transformer_engine/cast/core/common.h here?

size_t total_row_blocks) {
using namespace transformer_engine::dispatch::mxfp8::swizzle;
const size_t num_tiles_X =
(total_row_blocks + GEMM_SWIZZLED_SCALE_TILE_DIM_X - 1) / GEMM_SWIZZLED_SCALE_TILE_DIM_X;

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can also reuse the existing DIVUP() helper here (defined in transformer_engin/common/common.h).


// ---- Tensor-lookup helpers ----------------------------------------------------

// Map a global tile-row index to its owning tensor by binary-searching

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can also reuse the existing get_current_tensor_id() helper defined in transformer_engine/cast/core/common.cuh

@greptile-apps

greptile-apps Bot commented Jun 17, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

Adds fused grouped-tensor FP8 block-scaling quantize/dequantize (1D 1×128 and 2D 128×128) for the Hopper (sm_90) path. A single CUDA kernel launch walks 128×128 tiles across every tensor, resolving per-expert ownership from device-side GroupedTensor metadata. Three kernel variants cover RW-only, CW-only, and BOTH directions for each block dimension, with TMA bulk-load used on the CW and BOTH paths. PyTorch integration lands in te.GroupedTensor; GroupedLinear and JAX integration are deferred.

  • group_quantize_fp8_blockwise.cuh / group_dequantize_fp8_blockwise.cuh: new 1D/2D block-scaling kernels with per-expert compact-layout scale buffering that matches cuBLAS grouped FP8 block-scaling GEMM's per-expert sub-block expectation.
  • swizzle.cuh: moved from mxfp8/ to cast/ so the 128×4 swizzle helper is shared by MXFP8, NVFP4, and the new FP8 block-scaling path; namespace updated in cublaslt_grouped_gemm.cu and the MXFP8/NVFP4 headers.
  • ptx.cuh: mbarrier_* guards lowered from sm_100+ to sm_90+; new cp_async_bulk_tensor_2d_global_to_shared_cta (cluster-size-1 variant) added for Hopper.

Confidence Score: 5/5

The change is self-contained and non-breaking: new kernel paths behind new NVTE_BLOCK_SCALING_{1D,2D} dispatch cases, Hopper-only guard in both host dispatcher and kernel body, compact (non-swizzled) scale layout explicitly set.

All three kernel variants have their critical paths validated by a comprehensive C++ test comparing against split-quantize reference for all direction/shape/block-dim combinations. Per-expert scale layout math, 2D CW prefix sum, and buffer slack bound are correct. No mismatches found between quantize and dequantize scale-index formulas.

No files require special attention. quantize.cuh and test_grouped_tensor.py have minor non-blocking follow-up opportunities noted in the comments.

Important Files Changed

Filename Overview
transformer_engine/common/cast/fp8_blockwise/group_quantize_fp8_blockwise.cuh Core 1D/2D block-scaling grouped quantize kernels and host dispatchers; scale layout helpers, TMA bulk-load, per-expert prefix sums, and transpose staging are correct.
transformer_engine/common/cast/fp8_blockwise/group_dequantize_fp8_blockwise.cuh Dequantize kernels mirror the quantize layouts for all four {1D,2D}×{RW,CW} combinations; cooperative prefix-sum syncthreads ordering and in-bounds early returns look correct.
transformer_engine/pytorch/csrc/quantizer.cpp Scale-buffer sizing uses the tight bound blocks_X * 4 * (num_tensors-1) for 2D CW rounding slack, which is mathematically correct; force_pow_2_scales rejection and with_gemm_swizzled_scales(false) are both properly set.
transformer_engine/pytorch/csrc/extensions/cast.cpp FP8 blockwise mode routed in group_quantize and bgrad_group_quantize; scale dtype correctly selected as kFloat32 for blockwise in both rowwise and columnwise dequantize paths.
transformer_engine/common/cast/dispatch/quantize.cuh Forward and backward group-quantize helpers dispatch to the new blockwise kernels; force_pow_2_scales from the config is not checked at the C API level (P2 comment added).
transformer_engine/common/util/ptx.cuh Mbarrier PTX intrinsics correctly lowered to sm_90+; new cp_async_bulk_tensor_2d_global_to_shared_cta (shared::cta space, no cluster) properly gated at sm_90+.
transformer_engine/common/cast/swizzle.cuh Clean refactor: gemm_swizzled_scale_idx moved from mxfp8/swizzle.cuh to cast/swizzle.cuh under the shared dispatch::swizzle namespace.
tests/cpp/operator/test_cast_float8blockwise_grouped.cu C++ test validates all 18 combinations against per-tensor split-quantize reference; scale layout indexing matches the kernel's per-expert compact layout precisely.
tests/pytorch/test_grouped_tensor.py Python tests cover quantize+dbias and dequantize round-trips but only for rowwise=True, columnwise=False; CW and BOTH directions are exercised exclusively in the C++ test suite.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[group_quantize / bgrad_group_quantize
Python / C API] --> B{scaling_mode}
    B -->|NVTE_BLOCK_SCALING_1D| C[group_quantize_blockwise_1d]
    B -->|NVTE_BLOCK_SCALING_2D| D[group_quantize_blockwise_2d]
    B -->|NVTE_MXFP8_1D| E[mxfp8::group_quantize]
    C --> F{use_colwise or dbias?}
    F -->|RW-only, no dbias| G[group_block_scaled_1d_rw_kernel
vec-16 gmem reads, no smem cache]
    F -->|CW-only / BOTH / RW+dbias| H[group_block_scaled_1d_tma_kernel
TMA bulk-load to smem cache]
    H --> H1[RW pass: 8t/row vec-16 from smem]
    H --> H2[CW pass: 2-4t/col reg_data smem_T transpose
drain_smem_t_to_gmem]
    H --> H3[Optional dbias partial]
    D --> I[group_block_scaled_2d_tma_kernel
TMA bulk-load to smem cache]
    I --> I1[Pass 1: tile amax stage 8 IVecs in regs]
    I --> I2[Pass 2: quantize from regs
RW output + smem_T to CW output]
    I --> I3[Optional dbias partial]
    H2 --> J[grouped_reduce_dbias
per-expert column sum reduction]
    H3 --> J
    I2 --> J
    I3 --> J
Loading
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
flowchart TD
    A[group_quantize / bgrad_group_quantize
Python / C API] --> B{scaling_mode}
    B -->|NVTE_BLOCK_SCALING_1D| C[group_quantize_blockwise_1d]
    B -->|NVTE_BLOCK_SCALING_2D| D[group_quantize_blockwise_2d]
    B -->|NVTE_MXFP8_1D| E[mxfp8::group_quantize]
    C --> F{use_colwise or dbias?}
    F -->|RW-only, no dbias| G[group_block_scaled_1d_rw_kernel
vec-16 gmem reads, no smem cache]
    F -->|CW-only / BOTH / RW+dbias| H[group_block_scaled_1d_tma_kernel
TMA bulk-load to smem cache]
    H --> H1[RW pass: 8t/row vec-16 from smem]
    H --> H2[CW pass: 2-4t/col reg_data smem_T transpose
drain_smem_t_to_gmem]
    H --> H3[Optional dbias partial]
    D --> I[group_block_scaled_2d_tma_kernel
TMA bulk-load to smem cache]
    I --> I1[Pass 1: tile amax stage 8 IVecs in regs]
    I --> I2[Pass 2: quantize from regs
RW output + smem_T to CW output]
    I --> I3[Optional dbias partial]
    H2 --> J[grouped_reduce_dbias
per-expert column sum reduction]
    H3 --> J
    I2 --> J
    I3 --> J
Loading

Reviews (7): Last reviewed commit: "[Common/PyTorch] Grouped-quantize kernel..." | Re-trigger Greptile

Comment thread tests/cpp/operator/test_cast_float8blockwise_grouped.cu
}

CType amax = compute_row_amax<IType, CType, kVec>(in_vec[it]);
amax = fmaxf(amax, __shfl_xor_sync(0xffffffff, amax, 1));

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we reuse the existing amax warp-reduction helpers (warp_reduce_max() or reduce_max()) from transformer_engine/common/utils.cuh here?

Comment on lines +535 to +537
amax = fmaxf(amax, __shfl_xor_sync(0xffffffff, amax, 1));
amax = fmaxf(amax, __shfl_xor_sync(0xffffffff, amax, 2));
amax = fmaxf(amax, __shfl_xor_sync(0xffffffff, amax, 4));

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can also reuse reduce_max() or warp_reduce_max() here.


// ----- Host-side dispatchers --------------------------------------------------------------------

inline size_t align_up_to(size_t x, size_t a) { return ((x + a - 1) / a) * a; }

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can reuse DIVUP_TO_MULTIPLE() defined in transformer_engine/common/common.h.

NVTE_CHECK(info.tensor_offsets_d != nullptr,
"VARYING_FIRST_DIM requires tensor_offsets to be set on the GroupedTensor.");
}
info.total_row_blocks = (info.R_total + kTileDim - 1) / kTileDim;

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
info.total_row_blocks = (info.R_total + kTileDim - 1) / kTileDim;
info.total_row_blocks = DIVUP(info.R_total, kTileDim);

"VARYING_FIRST_DIM requires tensor_offsets to be set on the GroupedTensor.");
}
info.total_row_blocks = (info.R_total + kTileDim - 1) / kTileDim;
info.blocks_X = (info.K + kTileDim - 1) / kTileDim;

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
info.blocks_X = (info.K + kTileDim - 1) / kTileDim;
info.blocks_X = DIVUP(info.K, kTileDim);

info.same_both_dims = same_both_dims;
info.num_tensors = output->num_tensors;
info.K = output->get_common_last_dim();
NVTE_CHECK(info.K % 16 == 0, "Last dim must be multiple of 16 (FP8 alignment).");

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is a TMA requirement, we can use the TMA_GMEM_ALIGNMENT constant defined in transformer_engine/common/common.h

const float* noop_ptr =
(noop != nullptr) ? reinterpret_cast<const float*>(noop->data.dptr) : nullptr;

const size_t scale_stride_y = align_up_to(info.blocks_X, 4);

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
const size_t scale_stride_y = align_up_to(info.blocks_X, 4);
const size_t scale_stride_y = DIVUP_TO_MULTIPLE(info.blocks_X, 4);

const size_t scale_stride_y = align_up_to(info.blocks_X, 4);
// CW scales are stored [blocks_X, align4(total_row_blocks)] -- transposed to
// match the physically-transposed columnwise data the TN cuBLAS GEMM consumes.
const size_t scale_t_stride_y = align_up_to(info.total_row_blocks, 4);

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
const size_t scale_t_stride_y = align_up_to(info.total_row_blocks, 4);
const size_t scale_t_stride_y = DIVUP_TO_MULTIPLE(info.total_row_blocks, 4);

const float* noop_ptr =
(noop != nullptr) ? reinterpret_cast<const float*>(noop->data.dptr) : nullptr;

const size_t scale_stride_aligned_R = align_up_to(info.R_total, 4);

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
const size_t scale_stride_aligned_R = align_up_to(info.R_total, 4);
const size_t scale_stride_aligned_R = DIVUP_TO_MULTIPLE(info.R_total, 4);

(noop != nullptr) ? reinterpret_cast<const float*>(noop->data.dptr) : nullptr;

const size_t scale_stride_aligned_R = align_up_to(info.R_total, 4);
const size_t scale_t_stride_aligned_K = align_up_to(info.K, 4);

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
const size_t scale_t_stride_aligned_K = align_up_to(info.K, 4);
const size_t scale_t_stride_aligned_K = DIVUP_TO_MULTIPLE(info.K, 4);

denera added a commit to denera/TransformerEngine that referenced this pull request Jun 22, 2026
- Reuse shared helpers (DIVUP, DIVUP_TO_MULTIPLE, TMA_GMEM_ALIGNMENT,
  align_smem_ptr_per_TMA_requirements, get_current_tensor_id,
  subwarp_reduce_max_broadcast) in place of local equivalents.
- Add proxy-async fence after mbarrier_init in 2D + 1D TMA kernels.
- Enforce per-tensor first_dim % 128 device-side for VARYING_FIRST_DIM
  (matches MXFP8 grouped quantize behavior).
- Fix Hopper SM range wording in 1D dispatcher.
- Extend cpp tests to cover with_gemm_swizzled_scales path.

Signed-off-by: Alp Dener <adener@nvidia.com>
@denera denera requested a review from Oleg-Goncharov June 22, 2026 23:06
// num_tiles_X = DIVUP(total_row_blocks, TILE_DIM_X=4)
__device__ __forceinline__ size_t swizzled_colwise_scale_idx(size_t i, size_t j,
size_t total_row_blocks) {
using namespace transformer_engine::dispatch::mxfp8::swizzle;

@vthumbe1503 vthumbe1503 Jun 22, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should rename the namespace for swizzle...given that we use the same constants for mxfp8, nvfp4, fp8 block scaling

Comment on lines 119 to 130
if get_device_compute_capability() < (10, 0):
return False
# 4. Output quantization is not supported.
if any(q is not None for q in output_quantizers):
return False
# 5. Filter by quantization recipes.
if fp8:
if not (10, 0) <= get_device_compute_capability() <= (11, 0):
return False
return all(isinstance(q, MXFP8Quantizer) for q in input_quantizers) or all(
isinstance(q, NVFP4Quantizer) and q.with_rht for q in input_quantizers
return (
activation_dtype in (torch.bfloat16, torch.float16)
and all(isinstance(q, MXFP8Quantizer) for q in input_quantizers)
and all(isinstance(q, MXFP8Quantizer) for q in weight_quantizers)
and all(q is None or isinstance(q, MXFP8Quantizer) for q in grad_output_quantizers)
)
return activation_dtype in (torch.bfloat16, torch.float16)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Hopper BF16/FP16 fused path silently disabled

The old code used two separate capability gates: >= (9,0) for all paths and >= (10,0) for FP8 only, so Hopper (CC 9.0) was valid for BF16/FP16. The new single gate < (10,0) returns False for Hopper unconditionally. Anyone who has NVTE_GROUPED_LINEAR_USE_FUSED_GROUPED_GEMM=1 on Hopper will silently fall back to the unfused loop with no warning — the docstring that previously advertised Hopper BF16/FP16 support was removed in the same commit.

The corresponding test skip condition was tightened from not (9,0) <= cc <= (11,0) to cc < (10,0), confirming the regression is intentional. If that is the case, a NVTE_WARN or a comment explaining why Hopper is excluded (e.g., cuBLAS limitation, untested, or actively broken) would prevent confusion for users and future contributors.

Comment on lines 123 to 129
if fp8:
if not (10, 0) <= get_device_compute_capability() <= (11, 0):
return False
return all(isinstance(q, MXFP8Quantizer) for q in input_quantizers) or all(
isinstance(q, NVFP4Quantizer) and q.with_rht for q in input_quantizers
return (
activation_dtype in (torch.bfloat16, torch.float16)
and all(isinstance(q, MXFP8Quantizer) for q in input_quantizers)
and all(isinstance(q, MXFP8Quantizer) for q in weight_quantizers)
and all(q is None or isinstance(q, MXFP8Quantizer) for q in grad_output_quantizers)
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 NVFP4+RHT fused grouped GEMM path silently removed

The previous fp8 branch accepted NVFP4Quantizer with with_rht=True, and NVFP4Quantizer was explicitly imported. Both the type check and the import are removed here, so Blackwell users who opted into NVTE_GROUPED_LINEAR_USE_FUSED_GROUPED_GEMM=1 with an NVFP4 recipe will silently fall back to the unfused per-expert loop. The entire test_grouped_linear_grouped_tensor_path_skips_non_rht_nvfp4 test was deleted in the same commit, confirming the removal is deliberate.

If this drop is permanent, a brief note in the PR description or a NVTE_WARN when NVFP4 is detected at this gate would prevent surprise for existing users of that configuration.

@denera denera force-pushed the common/fp8-block-scaling-grouped-quantize branch from ef0ef0a to bff55f6 Compare June 27, 2026 09:59
…ling

Implements grouped-tensor quantize for the FP8 1D (1x128) and 2D (128x128)
block-scaling recipes in row-wise (RW), column-wise (CW) and BOTH quantization
directions. A single CUDA kernel launch walks 128x128 tiles across every tensor
in the group, with each CTA decoding its owning tensor from the device-side
GroupedTensor metadata with (N, R, K) shapes. Supports SAME_BOTH_DIMS (all
tensors identical) and VARYING_FIRST_DIM (constant K, varying R) shape
representations.

Three kernels share the dispatcher in group_quantize_blockwise_{1d,2d}:
- group_block_scaled_1d_rw_kernel: RW-only dispatch; 8 threads/row, reads
  global memory directly into vec-16 registers; bypasses TMA since the
  shared-memory roundtrip and ptx::mbarrier do not buy anything without
  re-use in the CW path.
- group_block_scaled_1d_tma_kernel: CW-only and BOTH dispatch. TMA bulk-load
  fills shared memory input cache. BOTH runs an RW pass (8 threads/row,
  vec-16 read from shared memory) then a CW pass; CW-only skips the RW
  pass. The CW pass uses 4 t/col with 32-row reg_data and two column passes
  in the BOTH instantiation (keeps the per-thread register footprint under
  the sm_90 3-CTAs/SM threshold) and 2 t/col with 64-row reg_data in the
  CW-only instantiation (avoids doubling the smem-load bank-conflict
  footprint that 4 t/col would introduce).
- group_block_scaled_2d_tma_kernel: RW-only, CW-only and BOTH dispatch. TMA
  bulk-load fills shared memory input cache. Pass 1 stages 8 IVecs/thread
  in registers while computing the per-tile scalar amax. Pass 2 quantizes
  from registers, emits row-wise output, stages column-wise output to the
  shared memory transpose staging buffer, then drains smem_T to global
  memory.

Per-expert scale offsets:
- 1D RW: closed-form O(1) for both SAME_BOTH_DIMS and VARYING_FIRST_DIM
  (each M_i is a multiple of kTileDim=128, hence of kScaleColAlign=4, so
  DIVUP_TO_MULTIPLE collapses and the prefix sum reduces to a single
  tensor_offsets_ptr[tensor_id]/K load).
- 2D CW: closed-form O(1) for SAME_BOTH_DIMS; CTA-cooperative warp-shuffle
  prefix sum for VARYING_FIRST_DIM (non-linear DIVUP_TO_MULTIPLE on
  blocks_y_t prevents a closed form). The cooperative reduction uses the
  existing warp_allreduce_sum helper from common/utils.cuh.

Dequantize and bias-gradient (bgrad):
- group_dequantize_fp8_blockwise.cuh: kernels for all four modes
  (1D/2D x rowwise/columnwise), inverting the per-expert layouts the
  quantize kernels write.
- bgrad_group_quantize accepts Float8Block quantizers and computes dbias
  per-tile column-partial in-kernel (mirroring MXFP8); reduced per expert
  via the existing common::grouped_reduce_dbias.

Scale constraints: the fused grouped FP8BS path supports only unconstrained
FP32 scales (Float8BlockQuantizer::create_grouped_tensor rejects
force_pow_2_scales=True). Power-of-2 scales remain available on the
non-grouped/unfused split-quantize path used for Blackwell MXFP8 emulation.

Tests: existing parametrized grouped quantize / dequantize / bgrad tests
in test_grouped_tensor.py cover MXFP8, NVFP4, FP8 current scaling and the
newly-added FP8 block scaling recipe. tests/cpp/operator/
test_cast_float8blockwise_grouped.cu adds 72 C++ unit-test cases over
uniform/jagged shapes, all four (BD x direction) modes, K in {128, 256,
512}, and CUDA-graph capture coverage.

Kernels are gated to Hopper (sm_90) at the host dispatcher (cuBlasLt
grouped GEMM supports FP8 block-scaling only on Hopper).

JAX integration is intentionally left out of scope and deferred to a
follow-up PR.

Resolves NVIDIA#2525

Signed-off-by: Alp Dener <adener@nvidia.com>
@denera denera force-pushed the common/fp8-block-scaling-grouped-quantize branch from e153c9b to dc998bd Compare June 27, 2026 13:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

FP8 MoE performance Performance issues

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Blockwise (1x128 and 128x128) FP8 grouped quantization

3 participants