[Common/PyTorch] Grouped-quantize kernels for 1D and 2D FP8 block-scaling#3135
[Common/PyTorch] Grouped-quantize kernels for 1D and 2D FP8 block-scaling#3135denera wants to merge 1 commit into
Conversation
| constexpr int kThreadsPerBlock = 256; | ||
| constexpr int kNumWarps = kThreadsPerBlock / kThreadsPerWarp; | ||
|
|
||
| // Align a dynamic-smem pointer to 128 bytes (TMA requirement). |
There was a problem hiding this comment.
Could we reuse the existing align_smem_ptr_per_TMA_requirements() helper from transformer_engine/cast/core/common.h here?
| size_t total_row_blocks) { | ||
| using namespace transformer_engine::dispatch::mxfp8::swizzle; | ||
| const size_t num_tiles_X = | ||
| (total_row_blocks + GEMM_SWIZZLED_SCALE_TILE_DIM_X - 1) / GEMM_SWIZZLED_SCALE_TILE_DIM_X; |
There was a problem hiding this comment.
We can also reuse the existing DIVUP() helper here (defined in transformer_engin/common/common.h).
|
|
||
| // ---- Tensor-lookup helpers ---------------------------------------------------- | ||
|
|
||
| // Map a global tile-row index to its owning tensor by binary-searching |
There was a problem hiding this comment.
We can also reuse the existing get_current_tensor_id() helper defined in transformer_engine/cast/core/common.cuh
Greptile SummaryAdds fused grouped-tensor FP8 block-scaling quantize/dequantize (1D 1×128 and 2D 128×128) for the Hopper (sm_90) path. A single CUDA kernel launch walks 128×128 tiles across every tensor, resolving per-expert ownership from device-side
Confidence Score: 5/5The change is self-contained and non-breaking: new kernel paths behind new NVTE_BLOCK_SCALING_{1D,2D} dispatch cases, Hopper-only guard in both host dispatcher and kernel body, compact (non-swizzled) scale layout explicitly set. All three kernel variants have their critical paths validated by a comprehensive C++ test comparing against split-quantize reference for all direction/shape/block-dim combinations. Per-expert scale layout math, 2D CW prefix sum, and buffer slack bound are correct. No mismatches found between quantize and dequantize scale-index formulas. No files require special attention. quantize.cuh and test_grouped_tensor.py have minor non-blocking follow-up opportunities noted in the comments. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[group_quantize / bgrad_group_quantize
Python / C API] --> B{scaling_mode}
B -->|NVTE_BLOCK_SCALING_1D| C[group_quantize_blockwise_1d]
B -->|NVTE_BLOCK_SCALING_2D| D[group_quantize_blockwise_2d]
B -->|NVTE_MXFP8_1D| E[mxfp8::group_quantize]
C --> F{use_colwise or dbias?}
F -->|RW-only, no dbias| G[group_block_scaled_1d_rw_kernel
vec-16 gmem reads, no smem cache]
F -->|CW-only / BOTH / RW+dbias| H[group_block_scaled_1d_tma_kernel
TMA bulk-load to smem cache]
H --> H1[RW pass: 8t/row vec-16 from smem]
H --> H2[CW pass: 2-4t/col reg_data smem_T transpose
drain_smem_t_to_gmem]
H --> H3[Optional dbias partial]
D --> I[group_block_scaled_2d_tma_kernel
TMA bulk-load to smem cache]
I --> I1[Pass 1: tile amax stage 8 IVecs in regs]
I --> I2[Pass 2: quantize from regs
RW output + smem_T to CW output]
I --> I3[Optional dbias partial]
H2 --> J[grouped_reduce_dbias
per-expert column sum reduction]
H3 --> J
I2 --> J
I3 --> J
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
flowchart TD
A[group_quantize / bgrad_group_quantize
Python / C API] --> B{scaling_mode}
B -->|NVTE_BLOCK_SCALING_1D| C[group_quantize_blockwise_1d]
B -->|NVTE_BLOCK_SCALING_2D| D[group_quantize_blockwise_2d]
B -->|NVTE_MXFP8_1D| E[mxfp8::group_quantize]
C --> F{use_colwise or dbias?}
F -->|RW-only, no dbias| G[group_block_scaled_1d_rw_kernel
vec-16 gmem reads, no smem cache]
F -->|CW-only / BOTH / RW+dbias| H[group_block_scaled_1d_tma_kernel
TMA bulk-load to smem cache]
H --> H1[RW pass: 8t/row vec-16 from smem]
H --> H2[CW pass: 2-4t/col reg_data smem_T transpose
drain_smem_t_to_gmem]
H --> H3[Optional dbias partial]
D --> I[group_block_scaled_2d_tma_kernel
TMA bulk-load to smem cache]
I --> I1[Pass 1: tile amax stage 8 IVecs in regs]
I --> I2[Pass 2: quantize from regs
RW output + smem_T to CW output]
I --> I3[Optional dbias partial]
H2 --> J[grouped_reduce_dbias
per-expert column sum reduction]
H3 --> J
I2 --> J
I3 --> J
Reviews (7): Last reviewed commit: "[Common/PyTorch] Grouped-quantize kernel..." | Re-trigger Greptile |
| } | ||
|
|
||
| CType amax = compute_row_amax<IType, CType, kVec>(in_vec[it]); | ||
| amax = fmaxf(amax, __shfl_xor_sync(0xffffffff, amax, 1)); |
There was a problem hiding this comment.
Could we reuse the existing amax warp-reduction helpers (warp_reduce_max() or reduce_max()) from transformer_engine/common/utils.cuh here?
| amax = fmaxf(amax, __shfl_xor_sync(0xffffffff, amax, 1)); | ||
| amax = fmaxf(amax, __shfl_xor_sync(0xffffffff, amax, 2)); | ||
| amax = fmaxf(amax, __shfl_xor_sync(0xffffffff, amax, 4)); |
There was a problem hiding this comment.
We can also reuse reduce_max() or warp_reduce_max() here.
|
|
||
| // ----- Host-side dispatchers -------------------------------------------------------------------- | ||
|
|
||
| inline size_t align_up_to(size_t x, size_t a) { return ((x + a - 1) / a) * a; } |
There was a problem hiding this comment.
We can reuse DIVUP_TO_MULTIPLE() defined in transformer_engine/common/common.h.
| NVTE_CHECK(info.tensor_offsets_d != nullptr, | ||
| "VARYING_FIRST_DIM requires tensor_offsets to be set on the GroupedTensor."); | ||
| } | ||
| info.total_row_blocks = (info.R_total + kTileDim - 1) / kTileDim; |
There was a problem hiding this comment.
| info.total_row_blocks = (info.R_total + kTileDim - 1) / kTileDim; | |
| info.total_row_blocks = DIVUP(info.R_total, kTileDim); |
| "VARYING_FIRST_DIM requires tensor_offsets to be set on the GroupedTensor."); | ||
| } | ||
| info.total_row_blocks = (info.R_total + kTileDim - 1) / kTileDim; | ||
| info.blocks_X = (info.K + kTileDim - 1) / kTileDim; |
There was a problem hiding this comment.
| info.blocks_X = (info.K + kTileDim - 1) / kTileDim; | |
| info.blocks_X = DIVUP(info.K, kTileDim); |
| info.same_both_dims = same_both_dims; | ||
| info.num_tensors = output->num_tensors; | ||
| info.K = output->get_common_last_dim(); | ||
| NVTE_CHECK(info.K % 16 == 0, "Last dim must be multiple of 16 (FP8 alignment)."); |
There was a problem hiding this comment.
If this is a TMA requirement, we can use the TMA_GMEM_ALIGNMENT constant defined in transformer_engine/common/common.h
| const float* noop_ptr = | ||
| (noop != nullptr) ? reinterpret_cast<const float*>(noop->data.dptr) : nullptr; | ||
|
|
||
| const size_t scale_stride_y = align_up_to(info.blocks_X, 4); |
There was a problem hiding this comment.
| const size_t scale_stride_y = align_up_to(info.blocks_X, 4); | |
| const size_t scale_stride_y = DIVUP_TO_MULTIPLE(info.blocks_X, 4); |
| const size_t scale_stride_y = align_up_to(info.blocks_X, 4); | ||
| // CW scales are stored [blocks_X, align4(total_row_blocks)] -- transposed to | ||
| // match the physically-transposed columnwise data the TN cuBLAS GEMM consumes. | ||
| const size_t scale_t_stride_y = align_up_to(info.total_row_blocks, 4); |
There was a problem hiding this comment.
| const size_t scale_t_stride_y = align_up_to(info.total_row_blocks, 4); | |
| const size_t scale_t_stride_y = DIVUP_TO_MULTIPLE(info.total_row_blocks, 4); |
| const float* noop_ptr = | ||
| (noop != nullptr) ? reinterpret_cast<const float*>(noop->data.dptr) : nullptr; | ||
|
|
||
| const size_t scale_stride_aligned_R = align_up_to(info.R_total, 4); |
There was a problem hiding this comment.
| const size_t scale_stride_aligned_R = align_up_to(info.R_total, 4); | |
| const size_t scale_stride_aligned_R = DIVUP_TO_MULTIPLE(info.R_total, 4); |
| (noop != nullptr) ? reinterpret_cast<const float*>(noop->data.dptr) : nullptr; | ||
|
|
||
| const size_t scale_stride_aligned_R = align_up_to(info.R_total, 4); | ||
| const size_t scale_t_stride_aligned_K = align_up_to(info.K, 4); |
There was a problem hiding this comment.
| const size_t scale_t_stride_aligned_K = align_up_to(info.K, 4); | |
| const size_t scale_t_stride_aligned_K = DIVUP_TO_MULTIPLE(info.K, 4); |
- Reuse shared helpers (DIVUP, DIVUP_TO_MULTIPLE, TMA_GMEM_ALIGNMENT, align_smem_ptr_per_TMA_requirements, get_current_tensor_id, subwarp_reduce_max_broadcast) in place of local equivalents. - Add proxy-async fence after mbarrier_init in 2D + 1D TMA kernels. - Enforce per-tensor first_dim % 128 device-side for VARYING_FIRST_DIM (matches MXFP8 grouped quantize behavior). - Fix Hopper SM range wording in 1D dispatcher. - Extend cpp tests to cover with_gemm_swizzled_scales path. Signed-off-by: Alp Dener <adener@nvidia.com>
| // num_tiles_X = DIVUP(total_row_blocks, TILE_DIM_X=4) | ||
| __device__ __forceinline__ size_t swizzled_colwise_scale_idx(size_t i, size_t j, | ||
| size_t total_row_blocks) { | ||
| using namespace transformer_engine::dispatch::mxfp8::swizzle; |
There was a problem hiding this comment.
I think we should rename the namespace for swizzle...given that we use the same constants for mxfp8, nvfp4, fp8 block scaling
| if get_device_compute_capability() < (10, 0): | ||
| return False | ||
| # 4. Output quantization is not supported. | ||
| if any(q is not None for q in output_quantizers): | ||
| return False | ||
| # 5. Filter by quantization recipes. | ||
| if fp8: | ||
| if not (10, 0) <= get_device_compute_capability() <= (11, 0): | ||
| return False | ||
| return all(isinstance(q, MXFP8Quantizer) for q in input_quantizers) or all( | ||
| isinstance(q, NVFP4Quantizer) and q.with_rht for q in input_quantizers | ||
| return ( | ||
| activation_dtype in (torch.bfloat16, torch.float16) | ||
| and all(isinstance(q, MXFP8Quantizer) for q in input_quantizers) | ||
| and all(isinstance(q, MXFP8Quantizer) for q in weight_quantizers) | ||
| and all(q is None or isinstance(q, MXFP8Quantizer) for q in grad_output_quantizers) | ||
| ) | ||
| return activation_dtype in (torch.bfloat16, torch.float16) |
There was a problem hiding this comment.
Hopper BF16/FP16 fused path silently disabled
The old code used two separate capability gates: >= (9,0) for all paths and >= (10,0) for FP8 only, so Hopper (CC 9.0) was valid for BF16/FP16. The new single gate < (10,0) returns False for Hopper unconditionally. Anyone who has NVTE_GROUPED_LINEAR_USE_FUSED_GROUPED_GEMM=1 on Hopper will silently fall back to the unfused loop with no warning — the docstring that previously advertised Hopper BF16/FP16 support was removed in the same commit.
The corresponding test skip condition was tightened from not (9,0) <= cc <= (11,0) to cc < (10,0), confirming the regression is intentional. If that is the case, a NVTE_WARN or a comment explaining why Hopper is excluded (e.g., cuBLAS limitation, untested, or actively broken) would prevent confusion for users and future contributors.
| if fp8: | ||
| if not (10, 0) <= get_device_compute_capability() <= (11, 0): | ||
| return False | ||
| return all(isinstance(q, MXFP8Quantizer) for q in input_quantizers) or all( | ||
| isinstance(q, NVFP4Quantizer) and q.with_rht for q in input_quantizers | ||
| return ( | ||
| activation_dtype in (torch.bfloat16, torch.float16) | ||
| and all(isinstance(q, MXFP8Quantizer) for q in input_quantizers) | ||
| and all(isinstance(q, MXFP8Quantizer) for q in weight_quantizers) | ||
| and all(q is None or isinstance(q, MXFP8Quantizer) for q in grad_output_quantizers) | ||
| ) |
There was a problem hiding this comment.
NVFP4+RHT fused grouped GEMM path silently removed
The previous fp8 branch accepted NVFP4Quantizer with with_rht=True, and NVFP4Quantizer was explicitly imported. Both the type check and the import are removed here, so Blackwell users who opted into NVTE_GROUPED_LINEAR_USE_FUSED_GROUPED_GEMM=1 with an NVFP4 recipe will silently fall back to the unfused per-expert loop. The entire test_grouped_linear_grouped_tensor_path_skips_non_rht_nvfp4 test was deleted in the same commit, confirming the removal is deliberate.
If this drop is permanent, a brief note in the PR description or a NVTE_WARN when NVFP4 is detected at this gate would prevent surprise for existing users of that configuration.
ef0ef0a to
bff55f6
Compare
…ling
Implements grouped-tensor quantize for the FP8 1D (1x128) and 2D (128x128)
block-scaling recipes in row-wise (RW), column-wise (CW) and BOTH quantization
directions. A single CUDA kernel launch walks 128x128 tiles across every tensor
in the group, with each CTA decoding its owning tensor from the device-side
GroupedTensor metadata with (N, R, K) shapes. Supports SAME_BOTH_DIMS (all
tensors identical) and VARYING_FIRST_DIM (constant K, varying R) shape
representations.
Three kernels share the dispatcher in group_quantize_blockwise_{1d,2d}:
- group_block_scaled_1d_rw_kernel: RW-only dispatch; 8 threads/row, reads
global memory directly into vec-16 registers; bypasses TMA since the
shared-memory roundtrip and ptx::mbarrier do not buy anything without
re-use in the CW path.
- group_block_scaled_1d_tma_kernel: CW-only and BOTH dispatch. TMA bulk-load
fills shared memory input cache. BOTH runs an RW pass (8 threads/row,
vec-16 read from shared memory) then a CW pass; CW-only skips the RW
pass. The CW pass uses 4 t/col with 32-row reg_data and two column passes
in the BOTH instantiation (keeps the per-thread register footprint under
the sm_90 3-CTAs/SM threshold) and 2 t/col with 64-row reg_data in the
CW-only instantiation (avoids doubling the smem-load bank-conflict
footprint that 4 t/col would introduce).
- group_block_scaled_2d_tma_kernel: RW-only, CW-only and BOTH dispatch. TMA
bulk-load fills shared memory input cache. Pass 1 stages 8 IVecs/thread
in registers while computing the per-tile scalar amax. Pass 2 quantizes
from registers, emits row-wise output, stages column-wise output to the
shared memory transpose staging buffer, then drains smem_T to global
memory.
Per-expert scale offsets:
- 1D RW: closed-form O(1) for both SAME_BOTH_DIMS and VARYING_FIRST_DIM
(each M_i is a multiple of kTileDim=128, hence of kScaleColAlign=4, so
DIVUP_TO_MULTIPLE collapses and the prefix sum reduces to a single
tensor_offsets_ptr[tensor_id]/K load).
- 2D CW: closed-form O(1) for SAME_BOTH_DIMS; CTA-cooperative warp-shuffle
prefix sum for VARYING_FIRST_DIM (non-linear DIVUP_TO_MULTIPLE on
blocks_y_t prevents a closed form). The cooperative reduction uses the
existing warp_allreduce_sum helper from common/utils.cuh.
Dequantize and bias-gradient (bgrad):
- group_dequantize_fp8_blockwise.cuh: kernels for all four modes
(1D/2D x rowwise/columnwise), inverting the per-expert layouts the
quantize kernels write.
- bgrad_group_quantize accepts Float8Block quantizers and computes dbias
per-tile column-partial in-kernel (mirroring MXFP8); reduced per expert
via the existing common::grouped_reduce_dbias.
Scale constraints: the fused grouped FP8BS path supports only unconstrained
FP32 scales (Float8BlockQuantizer::create_grouped_tensor rejects
force_pow_2_scales=True). Power-of-2 scales remain available on the
non-grouped/unfused split-quantize path used for Blackwell MXFP8 emulation.
Tests: existing parametrized grouped quantize / dequantize / bgrad tests
in test_grouped_tensor.py cover MXFP8, NVFP4, FP8 current scaling and the
newly-added FP8 block scaling recipe. tests/cpp/operator/
test_cast_float8blockwise_grouped.cu adds 72 C++ unit-test cases over
uniform/jagged shapes, all four (BD x direction) modes, K in {128, 256,
512}, and CUDA-graph capture coverage.
Kernels are gated to Hopper (sm_90) at the host dispatcher (cuBlasLt
grouped GEMM supports FP8 block-scaling only on Hopper).
JAX integration is intentionally left out of scope and deferred to a
follow-up PR.
Resolves NVIDIA#2525
Signed-off-by: Alp Dener <adener@nvidia.com>
e153c9b to
dc998bd
Compare
Description
Implements grouped-tensor quantize for the FP8 1D (1x128) and 2D (128x128) block-scaling recipes in row-wise (RW), column-wise (CW) and BOTH quantization directions. A single CUDA kernel launch walks 128x128 tiles across every tensor in the group, with each CTA decoding its owning tensor from the device-side GroupedTensor metadata with (N, R, K) shapes. Supports
SAME_BOTH_DIMS(all tensors identical) andVARYING_FIRST_DIM(constant K, varying R) shape representations.Three kernels share the dispatcher in
group_quantize_blockwise_{1d,2d}:group_block_scaled_1d_rw_kernel— RW-only dispatch; 8 threads/row, reads global memory directly into vec-16 registers; bypasses TMA because the shared memory roundtrip andptx::mbarrierdoes not buy anything without re-use in CW path.group_block_scaled_1d_tma_kernel— CW-only and BOTH dispatch; TMA bulk-load fills shared memory input cache. BOTH runs RW pass first (8 threads/row, vec-16 read from shared memory) then CW pass (2 threads/column, 64-row register stage); CW-only skips the RW pass. CW path writes the transposed-FP8 tile to a shared memory transpose staging buffer, then drains to global memory.group_block_scaled_2d_tma_kernel— RW-only, CW-only and BOTH dispatch; TMA bulk-load fills shared memory input cache. Pass 1 stages 8 IVecs/thread in registers while computing the per-tile scalar amax. Pass 2 quantizes from registers, emits row-wise output, stages column-wise output to shared memory transpose staging buffer, then drains to global memory.Kernels are gated to Hopper (sm_90) at the host dispatcher (cuBlasLt grouped GEMM supports FP8 block-scaling only on Hopper).
PR includes PyTorch integration into te.GroupedTensor only.
PyTorch integration into te.GroupedLinear and JAX integration deferred to a follow-up PRs.
Partially resolves #2525
Performance
Benchmark on H200 with a sweep of grouped tensors in (N, M, K) shapes:
Two shape families per config:
SAME_BOTH_DIMS): all experts share the (M, K) shapeVARYING_FIRST_DIM): per-expert M drawn from an imbalanced routing, common KBuckets:
Bucket medians across 3 reps. Speedup is grouped vs the split-quantized fallback that loops over the grouped tensor and quantizes each constituent sequentially. % mono is grouped throughput relative to a single non-grouped FP8 block-scaling quantize on the equivalent monolithic (N·M, K) tensor.
Notes
Known Sub-Optimalities
1D CW load bank conflicts on ~35% of load wavefronts (reading from the shared memory input-cache)
CU_TENSOR_MAP_SWIZZLE_128Bhas the right pattern but caps FP16/BF16 at 64-elements; does not fit the 128-element tile for FP8 block-scaling without doubling per-tile launch overhead (quadrupling for FP32).1D BOTH reads the shared memory input-cache twice
2D CW/BOTH has bank conflicts on ~16% of store wavefronts (when writing to the shared memory transpose buffer)
No TMA-store
Type of change
Checklist: