-
Notifications
You must be signed in to change notification settings - Fork 624
Fix FP8 block scaling with sequence parallel #2637
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Chen Cui <chcui@nvidia.com>
Greptile OverviewGreptile SummaryThis PR fixes an Key Changes:
How It Works:
This allows sequence parallel to work with Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant Module as LayerNorm/Linear Module
participant Dist as gather_along_first_dim
participant FP8Block as _start_all_gather_fp8_blockwise
participant NCCL as torch.distributed
Note over Module: Before: assert_dim_for_all_gather<br/>would fail if dims not divisible by 128
Module->>Dist: gather_along_first_dim(inp, quantizer)
Dist->>FP8Block: _start_all_gather_fp8_blockwise()
alt Input not quantizable (dims % 128 != 0)
FP8Block->>FP8Block: warnings.warn("Cannot quantize...")
alt Input already quantized
FP8Block->>FP8Block: inp = inp.dequantize()
end
FP8Block->>NCCL: all_gather_into_tensor(high-precision)
NCCL-->>FP8Block: gathered tensor
FP8Block->>FP8Block: out = quantizer(out)
FP8Block-->>Dist: return quantized output
else Input is quantizable
FP8Block->>FP8Block: Quantize input if needed
FP8Block->>NCCL: all_gather_into_tensor(FP8 blockwise)
NCCL-->>FP8Block: gathered FP8 tensor
FP8Block-->>Dist: return FP8 output
end
Dist-->>Module: return gathered tensor
Note over Module: Now works with sequence parallel<br/>even when dims not divisible by 128
|
This comment was marked as outdated.
This comment was marked as outdated.
Perform all-gather in high-precision if the input tensor is too small to quantize. Signed-off-by: Tim Moon <tmoon@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 files reviewed, no comments
timmoon10
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've tweaked the module logic so that it no longer checks that the local tensor can be quantized. In addition to FP8 block scaling, MXFP8 and NVFP4 will also support high-precision all-gathers. That said, this may not be fully robust for LayerNormLinear and LayerNormMLP, since the norm kernel may still expect to output to a quantized local tensor.
This comment was marked as outdated.
This comment was marked as outdated.
|
/te-ci pytorch L1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 files reviewed, 3 comments
| out = torch.empty(out_shape, dtype=dtype, device=device) | ||
| warnings.warn("Cannot quantize input tensor. Performing all-gather in high precision.") | ||
| if isinstance(inp, QuantizedTensorStorage): | ||
| inp = inp.dequantize() # Dequantize if needed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The dequantize() method defaults to dtype=torch.float32. Consider whether this is always appropriate for the fallback path, especially when the original tensor might have been in a different precision (e.g., bfloat16).
| ): | ||
| warnings.warn("Cannot quantize input tensor. Performing all-gather in high precision.") | ||
| if isinstance(inp, QuantizedTensorStorage): | ||
| inp = inp.dequantize() # Dequantize if needed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same consideration as in _start_all_gather_fp8_blockwise: the dequantize() method defaults to dtype=torch.float32, which may not match the original tensor's precision.
| ): | ||
| warnings.warn("Cannot quantize input tensor. Performing all-gather in high precision.") | ||
| if isinstance(inp, QuantizedTensorStorage): | ||
| inp = inp.dequantize() # Dequantize if needed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same consideration as in _start_all_gather_fp8_blockwise: the dequantize() method defaults to dtype=torch.float32, which may not match the original tensor's precision.
Description
Problem
Using Float8BlockQuantizer with sequence parallel fails with
AssertionError: All-gather requires quantizable tensor for quantizer Float8BlockQuantizerwhen local tensor dimensions aren't divisible by 128.Solution
Skip the assert_dim_for_all_gather check for Float8BlockQuantizer since gather_along_first_dim already has a fallback path
Fix the fallback in _start_all_gather_fp8_blockwise to handle already-quantized inputs by dequantizing before high-precision all-gather
###Note
The fallback path (high-precision all-gather → quantize) may increase the communication overhead.
Verification
The code change does not alter convergence behavior

When SP is True, the previous code did not run. When SP is False, this change doesn't affect anything.

Type of change
Changes
Please list the changes introduced in this PR:
Checklist: