-
Notifications
You must be signed in to change notification settings - Fork 624
Add NVTE_KEEP_BACKWARD_UNQUANTIZED #2644
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Ziang Li <ziangli@umich.edu>
for more information, see https://pre-commit.ci
Greptile OverviewGreptile SummaryThis PR adds a new environment variable Key changes:
Trade-offs:
Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant Env as Environment Variable
participant FP8Manager as FP8GlobalStateManager
participant Module as Linear/LayerNormLinear/MLP
participant Forward as Forward Pass
participant Backward as Backward Pass
User->>Env: Set NVTE_KEEP_BACKWARD_UNQUANTIZED=1
User->>Module: Forward pass with FP8 enabled
Module->>FP8Manager: keep_backward_unquantized()
FP8Manager->>FP8Manager: Check recipe.delayed()
alt Delayed Scaling Recipe
FP8Manager-->>Module: Return False (ignore env var)
else Other Recipes
FP8Manager-->>Module: Return True
end
alt keep_backward_unquantized=True
Module->>Forward: Quantize input to FP8
Forward->>Forward: Compute with FP8
Forward->>Forward: Save high-precision copy (ln_out_hp, act_out_hp)
Forward->>Forward: Disable columnwise quantization
Forward-->>Module: FP8 output
Module->>Backward: Start backward pass
Backward->>Backward: Use high-precision saved tensors
Backward->>Backward: Compute dgrad/wgrad without FP8 quantization
Backward->>Backward: Disable Userbuffers communication
Backward-->>Module: High-precision gradients
else keep_backward_unquantized=False
Module->>Forward: Quantize input to FP8
Forward->>Forward: Compute with FP8
Forward->>Forward: Save FP8 quantized tensors
Forward-->>Module: FP8 output
Module->>Backward: Start backward pass
Backward->>Backward: Use FP8 quantized tensors
Backward->>Backward: Compute dgrad/wgrad with FP8
Backward-->>Module: FP8 gradients
end
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
6 files reviewed, no comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
6 files reviewed, no comments
|
I'll work on potential unit test breakage. |
| FP8GlobalStateManager.is_fp8_enabled() | ||
| and FP8GlobalStateManager.keep_backward_unquantized() | ||
| ) | ||
| if func_ctx.is_first_module and not keep_backward_unquantized and not _is_graph_capturing(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the reduce_and_update_fp8_tensors stuff is for delayed scaling, in which case we can just ignore it
Signed-off-by: Ziang Li <ziangli@umich.edu>
| keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False) | ||
| use_fp8_bwd = ctx.fp8 and not keep_backward_unquantized | ||
| use_quantized_bwd = use_fp8_bwd or ctx.debug | ||
| if keep_backward_unquantized: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this shouldn't be related?
edit: disabling user-buffer when mixing fp8 & bf16 in one layer makes sense here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 files reviewed, no comments
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 files reviewed, 4 comments
| ln_out_return = None | ||
| if return_layernorm_output or return_layernorm_output_gathered: | ||
| ln_out_return = ln_out |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
storing both ln_out (quantized) and ln_out_hp (high precision) significantly increases memory usage
verify this memory overhead is acceptable for large models, especially during training
|
|
||
| # Prepare GEMM input | ||
| # Note: Cast to expected dtype and perform tensor-parallel communication |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
recomputing activation (activation_func(fc1_out, None, **act_params)) adds compute overhead
consider whether this is necessary or if the quantized act_out could be reused in some cases
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
Signed-off-by: Ziang Li <ziangli@umich.edu>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
4 files reviewed, 1 comment
Signed-off-by: Ziang Li <ziangli@umich.edu>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 files reviewed, 1 comment
… is used Signed-off-by: Ziang Li <ziangli@umich.edu>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 files reviewed, 2 comments
| if keep_backward_unquantized and is_grad_enabled and fc1_out is not None: | ||
| act_out_hp = activation_func(fc1_out, None, **act_params) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
recomputing activation_func(fc1_out, None, **act_params) adds compute overhead that could be significant for expensive activations like GELU
if fc1_out is already available (which it is on line 620), storing act_out in high precision during the initial computation would avoid redundant activation calculations
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this part of logic looks a bit off for me too, since the activation function have already quantized the output in fusion, so act_out is already fp8 so we cannot save the high precision of that for backward.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're correct. Looking at the flow:
- When
fc1_out_quantizeris set,activation_funcreturns quantized FP8 output inact_out - At line 620,
act_outis already FP8, so there's nothing high-precision to save - The recomputation on line 621 is the only way to get the high-precision activation values
The logic is actually correct - you need the recomputation because the original forward pass discarded the high-precision values when it quantized. The variable name act_out_hp makes this clearer (hp = high precision).
However, the overhead concern from my previous comment stands. A potential optimization would be to conditionally store high-precision activations during the initial forward pass when keep_backward_unquantized=True, rather than always quantizing and then recomputing. This would trade memory for compute - but that might be acceptable since this feature is opt-in via environment variable.
| ln_out_return = None | ||
| if return_layernorm_output or return_layernorm_output_gathered: | ||
| ln_out_return = ln_out | ||
| ln_out_hp = ln_out if keep_backward_unquantized else None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
storing both ln_out (quantized) and ln_out_hp (high precision) doubles the memory footprint for this activation
verify this memory overhead is acceptable for your target models, especially during training with large batch sizes or long sequences
Signed-off-by: Ziang Li <ziangli@umich.edu>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
4 files reviewed, no comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 files reviewed, no comments
Signed-off-by: Ziang Li <ziangli@umich.edu>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 files reviewed, no comments
Signed-off-by: Ziang Li <ziangli@umich.edu>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 files reviewed, no comments
|
|
||
| ctx.owns_input = saved_inputmat is not inp | ||
| if ctx.fp8 and requires_grad(inp, weight, bias): | ||
| if ctx.fp8 and not ctx.keep_backward_unquantized and requires_grad(inp, weight, bias): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: we can remove this line since it's about delayed scaling?
| # Note: dgrad GEMM requires row-wise usage, wgrad GEMM | ||
| # requires column-wise usage | ||
| if ctx.grad_output_quantizer is not None: | ||
| if ctx.grad_output_quantizer is not None and use_fp8_bwd: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this line seems redundant since you already skip the quantization step in base.py grad_output_preprocess?
| not ctx.use_bias | ||
| and not ctx.requires_wgrad | ||
| and ctx.grad_output_quantizer is not None | ||
| and use_fp8_bwd |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same comment as above
| if ctx.fp8 and requires_grad(inp, weights[0], biases[0]): | ||
| if ( | ||
| ctx.fp8 | ||
| and not ctx.keep_backward_unquantized |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same comment with linear.py, this seems to be delayed scaling only, can revert/ignore
| recipe = cls.get_fp8_recipe() | ||
| if recipe is not None and recipe.delayed(): | ||
| # Ignore NVTE_KEEP_BACKWARD_UNQUANTIZED when delayed scaling is used | ||
| return False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe it's better to assert an error for delayed scaling? Okay with both.
| if ctx.fp8 and requires_grad(inp, ln_weight, ln_bias, weight, bias): | ||
| if ( | ||
| ctx.fp8 | ||
| and not ctx.keep_backward_unquantized |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same comment
| # Note: dgrad GEMM requires row-wise usage, wgrad GEMM | ||
| # requires column-wise usage | ||
| if ctx.grad_output_quantizer is not None: | ||
| if ctx.grad_output_quantizer is not None and use_fp8_bwd: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this seems redundant too if we skip quant in grad_output_preprocess
| # make sure required data is available | ||
| if ctx.ub_overlap_ag and isinstance(ctx.grad_output_quantizer, MXFP8Quantizer): | ||
| if ( | ||
| use_fp8_bwd |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since we already disabled ub above, this should also be redundant?
|
|
||
| # This object is separate from the ub_obj_wgrad object which is passed to the GEMM | ||
| ub_obj_overlap_wgrad = get_ub(ctx.ub_name + "_wgrad", ctx.fp8) | ||
| ub_obj_overlap_wgrad = get_ub(ctx.ub_name + "_wgrad", use_fp8_bwd) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same
| # or 2) doing the recomputation with checkpointing | ||
| backwards_needs_fc1_input = fc1_weight.requires_grad and ( | ||
| (is_grad_enabled and not checkpoint) or is_recomputation | ||
| backwards_needs_fc1_input = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
backwards_needs_fc1_input should be orthogonal with keep_backward_unquantized?
| inp, ln_weight, ln_bias, fc1_weight, fc2_weight, fc1_bias, fc2_bias | ||
| if ( | ||
| ctx.fp8 | ||
| and not ctx.keep_backward_unquantized |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same old delayed scaling case
|
|
||
| keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False) | ||
| use_fp8_bwd = ctx.fp8 and not keep_backward_unquantized | ||
| fp8_recipe_bwd = ctx.fp8_recipe if use_fp8_bwd else None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit, this fp8_recipe_bwd shouldn't be needed and use_fp8_bwd flag is enough
| # Note: dgrad GEMM requires row-wise usage, wgrad GEMM | ||
| # requires column-wise usage | ||
| if ctx.fc2_grad_output_quantizer is not None: | ||
| if ctx.fc2_grad_output_quantizer is not None and use_fp8_bwd: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same nit about grad_output_preprocess already skip quant
| # Whether to set grad arg in general_gemm | ||
| grad_arg = True | ||
| if ctx.fp8 and ctx.fp8_recipe.float8_block_scaling(): | ||
| if use_fp8_bwd and fp8_recipe_bwd.float8_block_scaling(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just use ctx.fp8_recipe should be better (no strong opinion about this one)
| if ctx.ub_overlap_rs_dgrad: | ||
| # Overlap DGRAD+RS | ||
| ub_obj_fc1_dgrad = get_ub("fc1_dgrad", ctx.fp8) | ||
| ub_obj_fc1_dgrad = get_ub("fc1_dgrad", use_fp8_bwd) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since ub is already disabled, this part should also be redundant
Description
@HumansAnd
Add an NVTE_KEEP_BACKWARD_UNQUANTIZED env var for quantized fprop + high precision wgrad & dgrad.
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: