-
Notifications
You must be signed in to change notification settings - Fork 624
Add NVTE_KEEP_BACKWARD_UNQUANTIZED #2644
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
3afce1f
72149be
3e6eb64
927d482
cc85b60
fe24f95
5ca3615
5ba7674
02b7b2a
01a7de0
bf904aa
b449fc4
de3acaf
fe65d34
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -141,6 +141,7 @@ def forward( | |
| symmetric_ar_type, | ||
| debug, | ||
| ) = non_tensor_args | ||
| keep_backward_unquantized = fp8 and FP8GlobalStateManager.keep_backward_unquantized() | ||
|
|
||
| # NVTX label for profiling | ||
| nvtx_label = "transformer_engine._LayerNormLinear.forward" | ||
|
|
@@ -200,7 +201,10 @@ def forward( | |
| if fp8: | ||
| if input_quantizer is None: | ||
| raise ValueError("Missing quantizer for input tensor") | ||
| input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input) | ||
| input_quantizer.set_usage( | ||
| rowwise=True, | ||
| columnwise=backward_needs_input and not keep_backward_unquantized, | ||
| ) | ||
| if with_input_all_gather and input_quantizer.supports_only_rowwise_all_gather(): | ||
| # All-gather is not supported with FP8 column-wise data | ||
| input_quantizer.set_usage(columnwise=False) | ||
|
|
@@ -213,6 +217,7 @@ def forward( | |
| and not debug | ||
| and not return_layernorm_output | ||
| and not return_layernorm_output_gathered | ||
| and not keep_backward_unquantized | ||
| and not custom # TODO(negvet): and not FP8GlobalStateManager.get_fp8_recipe().custom() | ||
| ) | ||
|
|
||
|
|
@@ -236,6 +241,7 @@ def forward( | |
| ln_out_return = None | ||
| if return_layernorm_output or return_layernorm_output_gathered: | ||
| ln_out_return = ln_out | ||
|
Comment on lines
241
to
243
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. storing both verify this memory overhead is acceptable for large models, especially during training |
||
| ln_out_hp = ln_out if keep_backward_unquantized else None | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. storing both verify this memory overhead is acceptable for your target models, especially during training with large batch sizes or long sequences |
||
|
|
||
| # ------------------------------------------------------ | ||
| # Prepare GEMM input tensor | ||
|
|
@@ -409,13 +415,16 @@ def forward( | |
| # ------------------------------------------------------ | ||
|
|
||
| if is_grad_enabled: | ||
| ln_out_to_save = ln_out | ||
| if keep_backward_unquantized: | ||
| ln_out_to_save = ln_out_hp | ||
| ctx.weight_quantizer = weight_quantizer | ||
| ctx.ln_out_needs_gather = ( | ||
| weight.requires_grad and parallel_mode == "column" and sequence_parallel | ||
| ) | ||
|
|
||
| # Input with column-wise usage is needed for wgrad GEMM. | ||
| if backward_needs_input: | ||
| if backward_needs_input and not keep_backward_unquantized: | ||
| if isinstance(ln_out, QuantizedTensorStorage): | ||
| # For sequence parallel in vanilla FP8, rowwise data is | ||
| # to gather the input. For MXFP8, columnwise only data | ||
|
|
@@ -427,7 +436,7 @@ def forward( | |
| ln_out.update_usage(rowwise_usage=False) | ||
|
|
||
| if cpu_offloading: | ||
| mark_activation_offload(inputmat, mu, rsigma, ln_out) | ||
| mark_activation_offload(inputmat, mu, rsigma, ln_out_to_save) | ||
|
|
||
| # Scatter intermediate/activation tensors saved for the backward pass | ||
| # NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already | ||
|
|
@@ -439,7 +448,7 @@ def forward( | |
| mu, | ||
| rsigma, | ||
| weightmat if fp8 and not is_weight_param_quantized else None, | ||
| ln_out if weight.requires_grad else None, | ||
| ln_out_to_save if weight.requires_grad else None, | ||
| ) | ||
| nvtx_range_pop(f"{nvtx_label}.fsdp_scatter") | ||
|
|
||
|
|
@@ -466,7 +475,7 @@ def forward( | |
| weight, | ||
| bias, | ||
| ln_weight, | ||
| ln_out, | ||
| ln_out_to_save, | ||
| mu, | ||
| rsigma, | ||
| ) | ||
|
|
@@ -493,6 +502,7 @@ def forward( | |
| ctx.activation_dtype = activation_dtype | ||
| ctx.fp8 = fp8 | ||
| ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None | ||
| ctx.keep_backward_unquantized = keep_backward_unquantized | ||
| ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation | ||
| ctx.cpu_offloading = cpu_offloading | ||
| ctx.is_first_microbatch = is_first_microbatch | ||
|
|
@@ -515,7 +525,11 @@ def forward( | |
| ctx.requires_dgrad = inp_requires_grad | ||
| ctx.normalization = normalization | ||
| ctx.reduce_and_update_bwd_fp8_tensors = False | ||
| if ctx.fp8 and requires_grad(inp, ln_weight, ln_bias, weight, bias): | ||
| if ( | ||
| ctx.fp8 | ||
| and not ctx.keep_backward_unquantized | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same comment |
||
| and requires_grad(inp, ln_weight, ln_bias, weight, bias) | ||
| ): | ||
| _first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE | ||
| ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module() | ||
| if in_fp8_activation_recompute_phase(): | ||
|
|
@@ -592,6 +606,15 @@ def backward( | |
| if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation: | ||
| origin_weight.main_grad = main_grad | ||
|
|
||
| keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False) | ||
| use_fp8_bwd = ctx.fp8 and not keep_backward_unquantized | ||
| if keep_backward_unquantized: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this shouldn't be related? edit: disabling user-buffer when mixing fp8 & bf16 in one layer makes sense here |
||
| # Disable Userbuffers communication for backward pass when keep_backward_unquantized is True | ||
| ctx.ub_overlap_ag = False | ||
| ctx.ub_overlap_rs_dgrad = False | ||
| ctx.ub_bulk_dgrad = False | ||
| ctx.ub_bulk_wgrad = False | ||
|
|
||
| # Configure Userbuffers communication (comm+GEMM overlap) | ||
| ctx.ub_obj_gradout = None | ||
| ub_obj_dgrad = None | ||
|
|
@@ -628,7 +651,7 @@ def backward( | |
| # Configure quantizer for grad output tensor | ||
| # Note: dgrad GEMM requires row-wise usage, wgrad GEMM | ||
| # requires column-wise usage | ||
| if ctx.grad_output_quantizer is not None: | ||
| if ctx.grad_output_quantizer is not None and use_fp8_bwd: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this seems redundant too if we skip quant in grad_output_preprocess |
||
| quantizer = ctx.grad_output_quantizer | ||
| quantizer.set_usage(rowwise=True, columnwise=True) | ||
| if ctx.ub_overlap_ag: | ||
|
|
@@ -665,7 +688,7 @@ def backward( | |
| ln_out_total_work = None | ||
| if ctx.ln_out_needs_gather: | ||
| quantizer = None | ||
| if ctx.input_quantizer is not None: | ||
| if ctx.input_quantizer is not None and use_fp8_bwd: | ||
| quantizer = ctx.input_quantizer | ||
| if quantizer.supports_only_rowwise_all_gather(): | ||
| # If data is in FP8, we compute FP8 transposes manually | ||
|
|
@@ -703,18 +726,22 @@ def backward( | |
| # Make sure required data is available | ||
| if isinstance(grad_output, QuantizedTensorStorage): | ||
| grad_output.update_usage(rowwise_usage=True) | ||
| if ctx.weight_quantizer is not None and isinstance(weight, QuantizedTensorStorage): | ||
| if ( | ||
| use_fp8_bwd | ||
| and ctx.weight_quantizer is not None | ||
| and isinstance(weight, QuantizedTensorStorage) | ||
| ): | ||
| weight.update_usage(columnwise_usage=True) | ||
|
|
||
| # Choose whether to use GEMM kernel with split accumulator | ||
| use_split_accumulator = _2X_ACC_DGRAD | ||
| if ctx.fp8: | ||
| if use_fp8_bwd: | ||
| recipe = ctx.fp8_recipe | ||
| if hasattr(recipe, "fp8_gemm_dgrad"): | ||
| use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator | ||
|
|
||
| # Update grad input quantizer | ||
| if ctx.grad_input_quantizer is not None: | ||
| if ctx.grad_input_quantizer is not None and use_fp8_bwd: | ||
| ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False) | ||
|
|
||
| # Output buffers for Userbuffers reduce-scatter | ||
|
|
@@ -730,12 +757,15 @@ def backward( | |
| # dgrad GEMM | ||
| # Note: dx = dy * w | ||
| nvtx_range_push(f"{nvtx_label}.dgrad_gemm") | ||
| weight_for_dgrad = weight | ||
| if keep_backward_unquantized: | ||
| weight_for_dgrad = origin_weight | ||
| gemm_out, *_, reduce_scatter_out = general_gemm( | ||
| weight, | ||
| weight_for_dgrad, | ||
| grad_output, | ||
| layout="NN", | ||
| grad=True, | ||
| quantization_params=ctx.grad_input_quantizer, | ||
| quantization_params=ctx.grad_input_quantizer if use_fp8_bwd else None, | ||
| out=gemm_out, | ||
| out_dtype=ctx.activation_dtype, | ||
| use_split_accumulator=use_split_accumulator, | ||
|
|
@@ -782,7 +812,11 @@ def backward( | |
| # Prepare grad output tensor | ||
| # Note: Synchronize tensor-parallel communication and | ||
| # make sure required data is available | ||
| if ctx.ub_overlap_ag and isinstance(ctx.grad_output_quantizer, MXFP8Quantizer): | ||
| if ( | ||
| use_fp8_bwd | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. since we already disabled ub above, this should also be redundant? |
||
| and ctx.ub_overlap_ag | ||
| and isinstance(ctx.grad_output_quantizer, MXFP8Quantizer) | ||
| ): | ||
| # UB does not support pipelined overlapping grad output | ||
| # all-gather with wgrad GEMM. Also, we can't | ||
| # convert row-scaled MXFP8 to column-scaled, so we | ||
|
|
@@ -794,7 +828,7 @@ def backward( | |
| dgrad_send_stream, dgrad_recv_stream = ub_obj_dgrad.get_communication_stream() | ||
|
|
||
| # This object is separate from the ub_obj_wgrad object which is passed to the GEMM | ||
| ub_obj_overlap_wgrad = get_ub(ctx.ub_name + "_wgrad", ctx.fp8) | ||
| ub_obj_overlap_wgrad = get_ub(ctx.ub_name + "_wgrad", use_fp8_bwd) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same |
||
|
|
||
| ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True) | ||
|
|
||
|
|
@@ -820,14 +854,14 @@ def backward( | |
| if ln_out_total_work is not None: | ||
| ln_out_total_work.wait() | ||
| ln_out_total_work = None | ||
| if ctx.fp8 or ctx.debug: | ||
| if use_fp8_bwd or ctx.debug: | ||
| if isinstance(ln_out_total, QuantizedTensorStorage): | ||
| ln_out_total.update_usage(columnwise_usage=True) | ||
| else: | ||
| ctx.input_quantizer.set_usage(rowwise=False, columnwise=True) | ||
| ln_out_total = ctx.input_quantizer(ln_out_total) | ||
|
|
||
| if ctx.fp8 or ctx.debug: | ||
| if use_fp8_bwd or ctx.debug: | ||
| if isinstance(grad_output, QuantizedTensorStorage): | ||
| grad_output.update_usage(columnwise_usage=True) | ||
| else: | ||
|
|
@@ -836,7 +870,7 @@ def backward( | |
|
|
||
| # Figure out whether to use split accumulator | ||
| use_split_accumulator = _2X_ACC_WGRAD | ||
| if ctx.fp8: | ||
| if use_fp8_bwd: | ||
| recipe = ctx.fp8_recipe | ||
| if hasattr(recipe, "fp8_gemm_wgrad"): | ||
| use_split_accumulator = recipe.fp8_gemm_wgrad.use_split_accumulator | ||
|
|
@@ -862,15 +896,15 @@ def backward( | |
| "out_dtype": ( | ||
| main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype | ||
| ), | ||
| "quantization_params": ctx.grad_weight_quantizer, | ||
| "quantization_params": (ctx.grad_weight_quantizer if use_fp8_bwd else None), | ||
| "accumulate": ( | ||
| accumulate_wgrad_into_param_main_grad | ||
| if not getattr(weight, "overwrite_main_grad", False) | ||
| else False | ||
| ), | ||
| "layout": "NT", | ||
| "out": main_grad if ctx.fuse_wgrad_accumulation else None, | ||
| "bias": (bias if (grad_bias is None and not ctx.fp8) else None), | ||
| "bias": (bias if (grad_bias is None and not use_fp8_bwd) else None), | ||
| "use_split_accumulator": use_split_accumulator, | ||
| "grad": True, | ||
| "ub": ub_obj_wgrad, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same comment with
linear.py, this seems to be delayed scaling only, can revert/ignore