Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion transformer_engine/pytorch/module/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1135,9 +1135,11 @@ def grad_output_preprocess(
grad_output = grad_output.reshape((-1, grad_output.shape[-1]))
grad_output = grad_output.contiguous()
gather_grad_output = row_parallel_mode and ctx.sequence_parallel
keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False)
use_fp8_bwd = ctx.fp8 and not keep_backward_unquantized

# Non-FP8 case: bgrad is fused with wgrad for this case.
if not ctx.fp8 and not ctx.debug:
if not use_fp8_bwd and not ctx.debug:
if gather_grad_output:
if not ctx.ub_overlap_ag: # Perform NCCL all-gather
grad_output, _ = gather_along_first_dim(grad_output, ctx.tp_group)
Expand Down
30 changes: 22 additions & 8 deletions transformer_engine/pytorch/module/grouped_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,10 @@ def forward(
save_original_input,
debug,
) = non_tensor_args
keep_backward_unquantized = fp8 and FP8GlobalStateManager.keep_backward_unquantized()
if keep_backward_unquantized:
# Note, NVTE_KEEP_BACKWARD_UNQUANTIZED is ignored when delayed scaling is used
save_original_input = True

num_gemms = len(m_splits)
weights = weights_and_biases[:num_gemms]
Expand Down Expand Up @@ -286,6 +290,7 @@ def forward(
ctx.activation_dtype = activation_dtype
ctx.fp8 = fp8
ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None
ctx.keep_backward_unquantized = keep_backward_unquantized
ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
ctx.cpu_offloading = cpu_offloading
ctx.is_first_microbatch = is_first_microbatch
Expand All @@ -294,7 +299,11 @@ def forward(
ctx.inp_shape = inp.shape
ctx.requires_dgrad = inp.requires_grad
ctx.reduce_and_update_bwd_fp8_tensors = False
if ctx.fp8 and requires_grad(inp, weights[0], biases[0]):
if (
ctx.fp8
and not ctx.keep_backward_unquantized
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment with linear.py, this seems to be delayed scaling only, can revert/ignore

and requires_grad(inp, weights[0], biases[0])
):
ctx.reduce_and_update_bwd_fp8_tensors = (
ctx.reduce_and_update_bwd_fp8_tensors
or FP8GlobalStateManager.is_first_fp8_module()
Expand All @@ -318,6 +327,8 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
origin_weights = saved_tensors[2 * N : 3 * N]
biases = saved_tensors[3 * N : 4 * N]
main_grads = [main_grad_func() for main_grad_func in ctx.main_grad_funcs]
keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False)
use_fp8_bwd = ctx.fp8 and not keep_backward_unquantized

if ctx.cpu_offloading:
if ctx.grad_added_to_main_grad:
Expand All @@ -333,7 +344,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
grad_output_view = grad_output.contiguous().view(-1, grad_output.shape[-1])
grad_output = [None] * ctx.num_gemms
grad_biases = [None] * ctx.num_gemms
if ctx.fp8 and not ctx.debug:
if use_fp8_bwd and not ctx.debug:
if ctx.use_bias:
grad_output_mats = torch.split(grad_output_view, ctx.m_splits)
recipe = ctx.fp8_recipe
Expand Down Expand Up @@ -384,7 +395,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],

if ctx.requires_dgrad:
dgrad_gemm_use_split_accumulator = _2X_ACC_DGRAD
if ctx.fp8 or ctx.debug:
if use_fp8_bwd or ctx.debug:
recipe = ctx.fp8_recipe
if hasattr(recipe, "fp8_gemm_dgrad"):
dgrad_gemm_use_split_accumulator = (
Expand All @@ -395,13 +406,16 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
dtype=ctx.activation_dtype,
device=ctx.device,
)
weights_for_dgrad = weights
if keep_backward_unquantized:
weights_for_dgrad = origin_weights
# Make sure weights are available in column-wise format
# for dgrad computation.
for weight in weights:
for weight in weights_for_dgrad:
if isinstance(weight, QuantizedTensorStorage):
weight.update_usage(columnwise_usage=True)
general_grouped_gemm(
weights,
weights_for_dgrad,
grad_output,
[dgrad],
ctx.grad_input_quantizers,
Expand All @@ -415,7 +429,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],

if ctx.weights_requires_grad:
wgrad_gemm_use_split_accumulator = _2X_ACC_WGRAD
if ctx.fp8:
if use_fp8_bwd:
recipe = ctx.fp8_recipe
if hasattr(recipe, "fp8_gemm_wgrad"):
wgrad_gemm_use_split_accumulator = (
Expand All @@ -442,7 +456,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
else:
input_quantizer.set_usage(rowwise=False, columnwise=True)
inputmats: list
if ctx.fp8 and not ctx.debug:
if use_fp8_bwd and not ctx.debug:
inputmats = tex.split_quantize(inp_view, ctx.m_splits, ctx.input_quantizers)
elif ctx.debug:
inputmats = DebugQuantizer.multi_tensor_quantize(
Expand Down Expand Up @@ -516,7 +530,7 @@ def handle_custom_ddp_from_mcore(weight, wgrad):
if not ctx.use_bias or (
ctx.wgrad_store is not None
and ctx.wgrad_store.delay_wgrad_compute()
and not ctx.fp8
and not use_fp8_bwd
):
grad_biases = [None] * ctx.num_gemms

Expand Down
74 changes: 54 additions & 20 deletions transformer_engine/pytorch/module/layernorm_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def forward(
symmetric_ar_type,
debug,
) = non_tensor_args
keep_backward_unquantized = fp8 and FP8GlobalStateManager.keep_backward_unquantized()

# NVTX label for profiling
nvtx_label = "transformer_engine._LayerNormLinear.forward"
Expand Down Expand Up @@ -200,7 +201,10 @@ def forward(
if fp8:
if input_quantizer is None:
raise ValueError("Missing quantizer for input tensor")
input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input)
input_quantizer.set_usage(
rowwise=True,
columnwise=backward_needs_input and not keep_backward_unquantized,
)
if with_input_all_gather and input_quantizer.supports_only_rowwise_all_gather():
# All-gather is not supported with FP8 column-wise data
input_quantizer.set_usage(columnwise=False)
Expand All @@ -213,6 +217,7 @@ def forward(
and not debug
and not return_layernorm_output
and not return_layernorm_output_gathered
and not keep_backward_unquantized
and not custom # TODO(negvet): and not FP8GlobalStateManager.get_fp8_recipe().custom()
)

Expand All @@ -236,6 +241,7 @@ def forward(
ln_out_return = None
if return_layernorm_output or return_layernorm_output_gathered:
ln_out_return = ln_out
Comment on lines 241 to 243
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

storing both ln_out (quantized) and ln_out_hp (high precision) significantly increases memory usage

verify this memory overhead is acceptable for large models, especially during training

ln_out_hp = ln_out if keep_backward_unquantized else None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

storing both ln_out (quantized) and ln_out_hp (high precision) doubles the memory footprint for this activation

verify this memory overhead is acceptable for your target models, especially during training with large batch sizes or long sequences


# ------------------------------------------------------
# Prepare GEMM input tensor
Expand Down Expand Up @@ -409,13 +415,16 @@ def forward(
# ------------------------------------------------------

if is_grad_enabled:
ln_out_to_save = ln_out
if keep_backward_unquantized:
ln_out_to_save = ln_out_hp
ctx.weight_quantizer = weight_quantizer
ctx.ln_out_needs_gather = (
weight.requires_grad and parallel_mode == "column" and sequence_parallel
)

# Input with column-wise usage is needed for wgrad GEMM.
if backward_needs_input:
if backward_needs_input and not keep_backward_unquantized:
if isinstance(ln_out, QuantizedTensorStorage):
# For sequence parallel in vanilla FP8, rowwise data is
# to gather the input. For MXFP8, columnwise only data
Expand All @@ -427,7 +436,7 @@ def forward(
ln_out.update_usage(rowwise_usage=False)

if cpu_offloading:
mark_activation_offload(inputmat, mu, rsigma, ln_out)
mark_activation_offload(inputmat, mu, rsigma, ln_out_to_save)

# Scatter intermediate/activation tensors saved for the backward pass
# NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already
Expand All @@ -439,7 +448,7 @@ def forward(
mu,
rsigma,
weightmat if fp8 and not is_weight_param_quantized else None,
ln_out if weight.requires_grad else None,
ln_out_to_save if weight.requires_grad else None,
)
nvtx_range_pop(f"{nvtx_label}.fsdp_scatter")

Expand All @@ -466,7 +475,7 @@ def forward(
weight,
bias,
ln_weight,
ln_out,
ln_out_to_save,
mu,
rsigma,
)
Expand All @@ -493,6 +502,7 @@ def forward(
ctx.activation_dtype = activation_dtype
ctx.fp8 = fp8
ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None
ctx.keep_backward_unquantized = keep_backward_unquantized
ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
ctx.cpu_offloading = cpu_offloading
ctx.is_first_microbatch = is_first_microbatch
Expand All @@ -515,7 +525,11 @@ def forward(
ctx.requires_dgrad = inp_requires_grad
ctx.normalization = normalization
ctx.reduce_and_update_bwd_fp8_tensors = False
if ctx.fp8 and requires_grad(inp, ln_weight, ln_bias, weight, bias):
if (
ctx.fp8
and not ctx.keep_backward_unquantized
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment

and requires_grad(inp, ln_weight, ln_bias, weight, bias)
):
_first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE
ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module()
if in_fp8_activation_recompute_phase():
Expand Down Expand Up @@ -592,6 +606,15 @@ def backward(
if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation:
origin_weight.main_grad = main_grad

keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False)
use_fp8_bwd = ctx.fp8 and not keep_backward_unquantized
if keep_backward_unquantized:
Copy link
Collaborator

@zhongbozhu zhongbozhu Feb 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this shouldn't be related?

edit: disabling user-buffer when mixing fp8 & bf16 in one layer makes sense here

# Disable Userbuffers communication for backward pass when keep_backward_unquantized is True
ctx.ub_overlap_ag = False
ctx.ub_overlap_rs_dgrad = False
ctx.ub_bulk_dgrad = False
ctx.ub_bulk_wgrad = False

# Configure Userbuffers communication (comm+GEMM overlap)
ctx.ub_obj_gradout = None
ub_obj_dgrad = None
Expand Down Expand Up @@ -628,7 +651,7 @@ def backward(
# Configure quantizer for grad output tensor
# Note: dgrad GEMM requires row-wise usage, wgrad GEMM
# requires column-wise usage
if ctx.grad_output_quantizer is not None:
if ctx.grad_output_quantizer is not None and use_fp8_bwd:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems redundant too if we skip quant in grad_output_preprocess

quantizer = ctx.grad_output_quantizer
quantizer.set_usage(rowwise=True, columnwise=True)
if ctx.ub_overlap_ag:
Expand Down Expand Up @@ -665,7 +688,7 @@ def backward(
ln_out_total_work = None
if ctx.ln_out_needs_gather:
quantizer = None
if ctx.input_quantizer is not None:
if ctx.input_quantizer is not None and use_fp8_bwd:
quantizer = ctx.input_quantizer
if quantizer.supports_only_rowwise_all_gather():
# If data is in FP8, we compute FP8 transposes manually
Expand Down Expand Up @@ -703,18 +726,22 @@ def backward(
# Make sure required data is available
if isinstance(grad_output, QuantizedTensorStorage):
grad_output.update_usage(rowwise_usage=True)
if ctx.weight_quantizer is not None and isinstance(weight, QuantizedTensorStorage):
if (
use_fp8_bwd
and ctx.weight_quantizer is not None
and isinstance(weight, QuantizedTensorStorage)
):
weight.update_usage(columnwise_usage=True)

# Choose whether to use GEMM kernel with split accumulator
use_split_accumulator = _2X_ACC_DGRAD
if ctx.fp8:
if use_fp8_bwd:
recipe = ctx.fp8_recipe
if hasattr(recipe, "fp8_gemm_dgrad"):
use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator

# Update grad input quantizer
if ctx.grad_input_quantizer is not None:
if ctx.grad_input_quantizer is not None and use_fp8_bwd:
ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False)

# Output buffers for Userbuffers reduce-scatter
Expand All @@ -730,12 +757,15 @@ def backward(
# dgrad GEMM
# Note: dx = dy * w
nvtx_range_push(f"{nvtx_label}.dgrad_gemm")
weight_for_dgrad = weight
if keep_backward_unquantized:
weight_for_dgrad = origin_weight
gemm_out, *_, reduce_scatter_out = general_gemm(
weight,
weight_for_dgrad,
grad_output,
layout="NN",
grad=True,
quantization_params=ctx.grad_input_quantizer,
quantization_params=ctx.grad_input_quantizer if use_fp8_bwd else None,
out=gemm_out,
out_dtype=ctx.activation_dtype,
use_split_accumulator=use_split_accumulator,
Expand Down Expand Up @@ -782,7 +812,11 @@ def backward(
# Prepare grad output tensor
# Note: Synchronize tensor-parallel communication and
# make sure required data is available
if ctx.ub_overlap_ag and isinstance(ctx.grad_output_quantizer, MXFP8Quantizer):
if (
use_fp8_bwd
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since we already disabled ub above, this should also be redundant?

and ctx.ub_overlap_ag
and isinstance(ctx.grad_output_quantizer, MXFP8Quantizer)
):
# UB does not support pipelined overlapping grad output
# all-gather with wgrad GEMM. Also, we can't
# convert row-scaled MXFP8 to column-scaled, so we
Expand All @@ -794,7 +828,7 @@ def backward(
dgrad_send_stream, dgrad_recv_stream = ub_obj_dgrad.get_communication_stream()

# This object is separate from the ub_obj_wgrad object which is passed to the GEMM
ub_obj_overlap_wgrad = get_ub(ctx.ub_name + "_wgrad", ctx.fp8)
ub_obj_overlap_wgrad = get_ub(ctx.ub_name + "_wgrad", use_fp8_bwd)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same


ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True)

Expand All @@ -820,14 +854,14 @@ def backward(
if ln_out_total_work is not None:
ln_out_total_work.wait()
ln_out_total_work = None
if ctx.fp8 or ctx.debug:
if use_fp8_bwd or ctx.debug:
if isinstance(ln_out_total, QuantizedTensorStorage):
ln_out_total.update_usage(columnwise_usage=True)
else:
ctx.input_quantizer.set_usage(rowwise=False, columnwise=True)
ln_out_total = ctx.input_quantizer(ln_out_total)

if ctx.fp8 or ctx.debug:
if use_fp8_bwd or ctx.debug:
if isinstance(grad_output, QuantizedTensorStorage):
grad_output.update_usage(columnwise_usage=True)
else:
Expand All @@ -836,7 +870,7 @@ def backward(

# Figure out whether to use split accumulator
use_split_accumulator = _2X_ACC_WGRAD
if ctx.fp8:
if use_fp8_bwd:
recipe = ctx.fp8_recipe
if hasattr(recipe, "fp8_gemm_wgrad"):
use_split_accumulator = recipe.fp8_gemm_wgrad.use_split_accumulator
Expand All @@ -862,15 +896,15 @@ def backward(
"out_dtype": (
main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype
),
"quantization_params": ctx.grad_weight_quantizer,
"quantization_params": (ctx.grad_weight_quantizer if use_fp8_bwd else None),
"accumulate": (
accumulate_wgrad_into_param_main_grad
if not getattr(weight, "overwrite_main_grad", False)
else False
),
"layout": "NT",
"out": main_grad if ctx.fuse_wgrad_accumulation else None,
"bias": (bias if (grad_bias is None and not ctx.fp8) else None),
"bias": (bias if (grad_bias is None and not use_fp8_bwd) else None),
"use_split_accumulator": use_split_accumulator,
"grad": True,
"ub": ub_obj_wgrad,
Expand Down
Loading