diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 841cdf04ca..4a2140718d 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -1135,9 +1135,11 @@ def grad_output_preprocess( grad_output = grad_output.reshape((-1, grad_output.shape[-1])) grad_output = grad_output.contiguous() gather_grad_output = row_parallel_mode and ctx.sequence_parallel + keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False) + use_fp8_bwd = ctx.fp8 and not keep_backward_unquantized # Non-FP8 case: bgrad is fused with wgrad for this case. - if not ctx.fp8 and not ctx.debug: + if not use_fp8_bwd and not ctx.debug: if gather_grad_output: if not ctx.ub_overlap_ag: # Perform NCCL all-gather grad_output, _ = gather_along_first_dim(grad_output, ctx.tp_group) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index c9ceb714e3..7e6773043d 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -96,6 +96,10 @@ def forward( save_original_input, debug, ) = non_tensor_args + keep_backward_unquantized = fp8 and FP8GlobalStateManager.keep_backward_unquantized() + if keep_backward_unquantized: + # Note, NVTE_KEEP_BACKWARD_UNQUANTIZED is ignored when delayed scaling is used + save_original_input = True num_gemms = len(m_splits) weights = weights_and_biases[:num_gemms] @@ -286,6 +290,7 @@ def forward( ctx.activation_dtype = activation_dtype ctx.fp8 = fp8 ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None + ctx.keep_backward_unquantized = keep_backward_unquantized ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation ctx.cpu_offloading = cpu_offloading ctx.is_first_microbatch = is_first_microbatch @@ -318,6 +323,8 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], origin_weights = saved_tensors[2 * N : 3 * N] biases = saved_tensors[3 * N : 4 * N] main_grads = [main_grad_func() for main_grad_func in ctx.main_grad_funcs] + keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False) + use_fp8_bwd = ctx.fp8 and not keep_backward_unquantized if ctx.cpu_offloading: if ctx.grad_added_to_main_grad: @@ -333,7 +340,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], grad_output_view = grad_output.contiguous().view(-1, grad_output.shape[-1]) grad_output = [None] * ctx.num_gemms grad_biases = [None] * ctx.num_gemms - if ctx.fp8 and not ctx.debug: + if use_fp8_bwd and not ctx.debug: if ctx.use_bias: grad_output_mats = torch.split(grad_output_view, ctx.m_splits) recipe = ctx.fp8_recipe @@ -384,7 +391,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if ctx.requires_dgrad: dgrad_gemm_use_split_accumulator = _2X_ACC_DGRAD - if ctx.fp8 or ctx.debug: + if use_fp8_bwd or ctx.debug: recipe = ctx.fp8_recipe if hasattr(recipe, "fp8_gemm_dgrad"): dgrad_gemm_use_split_accumulator = ( @@ -395,13 +402,16 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], dtype=ctx.activation_dtype, device=ctx.device, ) + weights_for_dgrad = weights + if keep_backward_unquantized: + weights_for_dgrad = origin_weights # Make sure weights are available in column-wise format # for dgrad computation. - for weight in weights: + for weight in weights_for_dgrad: if isinstance(weight, QuantizedTensorStorage): weight.update_usage(columnwise_usage=True) general_grouped_gemm( - weights, + weights_for_dgrad, grad_output, [dgrad], ctx.grad_input_quantizers, @@ -415,7 +425,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if ctx.weights_requires_grad: wgrad_gemm_use_split_accumulator = _2X_ACC_WGRAD - if ctx.fp8: + if use_fp8_bwd: recipe = ctx.fp8_recipe if hasattr(recipe, "fp8_gemm_wgrad"): wgrad_gemm_use_split_accumulator = ( @@ -442,7 +452,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], else: input_quantizer.set_usage(rowwise=False, columnwise=True) inputmats: list - if ctx.fp8 and not ctx.debug: + if use_fp8_bwd and not ctx.debug: inputmats = tex.split_quantize(inp_view, ctx.m_splits, ctx.input_quantizers) elif ctx.debug: inputmats = DebugQuantizer.multi_tensor_quantize( @@ -516,7 +526,7 @@ def handle_custom_ddp_from_mcore(weight, wgrad): if not ctx.use_bias or ( ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute() - and not ctx.fp8 + and not use_fp8_bwd ): grad_biases = [None] * ctx.num_gemms diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 702916696b..60c4e1d8b2 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -141,6 +141,7 @@ def forward( symmetric_ar_type, debug, ) = non_tensor_args + keep_backward_unquantized = fp8 and FP8GlobalStateManager.keep_backward_unquantized() # NVTX label for profiling nvtx_label = "transformer_engine._LayerNormLinear.forward" @@ -200,7 +201,10 @@ def forward( if fp8: if input_quantizer is None: raise ValueError("Missing quantizer for input tensor") - input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input) + input_quantizer.set_usage( + rowwise=True, + columnwise=backward_needs_input and not keep_backward_unquantized, + ) if with_input_all_gather and input_quantizer.supports_only_rowwise_all_gather(): # All-gather is not supported with FP8 column-wise data input_quantizer.set_usage(columnwise=False) @@ -213,6 +217,7 @@ def forward( and not debug and not return_layernorm_output and not return_layernorm_output_gathered + and not keep_backward_unquantized and not custom # TODO(negvet): and not FP8GlobalStateManager.get_fp8_recipe().custom() ) @@ -236,6 +241,7 @@ def forward( ln_out_return = None if return_layernorm_output or return_layernorm_output_gathered: ln_out_return = ln_out + ln_out_hp = ln_out if keep_backward_unquantized else None # ------------------------------------------------------ # Prepare GEMM input tensor @@ -409,13 +415,16 @@ def forward( # ------------------------------------------------------ if is_grad_enabled: + ln_out_to_save = ln_out + if keep_backward_unquantized: + ln_out_to_save = ln_out_hp ctx.weight_quantizer = weight_quantizer ctx.ln_out_needs_gather = ( weight.requires_grad and parallel_mode == "column" and sequence_parallel ) # Input with column-wise usage is needed for wgrad GEMM. - if backward_needs_input: + if backward_needs_input and not keep_backward_unquantized: if isinstance(ln_out, QuantizedTensorStorage): # For sequence parallel in vanilla FP8, rowwise data is # to gather the input. For MXFP8, columnwise only data @@ -427,7 +436,7 @@ def forward( ln_out.update_usage(rowwise_usage=False) if cpu_offloading: - mark_activation_offload(inputmat, mu, rsigma, ln_out) + mark_activation_offload(inputmat, mu, rsigma, ln_out_to_save) # Scatter intermediate/activation tensors saved for the backward pass # NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already @@ -439,7 +448,7 @@ def forward( mu, rsigma, weightmat if fp8 and not is_weight_param_quantized else None, - ln_out if weight.requires_grad else None, + ln_out_to_save if weight.requires_grad else None, ) nvtx_range_pop(f"{nvtx_label}.fsdp_scatter") @@ -466,7 +475,7 @@ def forward( weight, bias, ln_weight, - ln_out, + ln_out_to_save, mu, rsigma, ) @@ -493,6 +502,7 @@ def forward( ctx.activation_dtype = activation_dtype ctx.fp8 = fp8 ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None + ctx.keep_backward_unquantized = keep_backward_unquantized ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation ctx.cpu_offloading = cpu_offloading ctx.is_first_microbatch = is_first_microbatch @@ -515,7 +525,11 @@ def forward( ctx.requires_dgrad = inp_requires_grad ctx.normalization = normalization ctx.reduce_and_update_bwd_fp8_tensors = False - if ctx.fp8 and requires_grad(inp, ln_weight, ln_bias, weight, bias): + if ( + ctx.fp8 + and not ctx.keep_backward_unquantized + and requires_grad(inp, ln_weight, ln_bias, weight, bias) + ): _first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module() if in_fp8_activation_recompute_phase(): @@ -592,6 +606,15 @@ def backward( if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation: origin_weight.main_grad = main_grad + keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False) + use_fp8_bwd = ctx.fp8 and not keep_backward_unquantized + if keep_backward_unquantized: + # Disable Userbuffers communication for backward pass when keep_backward_unquantized is True + ctx.ub_overlap_ag = False + ctx.ub_overlap_rs_dgrad = False + ctx.ub_bulk_dgrad = False + ctx.ub_bulk_wgrad = False + # Configure Userbuffers communication (comm+GEMM overlap) ctx.ub_obj_gradout = None ub_obj_dgrad = None @@ -628,7 +651,7 @@ def backward( # Configure quantizer for grad output tensor # Note: dgrad GEMM requires row-wise usage, wgrad GEMM # requires column-wise usage - if ctx.grad_output_quantizer is not None: + if ctx.grad_output_quantizer is not None and use_fp8_bwd: quantizer = ctx.grad_output_quantizer quantizer.set_usage(rowwise=True, columnwise=True) if ctx.ub_overlap_ag: @@ -665,7 +688,7 @@ def backward( ln_out_total_work = None if ctx.ln_out_needs_gather: quantizer = None - if ctx.input_quantizer is not None: + if ctx.input_quantizer is not None and use_fp8_bwd: quantizer = ctx.input_quantizer if quantizer.supports_only_rowwise_all_gather(): # If data is in FP8, we compute FP8 transposes manually @@ -703,18 +726,22 @@ def backward( # Make sure required data is available if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(rowwise_usage=True) - if ctx.weight_quantizer is not None and isinstance(weight, QuantizedTensorStorage): + if ( + use_fp8_bwd + and ctx.weight_quantizer is not None + and isinstance(weight, QuantizedTensorStorage) + ): weight.update_usage(columnwise_usage=True) # Choose whether to use GEMM kernel with split accumulator use_split_accumulator = _2X_ACC_DGRAD - if ctx.fp8: + if use_fp8_bwd: recipe = ctx.fp8_recipe if hasattr(recipe, "fp8_gemm_dgrad"): use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator # Update grad input quantizer - if ctx.grad_input_quantizer is not None: + if ctx.grad_input_quantizer is not None and use_fp8_bwd: ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False) # Output buffers for Userbuffers reduce-scatter @@ -730,12 +757,15 @@ def backward( # dgrad GEMM # Note: dx = dy * w nvtx_range_push(f"{nvtx_label}.dgrad_gemm") + weight_for_dgrad = weight + if keep_backward_unquantized: + weight_for_dgrad = origin_weight gemm_out, *_, reduce_scatter_out = general_gemm( - weight, + weight_for_dgrad, grad_output, layout="NN", grad=True, - quantization_params=ctx.grad_input_quantizer, + quantization_params=ctx.grad_input_quantizer if use_fp8_bwd else None, out=gemm_out, out_dtype=ctx.activation_dtype, use_split_accumulator=use_split_accumulator, @@ -820,14 +850,14 @@ def backward( if ln_out_total_work is not None: ln_out_total_work.wait() ln_out_total_work = None - if ctx.fp8 or ctx.debug: + if use_fp8_bwd or ctx.debug: if isinstance(ln_out_total, QuantizedTensorStorage): ln_out_total.update_usage(columnwise_usage=True) else: ctx.input_quantizer.set_usage(rowwise=False, columnwise=True) ln_out_total = ctx.input_quantizer(ln_out_total) - if ctx.fp8 or ctx.debug: + if use_fp8_bwd or ctx.debug: if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(columnwise_usage=True) else: @@ -836,7 +866,7 @@ def backward( # Figure out whether to use split accumulator use_split_accumulator = _2X_ACC_WGRAD - if ctx.fp8: + if use_fp8_bwd: recipe = ctx.fp8_recipe if hasattr(recipe, "fp8_gemm_wgrad"): use_split_accumulator = recipe.fp8_gemm_wgrad.use_split_accumulator @@ -862,7 +892,7 @@ def backward( "out_dtype": ( main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype ), - "quantization_params": ctx.grad_weight_quantizer, + "quantization_params": (ctx.grad_weight_quantizer if use_fp8_bwd else None), "accumulate": ( accumulate_wgrad_into_param_main_grad if not getattr(weight, "overwrite_main_grad", False) @@ -870,7 +900,7 @@ def backward( ), "layout": "NT", "out": main_grad if ctx.fuse_wgrad_accumulation else None, - "bias": (bias if (grad_bias is None and not ctx.fp8) else None), + "bias": (bias if (grad_bias is None and not use_fp8_bwd) else None), "use_split_accumulator": use_split_accumulator, "grad": True, "ub": ub_obj_wgrad, diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index bec6744518..44028aebcc 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -232,6 +232,7 @@ def _forward( debug, recompute_for_bwd, ) = non_tensor_args + keep_backward_unquantized = fp8 and FP8GlobalStateManager.keep_backward_unquantized() # if grad is enabled and this is not the bwd stage, we must save this so bwd knows which path to take if is_grad_enabled and not recompute_for_bwd: @@ -394,6 +395,7 @@ def _forward( and not debug and not return_layernorm_output and not return_layernorm_output_gathered + and not keep_backward_unquantized and not custom ) @@ -415,6 +417,7 @@ def _forward( # do not return layernorm output unless 1) no checkpointing or 2) checkpointing but not recomputing if (return_layernorm_output or return_layernorm_output_gathered) and not is_recomputation: ln_out_return = ln_out + ln_out_hp = ln_out if keep_backward_unquantized else None # Prepare GEMM input # Note: Cast to expected dtype and perform tensor-parallel communication @@ -611,6 +614,10 @@ def _forward( if fc2_input_quantizer is not None: fc2_input_quantizer.calibrate(act_out) + act_out_hp = act_out + if keep_backward_unquantized and is_grad_enabled and fc1_out is not None: + act_out_hp = activation_func(fc1_out, None, **act_params) + # we want to skip fc2 computation if we are checkpointing and recomputing, # otherwise we compute fc2 if not (is_recomputation and checkpoint): @@ -686,22 +693,33 @@ def _forward( # if we are not checkpointing, then we must save this if grad is enabled if is_grad_enabled and not save_for_checkpoint: + ln_out_to_save = ln_out + act_out_to_save = act_out + if keep_backward_unquantized: + ln_out_to_save = ln_out_hp + act_out_to_save = act_out_hp ctx.fc1_weight_quantizer = fc1_weight_quantizer ctx.fc2_weight_quantizer = fc2_weight_quantizer if not fc1_weight.requires_grad: if not return_layernorm_output: - clear_tensor_data(ln_out) - ln_out = None + clear_tensor_data(ln_out_to_save) + ln_out_to_save = None if not fc2_weight.requires_grad: - clear_tensor_data(act_out) - act_out = None + clear_tensor_data(act_out_to_save) + act_out_to_save = None if not checkpoint: # regular path, no selective activation checkpointing if cpu_offloading: mark_activation_offload( - inputmat, mu, rsigma, ln_out, fc1_out, fc1_out_without_bias, act_out + inputmat, + mu, + rsigma, + ln_out_to_save, + fc1_out, + fc1_out_without_bias, + act_out_to_save, ) # Scatter intermediate/activation tensors saved for the backward pass @@ -714,9 +732,9 @@ def _forward( fsdp_group, mu, rsigma, - ln_out, + ln_out_to_save, fc1_out_without_bias if bias_gelu_fusion else fc1_out, - act_out, + act_out_to_save, ( fc1_weight_final if fp8 and not isinstance(fc1_weight, Float8Tensor) @@ -744,13 +762,13 @@ def _forward( tensors_to_save, tensor_objects = prepare_for_saving( inputmat, ln_weight, - ln_out, + ln_out_to_save, fc1_weight_final, fc1_weight, fc1_bias, fc1_out, fc1_out_without_bias, - act_out, + act_out_to_save, fc2_weight_final, fc2_weight, fc2_bias, @@ -798,6 +816,7 @@ def _forward( ctx.activation_params = activation_params ctx.fp8 = fp8 ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None + ctx.keep_backward_unquantized = keep_backward_unquantized ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation ctx.cpu_offloading = cpu_offloading ctx.is_first_microbatch = is_first_microbatch @@ -996,6 +1015,15 @@ def backward( origin_fc1_weight.main_grad = fc1_weight_main_grad origin_fc2_weight.main_grad = fc2_weight_main_grad + keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False) + use_fp8_bwd = ctx.fp8 and not keep_backward_unquantized + if keep_backward_unquantized: + # Disable Userbuffers communication for backward pass when keep_backward_unquantized is True + ctx.ub_overlap_ag = False + ctx.ub_overlap_rs_dgrad = False + ctx.ub_bulk_dgrad = False + ctx.ub_bulk_wgrad = False + # TODO: Fix this # pylint: disable=fixme # Gather saved autograd context tensors when running with FSDP # NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already @@ -1015,7 +1043,7 @@ def backward( # Choose whether to use GEMM kernel with split accumulator dgrad_use_split_accumulator = _2X_ACC_DGRAD wgrad_use_split_accumulator = _2X_ACC_WGRAD - if ctx.fp8: + if use_fp8_bwd: recipe = ctx.fp8_recipe if hasattr(recipe, "fp8_gemm_dgrad"): dgrad_use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator @@ -1029,7 +1057,7 @@ def backward( # Configure quantizer for FC2 grad output tensor # Note: dgrad GEMM requires row-wise usage, wgrad GEMM # requires column-wise usage - if ctx.fc2_grad_output_quantizer is not None: + if ctx.fc2_grad_output_quantizer is not None and use_fp8_bwd: quantizer = ctx.fc2_grad_output_quantizer quantizer.set_usage(rowwise=True, columnwise=True) if ctx.ub_overlap_ag: @@ -1057,7 +1085,7 @@ def backward( ub_obj_fc1_dgrad = None if ctx.fc1_weight_requires_grad and ctx.tensor_parallel and ctx.sequence_parallel: quantizer = None - if ctx.fp8 or ctx.debug: + if use_fp8_bwd or ctx.debug: quantizer = ctx.fc1_input_quantizer if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): # If data is in FP8, we compute FP8 transposes manually @@ -1103,7 +1131,7 @@ def backward( # 5 high-precision unfused: gemm, activation, FC1_bias + FC1_gemm # 6 fp8 unfused: gemm, activation, FC1_bias + FC1_gemm fc2_dgrad_gemm_gelu_fusion = ( - not ctx.fp8 + not use_fp8_bwd and (ctx.activation == "gelu") and (not ctx.bias_gelu_fusion) and (not ctx.debug) @@ -1112,20 +1140,25 @@ def backward( # Make sure required data is available if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(rowwise_usage=True) - if ctx.fc2_weight_quantizer is not None and isinstance( - ctx.fc2_weight, QuantizedTensorStorage + if ( + use_fp8_bwd + and ctx.fc2_weight_quantizer is not None + and isinstance(ctx.fc2_weight, QuantizedTensorStorage) ): ctx.fc2_weight.update_usage(columnwise_usage=True) # Perform GEMM + fc2_weight_for_dgrad = fc2_weight + if keep_backward_unquantized: + fc2_weight_for_dgrad = origin_fc2_weight gemm_output, *_ = general_gemm( - fc2_weight, + fc2_weight_for_dgrad, grad_output, layout="NN", grad=True, quantization_params=( ctx.fc1_grad_input_quantizer - if fc2_dgrad_gemm_gelu_fusion or ctx.debug + if (fc2_dgrad_gemm_gelu_fusion or ctx.debug) and use_fp8_bwd else None ), # high precision to activation out_dtype=ctx.activation_dtype, @@ -1193,14 +1226,14 @@ def backward( # Prepare input tensor # Note: Synchronize tensor-parallel communication and # make sure required data is available - if ctx.fp8 or ctx.debug: + if use_fp8_bwd or ctx.debug: if isinstance(act_out, QuantizedTensorStorage): act_out.update_usage(columnwise_usage=True) else: ctx.fc2_input_quantizer.set_usage(rowwise=False, columnwise=True) act_out = ctx.fc2_input_quantizer(act_out) - if ctx.fp8 or ctx.debug: + if use_fp8_bwd or ctx.debug: if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(columnwise_usage=True) else: @@ -1209,7 +1242,7 @@ def backward( # Whether to set grad arg in general_gemm grad_arg = True - if ctx.fp8 and ctx.fp8_recipe.float8_block_scaling(): + if use_fp8_bwd and ctx.fp8_recipe.float8_block_scaling(): grad_arg = False # Arguments to include in wgrad GEMM closure @@ -1219,7 +1252,9 @@ def backward( if ctx.fuse_wgrad_accumulation else ctx.activation_dtype ), - "quantization_params": ctx.fc2_grad_weight_quantizer, # wgrad in high precision + "quantization_params": ( + ctx.fc2_grad_weight_quantizer if use_fp8_bwd else None + ), # wgrad in high precision "accumulate": ( accumulate_wgrad_into_param_main_grad if not getattr(fc1_weight, "overwrite_main_grad", False) @@ -1256,7 +1291,7 @@ def fc2_wgrad_gemm( # Update grad bias if needed if fc2_bias_grad is None: if ( - ctx.fp8 + use_fp8_bwd and ctx.fp8_recipe.float8_block_scaling() and fc2_bias is not None ): @@ -1277,12 +1312,12 @@ def fc2_wgrad_gemm( act_params = ctx.activation_params or {} fc1_bias_grad = None fuse_gemm_and_bias_fc1_wgrad = False - if ctx.fc1_grad_output_quantizer is not None: + if ctx.fc1_grad_output_quantizer is not None and use_fp8_bwd: ctx.fc1_grad_output_quantizer.set_usage(rowwise=True, columnwise=True) if ctx.bias_gelu_fusion: # Fusion: gemm, bias + gelu assert ctx.activation == "gelu" - assert not ctx.fp8 + assert not use_fp8_bwd fc1_bias_grad, dact = bgrad_dgelu_fused(fc2_dgrad, fc1_out_without_bias, fc1_bias) if ctx.fc1_grad_output_quantizer is not None: dact = ctx.fc1_grad_output_quantizer(dact) @@ -1293,7 +1328,7 @@ def fc2_wgrad_gemm( dact = ctx.fc1_grad_output_quantizer(dact) elif ( _act_func(ctx.activation, ctx.fp8_recipe if ctx.fp8 else None)[2] is not None - and ctx.fp8 + and use_fp8_bwd ): # Fusion: gemm, bias + gelu + quantize dbias_dact_quantize_func = _act_func( @@ -1315,7 +1350,7 @@ def fc2_wgrad_gemm( fc2_dgrad, fc1_out.to(ctx.activation_dtype), None, **act_params ) # activation in high precision - if ctx.fp8: + if use_fp8_bwd: # TODO float8 blockwise current scaling (as well as custom quantizers) has no bgrad fusion for now if ( isinstance(ctx.fc1_grad_output_quantizer, Float8BlockQuantizer) @@ -1364,8 +1399,10 @@ def fc2_wgrad_gemm( # -------------------------------------------------- # Make sure required data is available - if ctx.fc1_weight_quantizer is not None and isinstance( - ctx.fc1_weight_quantizer, QuantizedTensorStorage + if ( + use_fp8_bwd + and ctx.fc1_weight_quantizer is not None + and isinstance(ctx.fc1_weight_quantizer, QuantizedTensorStorage) ): ctx.fc1_weight.update_usage(columnwise_usage=True) @@ -1380,12 +1417,13 @@ def fc2_wgrad_gemm( gemm_out = ub_obj_fc1_wgrad.get_buffer(local_chunk=False) # dgrad GEMM + fc1_weight_for_dgrad = fc1_weight if use_fp8_bwd else origin_fc1_weight gemm_out, *_, reduce_scatter_out = general_gemm( - fc1_weight, + fc1_weight_for_dgrad, dact, out=gemm_out, out_dtype=ctx.activation_dtype, - quantization_params=ctx.fc1_grad_input_quantizer, + quantization_params=ctx.fc1_grad_input_quantizer if use_fp8_bwd else None, layout="NN", grad=True, use_split_accumulator=dgrad_use_split_accumulator, @@ -1434,7 +1472,7 @@ def fc2_wgrad_gemm( if ln_out_total_work is not None: ln_out_total_work.wait() ln_out_total_work = None - if ctx.fp8 or ctx.debug: + if use_fp8_bwd or ctx.debug: if isinstance(ln_out_total, QuantizedTensorStorage): ln_out_total.update_usage(columnwise_usage=True) else: @@ -1444,7 +1482,7 @@ def fc2_wgrad_gemm( # Prepare grad output tensor # Note: Synchronize tensor-parallel communication and # make sure required data is available - if ctx.fp8 or ctx.debug: + if use_fp8_bwd or ctx.debug: if isinstance(dact, QuantizedTensorStorage): dact.update_usage(columnwise_usage=True) else: @@ -1466,7 +1504,7 @@ def fc2_wgrad_gemm( if ctx.fuse_wgrad_accumulation else ctx.activation_dtype ), - "quantization_params": ctx.fc1_grad_weight_quantizer, + "quantization_params": (ctx.fc1_grad_weight_quantizer if use_fp8_bwd else None), "accumulate": ( accumulate_wgrad_into_param_main_grad if not getattr(fc2_weight, "overwrite_main_grad", False) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 23ad8cacb0..535d2e75e5 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -129,6 +129,10 @@ def forward( save_original_input, debug, ) = non_tensor_args + keep_backward_unquantized = fp8 and FP8GlobalStateManager.keep_backward_unquantized() + if keep_backward_unquantized: + # Note, NVTE_KEEP_BACKWARD_UNQUANTIZED is ignored when delayed scaling is used + save_original_input = True # NVTX label for profiling nvtx_label = "transformer_engine._Linear.forward" @@ -443,6 +447,7 @@ def forward( ctx.activation_dtype = activation_dtype ctx.fp8 = fp8 ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None + ctx.keep_backward_unquantized = keep_backward_unquantized ctx.input_quantizer = input_quantizer ctx.grad_input_quantizer = grad_input_quantizer ctx.grad_weight_quantizer = grad_weight_quantizer @@ -536,6 +541,15 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ) nvtx_range_pop(f"{nvtx_label}.fsdp_gather") + keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False) + use_fp8_bwd = ctx.fp8 and not keep_backward_unquantized + if keep_backward_unquantized: + # Disable Userbuffers communication for backward pass when keep_backward_unquantized is True + ctx.ub_overlap_ag = False + ctx.ub_overlap_rs_dgrad = False + ctx.ub_bulk_dgrad = False + ctx.ub_bulk_wgrad = False + # Configure Userbuffers communication (comm+GEMM overlap) ctx.ub_obj_gradout = None ub_obj_dgrad = None @@ -575,7 +589,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # Configure quantizer for grad output tensor # Note: dgrad GEMM requires row-wise usage, wgrad GEMM # requires column-wise usage - if ctx.grad_output_quantizer is not None: + if ctx.grad_output_quantizer is not None and use_fp8_bwd: quantizer = ctx.grad_output_quantizer quantizer.set_usage(rowwise=True, columnwise=True) if ctx.ub_overlap_ag: @@ -594,6 +608,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], not ctx.use_bias and not ctx.requires_wgrad and ctx.grad_output_quantizer is not None + and use_fp8_bwd ): ctx.grad_output_quantizer.set_usage(columnwise=False) @@ -623,7 +638,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], inputmat_total = None inputmat_total_work = None if ctx.requires_wgrad: - if ctx.fp8 or ctx.debug: + if use_fp8_bwd or ctx.debug: if isinstance(inputmat, QuantizedTensorStorage): # Input tensor is already quantized pass @@ -649,7 +664,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], inputmat = cast_if_needed(inputmat, ctx.activation_dtype) if ctx.backward_input_needs_gather: quantizer = None - if ctx.fp8 or ctx.debug: + if use_fp8_bwd or ctx.debug: quantizer = ctx.input_quantizer if quantizer.supports_only_rowwise_all_gather(): # If data is in FP8, we compute FP8 transposes manually @@ -690,20 +705,22 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # Make sure required data is available if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(rowwise_usage=True) - if ctx.weight_quantizer is not None and isinstance( - weight_fp8, QuantizedTensorStorage + if ( + use_fp8_bwd + and ctx.weight_quantizer is not None + and isinstance(weight_fp8, QuantizedTensorStorage) ): weight_fp8.update_usage(columnwise_usage=True) # Choose whether to use GEMM kernel with split accumulator use_split_accumulator = _2X_ACC_DGRAD - if ctx.fp8: + if use_fp8_bwd: recipe = ctx.fp8_recipe if hasattr(recipe, "fp8_gemm_dgrad"): use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator # Update grad input quantizer - if ctx.grad_input_quantizer is not None: + if ctx.grad_input_quantizer is not None and use_fp8_bwd: ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False) # Output buffers for Userbuffers reduce-scatter @@ -720,12 +737,15 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # Note: dx = dy * w nvtx_range_push(f"{nvtx_label}.dgrad_gemm") + weight_for_dgrad = weight_fp8 + if keep_backward_unquantized: + weight_for_dgrad = weight gemm_out, *_, reduce_scatter_out = general_gemm( - weight_fp8, + weight_for_dgrad, grad_output, layout="NN", grad=True, - quantization_params=ctx.grad_input_quantizer, + quantization_params=ctx.grad_input_quantizer if use_fp8_bwd else None, out=gemm_out, out_dtype=ctx.activation_dtype, use_split_accumulator=use_split_accumulator, @@ -774,7 +794,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if inputmat_total_work is not None: inputmat_total_work.wait() inputmat_total_work = None - if ctx.fp8 or ctx.debug: + if use_fp8_bwd or ctx.debug: if isinstance(inputmat_total, QuantizedTensorStorage): inputmat_total.update_usage(columnwise_usage=True) else: @@ -816,7 +836,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ub_obj_overlap_wgrad, dgrad_send_stream, dgrad_recv_stream ) - if ctx.fp8 or ctx.debug: + if use_fp8_bwd or ctx.debug: if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(columnwise_usage=True) else: @@ -825,7 +845,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # Figure out whether to use split accumulator use_split_accumulator = _2X_ACC_WGRAD - if ctx.fp8: + if use_fp8_bwd: recipe = ctx.fp8_recipe if hasattr(recipe, "fp8_gemm_wgrad"): use_split_accumulator = recipe.fp8_gemm_wgrad.use_split_accumulator @@ -851,7 +871,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], "out_dtype": ( main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype ), - "quantization_params": ctx.grad_weight_quantizer, + "quantization_params": (ctx.grad_weight_quantizer if use_fp8_bwd else None), "accumulate": ( accumulate_wgrad_into_param_main_grad if not getattr(weight, "overwrite_main_grad", False) @@ -859,7 +879,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ), "layout": "NT", "out": main_grad if ctx.fuse_wgrad_accumulation else None, - "bias": (bias if (grad_bias is None and not ctx.fp8) else None), + "bias": (bias if (grad_bias is None and not use_fp8_bwd) else None), "use_split_accumulator": use_split_accumulator, "grad": True, "ub": ub_obj_wgrad, diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index e640f3ffb1..a9a6895112 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -332,12 +332,14 @@ def pre_fuser_forward(self, *, requires_grad: bool) -> None: # Note: We cache the quantized input for backward pass, # but discard the quantized weights. weight_requires_grad = requires_grad and self.weight.requires_grad + keep_backward_unquantized = FP8GlobalStateManager.keep_backward_unquantized() + columnwise_usage = weight_requires_grad and not keep_backward_unquantized input_quantizer = self.get_quantizer("forward", 0) weight_quantizer = self.get_quantizer("forward", 1) grad_output_quantizer = self.get_quantizer("backward", 0) - input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) + input_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) weight_quantizer.set_usage(rowwise=True, columnwise=False) - grad_output_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) + grad_output_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None: super().reset_recipe_state(recipe=recipe) @@ -420,6 +422,7 @@ def _functional_forward( tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None, sequence_parallel: bool = False, with_quantized_compute: bool = False, + keep_backward_unquantized: bool = False, input_quantizer: Optional[Quantizer] = None, weight_quantizer: Optional[Quantizer] = None, output_quantizer: Optional[Quantizer] = None, @@ -459,6 +462,8 @@ def _functional_forward( distributing along inner dimension (embedding dim) with_quantized_compute: bool, default = `False` Whether to perform compute with quantized data. + keep_backward_unquantized: bool, default = `False` + Whether to skip quantized backward and use high precision. input_quantizer: Quantizer, optional Builder class for quantized input tensor. weight_quantizer: Quantizer, optional @@ -510,7 +515,10 @@ def _functional_forward( if with_quantized_compute: if input_quantizer is None: raise ValueError("Missing quantizer for input tensor") - input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) + input_quantizer.set_usage( + rowwise=True, + columnwise=weight_requires_grad and not keep_backward_unquantized, + ) if with_x_all_gather: input_quantizer.set_usage(columnwise=False) x, x_async = gather_along_first_dim( @@ -542,7 +550,10 @@ def _functional_forward( elif with_quantized_compute and not is_quantized_tensor(w): if weight_quantizer is None: raise ValueError("Missing quantizer for weight tensor") - weight_quantizer.set_usage(rowwise=True, columnwise=input_requires_grad) + weight_quantizer.set_usage( + rowwise=True, + columnwise=input_requires_grad and not keep_backward_unquantized, + ) w = weight_quantizer(w) # Check output tensor @@ -611,14 +622,23 @@ def _functional_forward( # Prepare weight tensor for backward pass if input_requires_grad: - if w is not weight and with_quantized_compute and is_quantized_tensor(w): + if ( + w is not weight + and with_quantized_compute + and is_quantized_tensor(w) + and not keep_backward_unquantized + ): w.update_usage(rowwise_usage=False, columnwise_usage=True) else: w = None # Prepare input tensor for backward pass if weight_requires_grad: - if with_quantized_compute and is_quantized_tensor(x_local): + if ( + with_quantized_compute + and is_quantized_tensor(x_local) + and not keep_backward_unquantized + ): if not (isinstance(x_local, Float8TensorStorage) and with_x_all_gather): # FP8 does not support all-gather of transpose data x_local.update_usage(rowwise_usage=False, columnwise_usage=True) @@ -968,6 +988,9 @@ def op_forward( grad_output_quantizer = self.get_quantizer("backward", 0) grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() + keep_backward_unquantized = ( + with_quantized_compute and FP8GlobalStateManager.keep_backward_unquantized() + ) # Get autocast dtype if needed if torch.is_autocast_enabled(): @@ -984,6 +1007,7 @@ def op_forward( tensor_parallel_group=self.tensor_parallel_group, sequence_parallel=self.sequence_parallel, with_quantized_compute=with_quantized_compute, + keep_backward_unquantized=keep_backward_unquantized, input_quantizer=input_quantizer, weight_quantizer=weight_quantizer, output_quantizer=output_quantizer, @@ -993,10 +1017,16 @@ def op_forward( # Save state for backward pass if ctx.requires_grad: + saved_input = input_ if keep_backward_unquantized else x_local + if not weight_requires_grad: + saved_input = None + saved_weight = self.weight if keep_backward_unquantized else w + if not input_requires_grad: + saved_weight = None if is_cpu_offload_enabled(): - mark_activation_offload(x_local) - ctx.save_for_backward(x_local, w) - ctx.with_quantized_compute = with_quantized_compute + mark_activation_offload(saved_input) + ctx.save_for_backward(saved_input, saved_weight) + ctx.with_quantized_compute = with_quantized_compute and not keep_backward_unquantized ctx.input_quantizer = input_quantizer ctx.weight_quantizer = weight_quantizer ctx.grad_output_quantizer = grad_output_quantizer diff --git a/transformer_engine/pytorch/ops/basic/quantize.py b/transformer_engine/pytorch/ops/basic/quantize.py index d126b554b5..cc26022d0e 100644 --- a/transformer_engine/pytorch/ops/basic/quantize.py +++ b/transformer_engine/pytorch/ops/basic/quantize.py @@ -57,7 +57,11 @@ def op_forward( # Check if FP8 is enabled fp8_enabled = FP8GlobalStateManager.is_fp8_enabled() quantize_forward = fp8_enabled and self._quantize_forward - quantize_backward = fp8_enabled and self._quantize_backward + quantize_backward = ( + fp8_enabled + and self._quantize_backward + and not FP8GlobalStateManager.keep_backward_unquantized() + ) # Quantize if needed out = input_ diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py index dfc11a19e7..2458d4d072 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py @@ -92,6 +92,9 @@ def fuser_forward( grad_output_quantizer = linear_op.get_quantizer("backward", 0) grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() + keep_backward_unquantized = ( + with_quantized_compute and FP8GlobalStateManager.keep_backward_unquantized() + ) # Get autocast dtype if needed if torch.is_autocast_enabled(): @@ -109,6 +112,7 @@ def fuser_forward( tensor_parallel_group=linear_op.tensor_parallel_group, sequence_parallel=linear_op.sequence_parallel, with_quantized_compute=with_quantized_compute, + keep_backward_unquantized=keep_backward_unquantized, input_quantizer=input_quantizer, weight_quantizer=weight_quantizer, output_quantizer=output_quantizer, @@ -118,10 +122,17 @@ def fuser_forward( # Save state for backward pass if linear_op_ctx.requires_grad: + saved_input = x_local + saved_weight = w + if keep_backward_unquantized: + saved_input = input_ if input_requires_grad else None + saved_weight = linear_op.weight if weight_requires_grad else None if is_cpu_offload_enabled(): - mark_activation_offload(x_local) - linear_op_ctx.save_for_backward(x_local, w) - linear_op_ctx.with_quantized_compute = with_quantized_compute + mark_activation_offload(saved_input) + linear_op_ctx.save_for_backward(saved_input, saved_weight) + linear_op_ctx.with_quantized_compute = ( + with_quantized_compute and not keep_backward_unquantized + ) linear_op_ctx.input_quantizer = input_quantizer linear_op_ctx.weight_quantizer = weight_quantizer linear_op_ctx.grad_output_quantizer = grad_output_quantizer diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py index 2dfc0566b7..efa543e555 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py @@ -86,6 +86,9 @@ def fuser_forward( grad_output_quantizer = linear_op.get_quantizer("backward", 0) grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() + keep_backward_unquantized = ( + with_quantized_compute and FP8GlobalStateManager.keep_backward_unquantized() + ) # Get autocast dtype if needed if torch.is_autocast_enabled(): @@ -106,6 +109,7 @@ def fuser_forward( tensor_parallel_group=linear_op.tensor_parallel_group, sequence_parallel=linear_op.sequence_parallel, with_quantized_compute=with_quantized_compute, + keep_backward_unquantized=keep_backward_unquantized, input_quantizer=input_quantizer, weight_quantizer=weight_quantizer, output_quantizer=output_quantizer, @@ -115,10 +119,17 @@ def fuser_forward( # Save state for backward pass if linear_op_ctx.requires_grad: + saved_input = x_local + saved_weight = w + if keep_backward_unquantized: + saved_input = input_ if input_requires_grad else None + saved_weight = linear_op.weight if weight_requires_grad else None if is_cpu_offload_enabled(): - mark_activation_offload(x_local) - linear_op_ctx.save_for_backward(x_local, w) - linear_op_ctx.with_quantized_compute = with_quantized_compute + mark_activation_offload(saved_input) + linear_op_ctx.save_for_backward(saved_input, saved_weight) + linear_op_ctx.with_quantized_compute = ( + with_quantized_compute and not keep_backward_unquantized + ) linear_op_ctx.input_quantizer = input_quantizer linear_op_ctx.weight_quantizer = weight_quantizer linear_op_ctx.grad_output_quantizer = grad_output_quantizer diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py index ae4bdd4b19..2804534968 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py @@ -65,6 +65,9 @@ def fuser_forward( grad_output_quantizer = linear_op.get_quantizer("backward", 0) grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() + keep_backward_unquantized = ( + with_quantized_compute and FP8GlobalStateManager.keep_backward_unquantized() + ) # Get extra input tensor for add operation extra_input = basic_op_extra_inputs[2][0] @@ -87,6 +90,7 @@ def fuser_forward( tensor_parallel_group=linear_op.tensor_parallel_group, sequence_parallel=linear_op.sequence_parallel, with_quantized_compute=with_quantized_compute, + keep_backward_unquantized=keep_backward_unquantized, input_quantizer=input_quantizer, weight_quantizer=weight_quantizer, output_quantizer=output_quantizer, @@ -96,10 +100,17 @@ def fuser_forward( # Save state for backward pass if linear_op_ctx.requires_grad: + saved_input = x_local + saved_weight = w + if keep_backward_unquantized: + saved_input = input_ if input_requires_grad else None + saved_weight = linear_op.weight if weight_requires_grad else None if is_cpu_offload_enabled(): - mark_activation_offload(x_local) - linear_op_ctx.save_for_backward(x_local, w) - linear_op_ctx.with_quantized_compute = with_quantized_compute + mark_activation_offload(saved_input) + linear_op_ctx.save_for_backward(saved_input, saved_weight) + linear_op_ctx.with_quantized_compute = ( + with_quantized_compute and not keep_backward_unquantized + ) linear_op_ctx.input_quantizer = input_quantizer linear_op_ctx.weight_quantizer = weight_quantizer linear_op_ctx.grad_output_quantizer = grad_output_quantizer diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index eba547afb0..aab7ed2d1c 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -430,6 +430,15 @@ def with_high_precision_init_val(cls) -> bool: """Should the high precision initial values be stored with FP8 parameters""" return cls.HIGH_PRECISION_INIT_VAL + @classmethod + def keep_backward_unquantized(cls) -> bool: + """Should backward skip FP8 quantization and use high precision""" + recipe = cls.get_fp8_recipe() + if recipe is not None and recipe.delayed(): + # Ignore NVTE_KEEP_BACKWARD_UNQUANTIZED when delayed scaling is used + return False + return bool(int(os.getenv("NVTE_KEEP_BACKWARD_UNQUANTIZED", "0"))) + @classmethod def fp8_graph_capturing(cls) -> bool: """Is CUDA graph capture under way?"""