Skip to content

Conversation

@cuichenx
Copy link
Contributor

@cuichenx cuichenx commented Jan 31, 2026

Description

Problem

Using Float8BlockQuantizer with sequence parallel fails with AssertionError: All-gather requires quantizable tensor for quantizer Float8BlockQuantizer when local tensor dimensions aren't divisible by 128.

Solution

Skip the assert_dim_for_all_gather check for Float8BlockQuantizer since gather_along_first_dim already has a fallback path
Fix the fallback in _start_all_gather_fp8_blockwise to handle already-quantized inputs by dequantizing before high-precision all-gather

###Note
The fallback path (high-precision all-gather → quantize) may increase the communication overhead.

Verification

The code change does not alter convergence behavior
image

When SP is True, the previous code did not run. When SP is False, this change doesn't affect anything.
image

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Chen Cui <chcui@nvidia.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 31, 2026

Greptile Overview

Greptile Summary

This PR fixes an AssertionError that occurred when using Float8BlockQuantizer with sequence parallel and tensor dimensions that aren't divisible by 128.

Key Changes:

  • Removed the assert_dim_for_all_gather function from utils.py which was blocking execution when tensor dimensions weren't quantizable
  • Removed calls to assert_dim_for_all_gather from layernorm_linear.py, layernorm_mlp.py, and linear.py modules
  • Enhanced fallback logic in distributed.py for _start_all_gather_fp8_blockwise, _all_gather_nvfp4, and _all_gather_mxfp8 to handle already-quantized inputs by dequantizing before high-precision all-gather
  • Changed fallback path to use inp.dtype and inp.device after dequantization, ensuring correct dtype handling

How It Works:
When tensor dimensions aren't quantizable (not divisible by 128), the code now:

  1. Warns the user about falling back to high-precision all-gather
  2. Dequantizes the input if it's already in a quantized format
  3. Performs all-gather in high precision
  4. Re-quantizes the gathered output

This allows sequence parallel to work with Float8BlockQuantizer even when local tensor dimensions don't meet the block size requirements, though with increased communication overhead as noted in the PR description.

Confidence Score: 4/5

  • Safe to merge with minor consideration for the dtype handling in fallback path
  • The changes correctly address the issue by removing blocking assertions and adding proper fallback logic. The dequantization handling is sound, and the use of inp.dtype and inp.device after dequantization is more correct than using pre-extracted values. The only minor concern is that the default dequantization dtype (float32) might differ from the original high-precision dtype in some edge cases, but this appears intentional and acceptable for the fallback path.
  • transformer_engine/pytorch/distributed.py - verify dtype handling after dequantization matches expectations

Important Files Changed

Filename Overview
transformer_engine/pytorch/distributed.py Added fallback logic for handling already-quantized inputs when dimensions aren't quantizable, with warnings and proper dequantization
transformer_engine/pytorch/utils.py Removed assert_dim_for_all_gather function since fallback paths in gather_along_first_dim handle non-quantizable dimensions

Sequence Diagram

sequenceDiagram
    participant Module as LayerNorm/Linear Module
    participant Dist as gather_along_first_dim
    participant FP8Block as _start_all_gather_fp8_blockwise
    participant NCCL as torch.distributed
    
    Note over Module: Before: assert_dim_for_all_gather<br/>would fail if dims not divisible by 128
    Module->>Dist: gather_along_first_dim(inp, quantizer)
    Dist->>FP8Block: _start_all_gather_fp8_blockwise()
    
    alt Input not quantizable (dims % 128 != 0)
        FP8Block->>FP8Block: warnings.warn("Cannot quantize...")
        alt Input already quantized
            FP8Block->>FP8Block: inp = inp.dequantize()
        end
        FP8Block->>NCCL: all_gather_into_tensor(high-precision)
        NCCL-->>FP8Block: gathered tensor
        FP8Block->>FP8Block: out = quantizer(out)
        FP8Block-->>Dist: return quantized output
    else Input is quantizable
        FP8Block->>FP8Block: Quantize input if needed
        FP8Block->>NCCL: all_gather_into_tensor(FP8 blockwise)
        NCCL-->>FP8Block: gathered FP8 tensor
        FP8Block-->>Dist: return FP8 output
    end
    
    Dist-->>Module: return gathered tensor
    Note over Module: Now works with sequence parallel<br/>even when dims not divisible by 128
Loading

greptile-apps[bot]

This comment was marked as outdated.

@cyanguwa cyanguwa requested a review from timmoon10 February 2, 2026 18:48
@timmoon10

This comment was marked as outdated.

timmoon10

This comment was marked as outdated.

greptile-apps[bot]

This comment was marked as outdated.

@timmoon10 timmoon10 self-requested a review February 2, 2026 19:38
Perform all-gather in high-precision if the input tensor is too small to quantize.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've tweaked the module logic so that it no longer checks that the local tensor can be quantized. In addition to FP8 block scaling, MXFP8 and NVFP4 will also support high-precision all-gathers. That said, this may not be fully robust for LayerNormLinear and LayerNormMLP, since the norm kernel may still expect to output to a quantized local tensor.

@timmoon10

This comment was marked as outdated.

@timmoon10
Copy link
Collaborator

/te-ci pytorch L1

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

out = torch.empty(out_shape, dtype=dtype, device=device)
warnings.warn("Cannot quantize input tensor. Performing all-gather in high precision.")
if isinstance(inp, QuantizedTensorStorage):
inp = inp.dequantize() # Dequantize if needed
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The dequantize() method defaults to dtype=torch.float32. Consider whether this is always appropriate for the fallback path, especially when the original tensor might have been in a different precision (e.g., bfloat16).

):
warnings.warn("Cannot quantize input tensor. Performing all-gather in high precision.")
if isinstance(inp, QuantizedTensorStorage):
inp = inp.dequantize() # Dequantize if needed
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same consideration as in _start_all_gather_fp8_blockwise: the dequantize() method defaults to dtype=torch.float32, which may not match the original tensor's precision.

):
warnings.warn("Cannot quantize input tensor. Performing all-gather in high precision.")
if isinstance(inp, QuantizedTensorStorage):
inp = inp.dequantize() # Dequantize if needed
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same consideration as in _start_all_gather_fp8_blockwise: the dequantize() method defaults to dtype=torch.float32, which may not match the original tensor's precision.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants