From 5175aad7172294244848dcc26a1b0bcc339c9671 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Wed, 7 Jan 2026 00:15:10 +0000 Subject: [PATCH 01/45] Naive implementation of grouped linear op Signed-off-by: Tim Moon --- tests/pytorch/test_fusible_ops.py | 142 ++++++ .../pytorch/ops/basic/__init__.py | 1 + .../pytorch/ops/basic/grouped_linear.py | 450 ++++++++++++++++++ 3 files changed, 593 insertions(+) create mode 100644 transformer_engine/pytorch/ops/basic/grouped_linear.py diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index ce15dd1421..d2c84403c4 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -7,6 +7,7 @@ from collections.abc import Iterable import io import math +import random from typing import Optional import pytest @@ -1924,6 +1925,147 @@ def test_dropout( abs(z_score) < 2.5758 ), f"Number of zeros is outside 99% confidence interval ({prob=}, {prob_observed=})" + @pytest.mark.parametrize("bias", (False, True)) + @pytest.mark.parametrize("input_requires_grad", (False, True)) + @pytest.mark.parametrize("weight_requires_grad", (False, True)) + def test_grouped_linear( + self, + *, + group_size: int = 4, + bias: bool, + weight_shape: tuple[int, int] = (32, 32), + split_alignment: int = 32, + dtype: torch.dtype = torch.float32, + device: torch.device = "cuda", + quantization: Optional[str] = None, + quantized_compute: bool = False, + quantized_weight: bool = False, + input_requires_grad: bool, + weight_requires_grad: bool, + ) -> None: + """Grouped GEMM""" + + # Split sizes + split_sizes = [split_alignment * i for i in range(group_size)] + random.shuffle(split_sizes) + split_sizes = torch.tensor(split_sizes, dtype=torch.int, device="cpu") + + # Make input and weight shapes consistent + out_features, in_features = weight_shape + in_shape = (split_sizes.sum().item(), in_features) + out_shape = (in_shape[0], out_features) + + # Skip invalid configurations + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) + maybe_skip_quantization(quantization, dims=out_shape) + if quantization is None and (quantized_compute or quantized_weight): + pytest.skip("Quantization scheme is not specified") + if quantization is not None and not (quantized_compute or quantized_weight): + pytest.skip("Quantization scheme is not used") + + # Random data + x_ref, x_test = make_reference_and_test_tensors( + in_shape, + quantization=quantization, + test_dtype=dtype, + test_device=device, + requires_grad=input_requires_grad, + ) + dy_ref, dy_test = make_reference_and_test_tensors( + out_shape, + quantization=quantization, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + ws_ref, ws_test = [], [] + bs_ref, bs_test = [], [] + for _ in range(group_size): + w_ref, w_test = make_reference_and_test_tensors( + (out_features, in_features), + quantization=quantization, + test_dtype=dtype, + test_device=device, + requires_grad=weight_requires_grad, + ) + b_ref, b_test = None, None + if bias: + b_ref, b_test = make_reference_and_test_tensors( + out_features, + test_dtype=dtype, + test_device=device, + requires_grad=weight_requires_grad, + ) + ws_ref.append(w_ref) + ws_test.append(w_test) + bs_ref.append(b_ref) + bs_test.append(b_test) + + # Plain PyTorch implementation + xs_ref = torch.split(x_ref, split_sizes.tolist()) + ys_ref = [] + for x, w, b in zip(xs_ref, ws_ref, bs_ref): + ys_ref.append(torch.nn.functional.linear(x, w, bias=b)) + y_ref = torch.cat(ys_ref) + if input_requires_grad or weight_requires_grad: + y_ref.backward(dy_ref) + + # Construct fusible operation + recipe = make_recipe(quantization) + with te.quantized_model_init(enabled=quantized_weight, recipe=recipe): + op = te_ops.GroupedLinear( + group_size, + in_features, + out_features, + bias=bias, + device=device, + dtype=dtype, + ) + with torch.no_grad(): + for group_idx in range(group_size): + getattr(op, f"weight{group_idx}").copy_(ws_test[group_idx]) + if bias: + getattr(op, f"bias{group_idx}").copy_(bs_test[group_idx]) + del ws_test, bs_test + for param in op.parameters(): + param.requires_grad_(requires_grad=weight_requires_grad) + + # Forward and backward pass with op + with te.autocast(enabled=quantized_compute, recipe=recipe): + y_test = op(x_test, split_sizes) + if input_requires_grad or weight_requires_grad: + y_test.backward(dy_test) + + # Expected numerical error + tols = dtype_tols(dtype) + if dtype == torch.float32: + tols = dtype_tols(torch.float16) # TF32 GEMM + if quantized_compute: + tols = quantization_tols(quantization) + + # Check results + y_test = y_test.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(y_test, y_ref, **tols) + if input_requires_grad: + dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(dx_test, x_ref.grad, **tols) + else: + assert x_test.grad is None + for group_idx in range(group_size): + w_test = getattr(op, f"weight{group_idx}") + if weight_requires_grad: + dw_test = w_test.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(dw_test, ws_ref[group_idx].grad, **tols) + else: + assert w_test.grad is None + if bias: + b_test = getattr(op, f"bias{group_idx}") + if weight_requires_grad: + db_test = b_test.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(db_test, bs_ref[group_idx].grad, **tols) + else: + assert b_test.grad is None + class TestFusedOps: """Tests for fused operations""" diff --git a/transformer_engine/pytorch/ops/basic/__init__.py b/transformer_engine/pytorch/ops/basic/__init__.py index 665ffe359c..a74f02e3a0 100644 --- a/transformer_engine/pytorch/ops/basic/__init__.py +++ b/transformer_engine/pytorch/ops/basic/__init__.py @@ -24,6 +24,7 @@ from .bias import Bias from .constant_scale import ConstantScale from .dropout import Dropout +from .grouped_linear import GroupedLinear from .identity import Identity from .l2normalization import L2Normalization from .layer_norm import LayerNorm diff --git a/transformer_engine/pytorch/ops/basic/grouped_linear.py b/transformer_engine/pytorch/ops/basic/grouped_linear.py new file mode 100644 index 0000000000..e03710189f --- /dev/null +++ b/transformer_engine/pytorch/ops/basic/grouped_linear.py @@ -0,0 +1,450 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Fusible operation for bias.""" + +from __future__ import annotations +from collections.abc import Iterable +import contextlib +import math +from typing import Any, Optional + +import torch + +import transformer_engine_torch as tex +from ...module.base import get_dummy_wgrad +from ...quantization import FP8GlobalStateManager +from ...tensor import Quantizer +from ...utils import ( + canonicalize_device, + canonicalize_dtype, + clear_tensor_data, + devices_match, +) +from .._common import is_quantized_tensor +from ..op import BasicOperation, OperationContext + + +class GroupedLinear(BasicOperation): + + # Operation expects input split sizes + num_extra_inputs: int = 1 + + def __init__( + self, + group_size: int, + in_features: int, + out_features: int, + *, + bias: bool = True, + device: Optional[torch.device | str] = None, + dtype: Optional[torch.dtype] = None, + rng_state_tracker_function: Optional[Callable[[], CudaRNGStatesTracker]] = None, + accumulate_into_main_grad: bool = False, + ) -> None: + super().__init__() + + # Weight tensor dimensions + self.group_size: int = group_size + self.in_features: int = in_features + self.out_features: int = out_features + if self.group_size <= 0: + raise ValueError(f"Invalid group size ({self.group_size})") + if self.in_features <= 0: + raise ValueError(f"Invalid input size ({self.in_features})") + if self.out_features <= 0: + raise ValueError(f"Invalid output size ({self.out_features})") + + # Weight tensor attributes + device = canonicalize_device(device) + dtype = canonicalize_dtype(dtype) + if dtype not in (torch.float32, torch.float16, torch.bfloat16): + raise ValueError(f"Supported dtypes are float32, float16, bfloat16 (got {dtype})") + + # Initialize recipe state if needed for natively quantized weight + self._with_quantized_weight: bool = FP8GlobalStateManager.with_fp8_parameters() + if self._with_quantized_weight: + self.reset_recipe_state(recipe=FP8GlobalStateManager.get_fp8_recipe()) + + # RNG state tracker + self._rng_state_tracker_function: Optional[Callable[[], CudaRNGStatesTracker]] + self._rng_state_tracker_function = rng_state_tracker_function + + # Register weights + self.weight0: torch.nn.Parameter + for group_idx in range(self.group_size): + weight_tensor = torch.empty( + self.out_features, + self.in_features, + device=device, + dtype=dtype, + ) + self.register_parameter( + f"weight{group_idx}", + torch.nn.Parameter(weight_tensor), + ) + + # Register biases + self.bias0: Optional[torch.nn.Parameter] + for group_idx in range(self.group_size): + bias_tensor = None + if bias: + bias_tensor = torch.empty( + self.out_features, + device=device, + dtype=dtype, + ) + bias_tensor = torch.nn.Parameter(bias_tensor) + self.register_parameter(f"bias{group_idx}", bias_tensor) + + # Initialize weights if needed + if device.type != "meta": + self.reset_parameters() + + # Whether to accumulate weight gradient into main_grad + self._accumulate_into_main_grad: bool = accumulate_into_main_grad + + def num_quantizers(self, mode: str) -> int: + if mode == "forward": + return 2 * self.group_size + if mode == "backward": + return self.group_size + return 0 + + @property + def has_bias(self) -> bool: + return self.bias0 is not None + + @torch.no_grad + def reset_parameters(self) -> None: + """Initialize parameter buffers and values""" + + for group_idx in range(self.group_size): + + # Parameters + weight = getattr(self, f"weight{group_idx}") + bias = getattr(self, f"bias{group_idx}") + + # Parameter device + device = weight.device + if device.type == "meta": + device = canonicalize_device(None) + + # Allocate buffers if needed + if is_quantized_tensor(weight): + weight = torch.empty( + weight.size(), + dtype=weight.dtype, + device=device, + ) + elif not devices_match(weight.device, device): + weight = torch.empty_like(weight, device=device) + if bias is not None and not devices_match(bias.device, device): + bias = torch.empty_like(bias, device=device) + + # Initialize values + init_context = contextlib.nullcontext() + if self._rng_state_tracker_function is not None: + init_context = self._rng_state_tracker_function().fork() + with init_context: + torch.nn.init.kaiming_uniform_(weight, a=math.sqrt(5)) + if bias is not None: + bias.zero_() + + # Quantize weight if needed + if self._with_quantized_weight: + quantizer = self.get_quantizer("forward", 1) + if quantizer is None: + raise RuntimeError( + "Tried to quantize weight with deferred initialization " + "due to meta device, but no quantizer was available. " + "This is most likely because the weight was initialized " + "within quantized_model_init, but the forward pass was not " + "performed within autocast." + ) + quantizer.set_usage( + rowwise=True, + columnwise=torch.is_grad_enabled(), + ) + quantizer.internal = False + with torch.no_grad(): + weight = quantizer(weight) + + # Save updated parameters + if not isinstance(weight, torch.nn.Parameter): + weight = torch.nn.Parameter(weight) + setattr(self, f"weight{group_idx}", weight) + if bias is not None: + if not isinstance(bias, torch.nn.Parameter): + bias = torch.nn.Parameter(bias) + setattr(self, f"bias{group_idx}", bias) + + def pre_first_fuser_forward(self) -> None: + super().pre_first_fuser_forward() + + # Initialize params if needed + if any(param.device.type == "meta" for param in self.parameters()): + self.reset_parameters() + + # Check that weights are consistent + dtype = self.weight0.dtype + device = self.weight0.device + weight_requires_grad = self.weight0.requires_grad + weight_tensor_type = type(self.weight0.data) + for group_idx in range(self.group_size): + weight = getattr(self, f"weight{group_idx}") + if weight.dtype != dtype: + raise RuntimeError( + f"Weight {group_idx} has invalid dtype " + f"(expected {dtype}, got {weight.dtype})." + ) + if not devices_match(weight.device, device): + raise RuntimeError( + f"Weight {group_idx} has invalid device " + f"(expected {device}, got {weight.device})." + ) + if weight.requires_grad != weight_requires_grad: + raise RuntimeError( + f"Weight {group_idx} has requires_grad={weight.requires_grad}, " + f"but expected requires_grad={weight_requires_grad}." + ) + if type(weight.data) != weight_tensor_type: + raise RuntimeError( + f"Weight {group_idx} has invalid tensor type " + f"(expected {weight_tensor_type.__name__}, " + f"got {type(weight.data).__name__})." + ) + + # Check that biases are consistent + for group_idx in range(self.group_size): + bias = getattr(self, f"bias{group_idx}") + if self.has_bias: + if bias is None: + raise RuntimeError( + f"Expected biases, but bias {group_idx} is uninitialized" + ) + if bias.dtype != dtype: + raise RuntimeError( + f"Bias {group_idx} has invalid dtype " + f"(expected {dtype}, got {bias.dtype})." + ) + if not devices_match(bias.device, device): + raise RuntimeError( + f"Bias {group_idx} has invalid device " + f"(expected {device}, got {bias.device})." + ) + if bias.requires_grad != weight_requires_grad: + raise RuntimeError( + f"Bias {group_idx} has requires_grad={bias.requires_grad}, " + f"but expected requires_grad={weight_requires_grad}." + ) + else: + if bias is not None: + raise RuntimeError( + f"Expected no biases, but bias {group_idx} is initialized" + ) + + def pre_fuser_forward(self, *, requires_grad: bool) -> None: + super().pre_fuser_forward(requires_grad=requires_grad) + if FP8GlobalStateManager.is_fp8_enabled(): + # Assume weights have consistent grad requirement + weight_requires_grad = requires_grad and self.weight0.requires_grad + + # Configure quantizer usages + # Note: We cache the quantized input for backward pass, + # but discard the quantized weights. + for group_idx in range(self.group_size): + input_quantizer = self.get_quantizer("forward", 2 * group_idx) + weight_quantizer = self.get_quantizer("forward", 2 * group_idx + 1) + grad_output_quantizer = self.get_quantizer("backward", group_idx) + input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) + weight_quantizer.set_usage(rowwise=True, columnwise=False) + grad_output_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) + + def op_forward(self, *args, **kwargs): + raise RuntimeError( + "{self.__class__.__name__} operation has " + f"{self.num_extra_inputs} extra tensor inputs " + f"and {self.num_extra_outputs} extra tensor outputs. " + "It overrides `fuser_forward` instead of `op_forward`." + ) + + def op_backward(self, *args, **kwargs): + raise RuntimeError( + "{self.__class__.__name__} operation has " + f"{self.num_extra_inputs} extra tensor inputs " + f"and {self.num_extra_outputs} extra tensor outputs. " + "It overrides `fuser_backward` instead of `op_backward`." + ) + + def fuser_forward( + self, + basic_op_ctxs: list[OperationContext], + input_: torch.Tensor, + *, + basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], + prev_op_grad_output_quantizer: Optional[Quantizer], + next_op_input_quantizer: Optional[Quantizer], + basic_op_kwargs: list[dict[str, Any]], + ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: + + # Check which grads are required + ctx = basic_op_ctxs[0] + input_requires_grad = ctx.requires_grad + weight_requires_grad = ctx.requires_grad and self.weight0.requires_grad + + # Quantizers + input_quantizers = None + weight_quantizers = None + grad_output_quantizers = None + with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() + if with_quantized_compute: + input_quantizers = [] + weight_quantizers = [] + grad_output_quantizers = [] + for group_idx in range(self.group_size): + input_quantizers.append(self.get_quantizer("forward", 2 * group_idx)) + weight_quantizers.append(self.get_quantizer("forward", 2 * group_idx + 1)) + grad_output_quantizers.append(self.get_quantizer("backward", group_idx)) + + # Get autocast dtype if needed + if torch.is_autocast_enabled(): + dtype = torch.get_autocast_dtype("cuda") + else: + dtype = self.weight0.dtype + + # Extract split sizes from extra input + # TODO Support splits on GPU + split_sizes = basic_op_extra_inputs[0][0] + split_sizes_int = [int(s) for s in split_sizes.tolist()] + if len(split_sizes_int) != self.group_size: + raise ValueError( + f"Expected {self.group_size} splits, but got {len(split_sizes_int)}." + ) + + # Extract params + weights = [] + biases = [] + for group_idx in range(self.group_size): + weights.append(getattr(self, f"weight{group_idx}")) + biases.append(getattr(self, f"bias{group_idx}")) + + # Perform GEMMs + # TODO: Fused impl, quantization + xs = torch.split(input_, split_sizes_int) + ys = [] + for x, w, b in zip(xs, weights, biases): + y = torch.nn.functional.linear(x, w, bias=b) + ys.append(y) + out = torch.cat(ys) + + # Save state for backward pass + if ctx.requires_grad: + ctx.save_for_backward(split_sizes, *xs, *weights) + ctx.with_quantized_compute = with_quantized_compute + ctx.input_quantizers = input_quantizers + ctx.weight_quantizers = weight_quantizers + ctx.grad_output_quantizers = grad_output_quantizers + ctx.grad_input_quantizers = None + ctx.dtype = dtype + ctx.input_requires_grad = input_requires_grad + ctx.weight_requires_grad = weight_requires_grad + + return out, [()] + + def fuser_backward( + self, + basic_op_ctxs: list[OperationContext], + grad_output: torch.Tensor, + *, + basic_op_grad_extra_outputs: list[tuple[torch.Tensor, ...]], + ) -> tuple[ + torch.Tensor, + Iterable[Iterable[Optional[torch.Tensor]]], + Iterable[Iterable[Optional[torch.Tensor]]], + ]: + group_size = self.group_size + has_bias = self.has_bias + + # Saved tensors from forward pass + ctx = basic_op_ctxs[0] + saved_tensors = ctx.saved_tensors + split_sizes, saved_tensors = saved_tensors[0], saved_tensors[1:] + xs, saved_tensors = saved_tensors[:group_size], saved_tensors[group_size:] + weights, saved_tensors = saved_tensors[:group_size], saved_tensors[group_size:] + + # Split grad output tensor + # TODO Support splits on GPU + split_sizes_int = [int(s) for s in split_sizes.tolist()] + dys = torch.split(grad_output, split_sizes_int) + + # Megatron-LM wgrad fusion + # Note: Get grad tensors from params so we can accumulate + # directly into it. + accumulate_into_main_grad = self._accumulate_into_main_grad + grad_weights = [None] * group_size + if ctx.weight_requires_grad and accumulate_into_main_grad: + for group_idx in range(group_size): + weight_param = getattr(self, f"weight{group_idx}") + if hasattr(weight_param, "__fsdp_param__"): + weight_param.main_grad = weight_param.get_main_grad() + accumulate_into_main_grad = not getattr(weight_param, "overwrite_main_grad", False) + if not hasattr(weight_param, "main_grad"): + raise RuntimeError( + "GroupLinear op is configured with " + "accumulate_into_main_grad=True, " + "but weight parameter does not have main_grad attribute" + ) + grad_weights[group_idx] = weight_param.main_grad.detach() + else: + accumulate_into_main_grad = False + + # Compute grad biases + # TODO: Fuse with quantization + grad_biases = [None] * group_size + if ctx.weight_requires_grad and has_bias: + for group_idx in range(group_size): + dy = dys[group_idx] + grad_biases[group_idx] = dy.reshape(-1, dy.size(-1)).sum(0) + + # Perform GEMMs + # TODO: Fused impl, quantization + grad_input = None + if ctx.input_requires_grad: + dxs = [] + for group_idx in range(group_size): + dy_shape = list(dys[group_idx].size()) + dx = torch.matmul( + dys[group_idx].reshape(-1, dy_shape[-1]), + weights[group_idx], + ) + dxs.append(dx.reshape(dy_shape[:-1] + [dx.size(-1)])) + grad_input = torch.cat(dxs) + if ctx.weight_requires_grad: + for group_idx in range(group_size): + grad_weights[group_idx] = torch.matmul( + dys[group_idx].reshape(-1, dys[group_idx].size(-1)).T, + xs[group_idx].reshape(-1, xs[group_idx].size(-1)), + out=grad_weights[group_idx], + ) + + # Clear input tensors if possible + clear_tensor_data(*xs) + + # Megatron-LM wgrad fusion + # Note: Return dummy tensor for grad weight if needed. + if accumulate_into_main_grad: + grad_weights = [None] * group_size + for group_idx in range(group_size): + weight_param = getattr(self, f"weight{group_idx}") + if hasattr(weight_param, "grad_added_to_main_grad"): + weight_param.grad_added_to_main_grad = True + grad_weights[group_idx] = get_dummy_wgrad( + list(weight_param.size()), + weight_param.dtype, + zero=getattr(weight_param, "zero_out_wgrad", False), + ) + + grad_params = grad_weights + grad_biases if has_bias else grad_weights + return grad_input, [grad_params], [(None,)] From 5ffd57e74e16835fc3b1039bb66d6c63fdbadab2 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Wed, 7 Jan 2026 02:09:25 +0000 Subject: [PATCH 02/45] Use grouped GEMM tex functions Signed-off-by: Tim Moon --- .../pytorch/ops/basic/grouped_linear.py | 149 +++++++++++------- 1 file changed, 95 insertions(+), 54 deletions(-) diff --git a/transformer_engine/pytorch/ops/basic/grouped_linear.py b/transformer_engine/pytorch/ops/basic/grouped_linear.py index e03710189f..325db168ce 100644 --- a/transformer_engine/pytorch/ops/basic/grouped_linear.py +++ b/transformer_engine/pytorch/ops/basic/grouped_linear.py @@ -13,7 +13,13 @@ import torch import transformer_engine_torch as tex -from ...module.base import get_dummy_wgrad +from ...cpp_extensions import general_grouped_gemm +from ...module.base import ( + _2X_ACC_FPROP, + _2X_ACC_DGRAD, + _2X_ACC_WGRAD, + get_dummy_wgrad, +) from ...quantization import FP8GlobalStateManager from ...tensor import Quantizer from ...utils import ( @@ -288,6 +294,8 @@ def fuser_forward( next_op_input_quantizer: Optional[Quantizer], basic_op_kwargs: list[dict[str, Any]], ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: + group_size = self.group_size + has_bias = self.has_bias # Check which grads are required ctx = basic_op_ctxs[0] @@ -303,7 +311,7 @@ def fuser_forward( input_quantizers = [] weight_quantizers = [] grad_output_quantizers = [] - for group_idx in range(self.group_size): + for group_idx in range(group_size): input_quantizers.append(self.get_quantizer("forward", 2 * group_idx)) weight_quantizers.append(self.get_quantizer("forward", 2 * group_idx + 1)) grad_output_quantizers.append(self.get_quantizer("backward", group_idx)) @@ -318,26 +326,40 @@ def fuser_forward( # TODO Support splits on GPU split_sizes = basic_op_extra_inputs[0][0] split_sizes_int = [int(s) for s in split_sizes.tolist()] - if len(split_sizes_int) != self.group_size: + if len(split_sizes_int) != group_size: raise ValueError( - f"Expected {self.group_size} splits, but got {len(split_sizes_int)}." + f"Expected {group_size} splits, but got {len(split_sizes_int)}." ) # Extract params weights = [] - biases = [] - for group_idx in range(self.group_size): + biases = [] if has_bias else None + for group_idx in range(group_size): weights.append(getattr(self, f"weight{group_idx}")) - biases.append(getattr(self, f"bias{group_idx}")) + if has_bias: + biases.append(getattr(self, f"bias{group_idx}")) - # Perform GEMMs - # TODO: Fused impl, quantization + # Split input tensor xs = torch.split(input_, split_sizes_int) - ys = [] - for x, w, b in zip(xs, weights, biases): - y = torch.nn.functional.linear(x, w, bias=b) - ys.append(y) - out = torch.cat(ys) + + # Allocate output tensor + in_shape = list(input_.size()) + out_shape = in_shape[:-1] + [self.out_features] + out = torch.empty(out_shape, dtype=dtype, device=input_.device) + + # Perform GEMMs + general_grouped_gemm( + weights, + xs, + [out], + [None] * group_size, # quantization_params + dtype, + m_splits=split_sizes_int, + bias=biases, + use_bias=has_bias, + use_split_accumulator=_2X_ACC_FPROP, + single_output=True, + ) # Save state for backward pass if ctx.requires_grad: @@ -379,55 +401,74 @@ def fuser_backward( split_sizes_int = [int(s) for s in split_sizes.tolist()] dys = torch.split(grad_output, split_sizes_int) - # Megatron-LM wgrad fusion - # Note: Get grad tensors from params so we can accumulate - # directly into it. accumulate_into_main_grad = self._accumulate_into_main_grad grad_weights = [None] * group_size - if ctx.weight_requires_grad and accumulate_into_main_grad: - for group_idx in range(group_size): - weight_param = getattr(self, f"weight{group_idx}") - if hasattr(weight_param, "__fsdp_param__"): - weight_param.main_grad = weight_param.get_main_grad() - accumulate_into_main_grad = not getattr(weight_param, "overwrite_main_grad", False) - if not hasattr(weight_param, "main_grad"): - raise RuntimeError( - "GroupLinear op is configured with " - "accumulate_into_main_grad=True, " - "but weight parameter does not have main_grad attribute" + if ctx.weight_requires_grad: + if accumulate_into_main_grad: + # Megatron-LM wgrad fusion + # Note: Get grad tensors from params so we can + # accumulate directly into it. + for group_idx in range(group_size): + weight_param = getattr(self, f"weight{group_idx}") + if hasattr(weight_param, "__fsdp_param__"): + weight_param.main_grad = weight_param.get_main_grad() + accumulate_into_main_grad = not getattr(weight_param, "overwrite_main_grad", False) + if not hasattr(weight_param, "main_grad"): + raise RuntimeError( + "GroupLinear op is configured with " + "accumulate_into_main_grad=True, " + "but weight parameter does not have main_grad attribute" + ) + else: + weight_shape = weights[0].size() + device = weights[0].device + for group_idx in range(group_size): + grad_weights[group_idx] = torch.empty( + weight_shape, + dtype=ctx.dtype, + device=device, ) - grad_weights[group_idx] = weight_param.main_grad.detach() else: accumulate_into_main_grad = False - # Compute grad biases - # TODO: Fuse with quantization - grad_biases = [None] * group_size - if ctx.weight_requires_grad and has_bias: - for group_idx in range(group_size): - dy = dys[group_idx] - grad_biases[group_idx] = dy.reshape(-1, dy.size(-1)).sum(0) - - # Perform GEMMs - # TODO: Fused impl, quantization + # Perform dgrad GEMMs grad_input = None if ctx.input_requires_grad: - dxs = [] - for group_idx in range(group_size): - dy_shape = list(dys[group_idx].size()) - dx = torch.matmul( - dys[group_idx].reshape(-1, dy_shape[-1]), - weights[group_idx], - ) - dxs.append(dx.reshape(dy_shape[:-1] + [dx.size(-1)])) - grad_input = torch.cat(dxs) + out_shape = list(grad_output.size()) + in_shape = out_shape[:-1] + [self.in_features] + grad_input = torch.empty( + in_shape, + dtype=ctx.dtype, + device=grad_output.device, + ) + general_grouped_gemm( + weights, + dys, + [grad_input], + [None] * group_size, # quantization_params + ctx.dtype, + layout="NN", + m_splits=split_sizes_int, + use_split_accumulator=_2X_ACC_DGRAD, + single_output=True, + ) + + # Perform wgrad GEMMs + grad_biases = [None] * group_size if ctx.weight_requires_grad: - for group_idx in range(group_size): - grad_weights[group_idx] = torch.matmul( - dys[group_idx].reshape(-1, dys[group_idx].size(-1)).T, - xs[group_idx].reshape(-1, xs[group_idx].size(-1)), - out=grad_weights[group_idx], - ) + _, grad_biases, _ = general_grouped_gemm( + xs, + dys, + grad_weights, + [None] * group_size, # quantization_params + ctx.dtype, + layout="NT", + m_splits=split_sizes_int, + grad=True, + use_bias=has_bias, + use_split_accumulator=_2X_ACC_WGRAD, + accumulate=accumulate_into_main_grad, + ) # Clear input tensors if possible clear_tensor_data(*xs) From 2ee42da1bdc4cfcd326a9e55e785828b0eb19e79 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Thu, 8 Jan 2026 05:50:36 +0000 Subject: [PATCH 03/45] Support quantized compute Signed-off-by: Tim Moon --- tests/pytorch/test_fusible_ops.py | 18 +- .../pytorch/ops/basic/grouped_linear.py | 156 ++++++++++++++---- 2 files changed, 136 insertions(+), 38 deletions(-) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index d2c84403c4..e7af692098 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -1926,6 +1926,10 @@ def test_dropout( ), f"Number of zeros is outside 99% confidence interval ({prob=}, {prob_observed=})" @pytest.mark.parametrize("bias", (False, True)) + @pytest.mark.parametrize("dtype", _dtypes) + @pytest.mark.parametrize("quantization", _quantization_list) + @pytest.mark.parametrize("quantized_compute", (False, True)) + @pytest.mark.parametrize("quantized_weight", (False, True)) @pytest.mark.parametrize("input_requires_grad", (False, True)) @pytest.mark.parametrize("weight_requires_grad", (False, True)) def test_grouped_linear( @@ -1933,13 +1937,13 @@ def test_grouped_linear( *, group_size: int = 4, bias: bool, - weight_shape: tuple[int, int] = (32, 32), - split_alignment: int = 32, - dtype: torch.dtype = torch.float32, + weight_shape: tuple[int, int] = (128, 128), + split_alignment: int = 128, + dtype: torch.dtype, device: torch.device = "cuda", - quantization: Optional[str] = None, - quantized_compute: bool = False, - quantized_weight: bool = False, + quantization: Optional[str], + quantized_compute: bool, + quantized_weight: bool, input_requires_grad: bool, weight_requires_grad: bool, ) -> None: @@ -1962,6 +1966,8 @@ def test_grouped_linear( pytest.skip("Quantization scheme is not specified") if quantization is not None and not (quantized_compute or quantized_weight): pytest.skip("Quantization scheme is not used") + if quantization is not None and dtype not in (torch.bfloat16, torch.float16): + pytest.skip("Quantized group GEMM is only supported with BF16/FP16") # Random data x_ref, x_test = make_reference_and_test_tensors( diff --git a/transformer_engine/pytorch/ops/basic/grouped_linear.py b/transformer_engine/pytorch/ops/basic/grouped_linear.py index 325db168ce..3851b6e3ef 100644 --- a/transformer_engine/pytorch/ops/basic/grouped_linear.py +++ b/transformer_engine/pytorch/ops/basic/grouped_linear.py @@ -28,7 +28,7 @@ clear_tensor_data, devices_match, ) -from .._common import is_quantized_tensor +from .._common import is_quantized_tensor, maybe_dequantize from ..op import BasicOperation, OperationContext @@ -268,6 +268,49 @@ def pre_fuser_forward(self, *, requires_grad: bool) -> None: weight_quantizer.set_usage(rowwise=True, columnwise=False) grad_output_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) + def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None: + super().reset_recipe_state(recipe=recipe) + + for group_idx in range(self.group_size): + # Input/grad output quantizers use internal tensors + input_quantizer = self.get_quantizer("forward", 2 * group_idx) + grad_output_quantizer = self.get_quantizer("backward", group_idx) + if input_quantizer is not None: + input_quantizer.internal = True + if grad_output_quantizer is not None: + grad_output_quantizer.internal = True + + # Handle weight quantizer + # Note: This function may be called in base class constructor, + # before any basic linear attrs have been set. + weight_quantizer = self.get_quantizer("forward", 2 * group_idx + 1) + if weight_quantizer is None: + pass + elif is_quantized_tensor(getattr(self, "weight", None)): + # Make sure weight param has correct quantizer + weight_quantizer.set_usage(rowwise=True, columnwise=torch.is_grad_enabled()) + weight_quantizer.internal = False + self.weight.update_quantizer(weight_quantizer.copy()) + else: + # Use internal tensors if quantized weights will not be + # exposed externally + weight_quantizer.internal = ( + not FP8GlobalStateManager.with_fp8_parameters() + and not getattr(self, "_with_quantized_weight", False) + ) + + # Recipe-specific configuration + # Note: This function may be called in base class constructor, + # before any basic linear attrs have been set. + if recipe is not None: + if recipe.float8_current_scaling(): + input_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale + input_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon + weight_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_weight.power_2_scale + weight_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_weight.amax_epsilon + grad_output_quantizer.force_pow_2_scales = recipe.fp8_quant_bwd_grad.power_2_scale + grad_output_quantizer.amax_epsilon_scales = recipe.fp8_quant_bwd_grad.amax_epsilon + def op_forward(self, *args, **kwargs): raise RuntimeError( "{self.__class__.__name__} operation has " @@ -303,18 +346,15 @@ def fuser_forward( weight_requires_grad = ctx.requires_grad and self.weight0.requires_grad # Quantizers - input_quantizers = None - weight_quantizers = None - grad_output_quantizers = None + input_quantizers = [None] * group_size + weight_quantizers = [None] * group_size + grad_output_quantizers = [None] * group_size with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() if with_quantized_compute: - input_quantizers = [] - weight_quantizers = [] - grad_output_quantizers = [] for group_idx in range(group_size): - input_quantizers.append(self.get_quantizer("forward", 2 * group_idx)) - weight_quantizers.append(self.get_quantizer("forward", 2 * group_idx + 1)) - grad_output_quantizers.append(self.get_quantizer("backward", group_idx)) + input_quantizers[group_idx] = self.get_quantizer("forward", 2 * group_idx) + weight_quantizers[group_idx] = self.get_quantizer("forward", 2 * group_idx + 1) + grad_output_quantizers[group_idx] = self.get_quantizer("backward", group_idx) # Get autocast dtype if needed if torch.is_autocast_enabled(): @@ -332,15 +372,34 @@ def fuser_forward( ) # Extract params - weights = [] - biases = [] if has_bias else None - for group_idx in range(group_size): - weights.append(getattr(self, f"weight{group_idx}")) - if has_bias: - biases.append(getattr(self, f"bias{group_idx}")) - - # Split input tensor - xs = torch.split(input_, split_sizes_int) + weights = [getattr(self, f"weight{idx}") for idx in range(group_size)] + bs = None + if has_bias: + bs = [ + maybe_dequantize(getattr(self, f"bias{idx}"), dtype) + for idx in range(group_size) + ] + + # Convert weight dtype if needed + ws = [] + for w, quantizer in zip(weights, weight_quantizers): + if not with_quantized_compute: + w = maybe_dequantize(w, dtype) + elif with_quantized_compute and not is_quantized_tensor(w): + quantizer = weight_quantizers[group_idx] + quantizer.set_usage(rowwise=True, columnwise=input_requires_grad) + w = quantizer(w) + ws.append(w) + + # Split input tensor and convert dtypes if needed + x = maybe_dequantize(input_, dtype) + xs = None + if with_quantized_compute: + for quantizer in input_quantizers: + quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) + xs = tex.split_quantize(x, split_sizes_int, input_quantizers) + else: + xs = torch.split(x, split_sizes_int) # Allocate output tensor in_shape = list(input_.size()) @@ -349,21 +408,36 @@ def fuser_forward( # Perform GEMMs general_grouped_gemm( - weights, + ws, xs, [out], [None] * group_size, # quantization_params dtype, m_splits=split_sizes_int, - bias=biases, + bias=bs, use_bias=has_bias, use_split_accumulator=_2X_ACC_FPROP, single_output=True, ) + # Prepare weight tensors for backward pass + if not input_requires_grad: + ws = [None] * group_size + elif with_quantized_compute: + for w, weight_param in zip(ws, weights): + if w is not weight_param: + w.update_usage(rowwise_usage=False, columnwise_usage=True) + + # Prepare input tensor for backward pass + if not weight_requires_grad: + xs = [None] * group_size + elif with_quantized_compute: + for x in xs: + x.update_usage(rowwise_usage=False, columnwise_usage=True) + # Save state for backward pass if ctx.requires_grad: - ctx.save_for_backward(split_sizes, *xs, *weights) + ctx.save_for_backward(split_sizes, *xs, *ws) ctx.with_quantized_compute = with_quantized_compute ctx.input_quantizers = input_quantizers ctx.weight_quantizers = weight_quantizers @@ -394,13 +468,34 @@ def fuser_backward( saved_tensors = ctx.saved_tensors split_sizes, saved_tensors = saved_tensors[0], saved_tensors[1:] xs, saved_tensors = saved_tensors[:group_size], saved_tensors[group_size:] - weights, saved_tensors = saved_tensors[:group_size], saved_tensors[group_size:] + ws, saved_tensors = saved_tensors[:group_size], saved_tensors[group_size:] - # Split grad output tensor + # Split grad output tensor and convert dtypes if needed # TODO Support splits on GPU split_sizes_int = [int(s) for s in split_sizes.tolist()] - dys = torch.split(grad_output, split_sizes_int) + dy = maybe_dequantize(grad_output, ctx.dtype) + dys = None + grad_biases = [None] * group_size + if ctx.with_quantized_compute: + for quantizer in ctx.grad_output_quantizers: + quantizer.set_usage( + rowwise=ctx.input_requires_grad, + columnwise=ctx.weight_requires_grad, + ) + dys = tex.split_quantize(dy, split_sizes_int, ctx.grad_output_quantizers) + if has_bias: + grad_biases = [ + dy.reshape(-1, dy.size(-1)).sum(dim=0) + for dy in torch.split(grad_output, split_sizes_int) + ] + else: + dys = torch.split(grad_output, split_sizes_int) + if has_bias: + grad_biases = [ + dy.reshape(-1, dy.size(-1)).sum(dim=0) for dy in dys + ] + # Initialize grad weight grads accumulate_into_main_grad = self._accumulate_into_main_grad grad_weights = [None] * group_size if ctx.weight_requires_grad: @@ -420,8 +515,8 @@ def fuser_backward( "but weight parameter does not have main_grad attribute" ) else: - weight_shape = weights[0].size() - device = weights[0].device + weight_shape = ws[0].size() + device = grad_output.device for group_idx in range(group_size): grad_weights[group_idx] = torch.empty( weight_shape, @@ -442,7 +537,7 @@ def fuser_backward( device=grad_output.device, ) general_grouped_gemm( - weights, + ws, dys, [grad_input], [None] * group_size, # quantization_params @@ -454,9 +549,8 @@ def fuser_backward( ) # Perform wgrad GEMMs - grad_biases = [None] * group_size if ctx.weight_requires_grad: - _, grad_biases, _ = general_grouped_gemm( + general_grouped_gemm( xs, dys, grad_weights, @@ -464,8 +558,6 @@ def fuser_backward( ctx.dtype, layout="NT", m_splits=split_sizes_int, - grad=True, - use_bias=has_bias, use_split_accumulator=_2X_ACC_WGRAD, accumulate=accumulate_into_main_grad, ) From 93e71df5dbe073a17ec46a1c08557beb3dbf9d92 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Thu, 8 Jan 2026 06:06:12 +0000 Subject: [PATCH 04/45] Debug test failures with MXFP8 or NVFP4 params Signed-off-by: Tim Moon --- .../pytorch/ops/basic/grouped_linear.py | 37 ++++++++++--------- 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/transformer_engine/pytorch/ops/basic/grouped_linear.py b/transformer_engine/pytorch/ops/basic/grouped_linear.py index 3851b6e3ef..1c71e4de73 100644 --- a/transformer_engine/pytorch/ops/basic/grouped_linear.py +++ b/transformer_engine/pytorch/ops/basic/grouped_linear.py @@ -122,20 +122,17 @@ def num_quantizers(self, mode: str) -> int: def has_bias(self) -> bool: return self.bias0 is not None - @torch.no_grad def reset_parameters(self) -> None: """Initialize parameter buffers and values""" - for group_idx in range(self.group_size): + # Parameter device + device = self.weight0.device + if device.type == "meta": + device = canonicalize_device(None) - # Parameters + # Initialize weights + for group_idx in range(self.group_size): weight = getattr(self, f"weight{group_idx}") - bias = getattr(self, f"bias{group_idx}") - - # Parameter device - device = weight.device - if device.type == "meta": - device = canonicalize_device(None) # Allocate buffers if needed if is_quantized_tensor(weight): @@ -146,8 +143,6 @@ def reset_parameters(self) -> None: ) elif not devices_match(weight.device, device): weight = torch.empty_like(weight, device=device) - if bias is not None and not devices_match(bias.device, device): - bias = torch.empty_like(bias, device=device) # Initialize values init_context = contextlib.nullcontext() @@ -155,12 +150,10 @@ def reset_parameters(self) -> None: init_context = self._rng_state_tracker_function().fork() with init_context: torch.nn.init.kaiming_uniform_(weight, a=math.sqrt(5)) - if bias is not None: - bias.zero_() # Quantize weight if needed if self._with_quantized_weight: - quantizer = self.get_quantizer("forward", 1) + quantizer = self.get_quantizer("forward", 2 * group_idx + 1) if quantizer is None: raise RuntimeError( "Tried to quantize weight with deferred initialization " @@ -181,10 +174,18 @@ def reset_parameters(self) -> None: if not isinstance(weight, torch.nn.Parameter): weight = torch.nn.Parameter(weight) setattr(self, f"weight{group_idx}", weight) - if bias is not None: - if not isinstance(bias, torch.nn.Parameter): - bias = torch.nn.Parameter(bias) - setattr(self, f"bias{group_idx}", bias) + + # Initialize biases if needed + if self.bias0 is not None: + with torch.no_grad(): + for group_idx in range(self.group_size): + bias = getattr(self, f"bias{group_idx}") + if not devices_match(bias.device, device): + bias = torch.empty_like(bias, device=device) + bias.zero_() + if not isinstance(bias, torch.nn.Parameter): + bias = torch.nn.Parameter(bias) + setattr(self, f"bias{group_idx}", bias) def pre_first_fuser_forward(self) -> None: super().pre_first_fuser_forward() From fdddc479482a91a20e74feb47368046a3c3f1725 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Sat, 10 Jan 2026 00:12:52 +0000 Subject: [PATCH 05/45] Add multiply op Signed-off-by: Tim Moon --- tests/pytorch/test_fusible_ops.py | 64 ++++++++ .../pytorch/ops/basic/__init__.py | 1 + .../pytorch/ops/basic/multiply_extra_input.py | 152 ++++++++++++++++++ 3 files changed, 217 insertions(+) create mode 100644 transformer_engine/pytorch/ops/basic/multiply_extra_input.py diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index e7af692098..dfb92c6863 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -2072,6 +2072,70 @@ def test_grouped_linear( else: assert b_test.grad is None + @pytest.mark.parametrize( + "input_shape,extra_input_shape", + ( + ((3,4,5), (3,4,5)), + ((6,7), ()), + ((), (8,9)), + ((10,11,12), (11,1)), + ((1,15), (13,14,15)), + ) + ) + @pytest.mark.parametrize("input_requires_grad", (False, True)) + @pytest.mark.parametrize("extra_input_requires_grad", (False, True)) + def test_multiply_extra_input( + self, + *, + input_shape: Iterable[int], + extra_input_shape: Iterable[int], + dtype: torch.dtype = torch.float32, + device: torch.device = "cuda", + input_requires_grad: bool, + extra_input_requires_grad: bool, + ) -> None: + """Multiply two tensors""" + + # Random data + x1_ref, x1_test = make_reference_and_test_tensors( + input_shape, + test_dtype=dtype, + test_device=device, + requires_grad=input_requires_grad, + ) + x2_ref, x2_test = make_reference_and_test_tensors( + extra_input_shape, + test_dtype=dtype, + test_device=device, + requires_grad=extra_input_requires_grad, + ) + + # Plain PyTorch implementation + y_ref = x1_ref * x2_ref + if input_requires_grad or extra_input_requires_grad: + torch.square(y_ref).sum().backward() + + # Implementation with fusible operation + op = te_ops.MultiplyExtraInput() + y_test = op(x1_test, x2_test) + if input_requires_grad or extra_input_requires_grad: + torch.square(y_test).sum().backward() + + # Check results + tols = dtype_tols(dtype) + y_test = y_test.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(y_test, y_ref, **tols) + if input_requires_grad: + dx1_test = x1_test.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(dx1_test, x1_ref.grad, **tols) + else: + assert x1_test.grad is None + if extra_input_requires_grad: + dx2_test = x2_test.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(dx2_test, x2_ref.grad, **tols) + else: + assert x2_test.grad is None + class TestFusedOps: """Tests for fused operations""" diff --git a/transformer_engine/pytorch/ops/basic/__init__.py b/transformer_engine/pytorch/ops/basic/__init__.py index a74f02e3a0..c119682151 100644 --- a/transformer_engine/pytorch/ops/basic/__init__.py +++ b/transformer_engine/pytorch/ops/basic/__init__.py @@ -29,6 +29,7 @@ from .l2normalization import L2Normalization from .layer_norm import LayerNorm from .make_extra_output import MakeExtraOutput +from .multiply_extra_input import MultiplyExtraInput from .quantize import Quantize from .reduce_scatter import ReduceScatter from .reshape import Reshape diff --git a/transformer_engine/pytorch/ops/basic/multiply_extra_input.py b/transformer_engine/pytorch/ops/basic/multiply_extra_input.py new file mode 100644 index 0000000000..b4a763bde1 --- /dev/null +++ b/transformer_engine/pytorch/ops/basic/multiply_extra_input.py @@ -0,0 +1,152 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Fusible operation for multiplying with extra input tensor.""" + +from __future__ import annotations +from collections.abc import Iterable +from typing import Any, Optional + +import torch + +from ..op import BasicOperation, OperationContext +from .._common import maybe_dequantize + + +def _reduce_broadcast_dims( + x: torch.Tensor, + target_shape: Iterable[int], +) -> torch.Tensor: + """Reduce a tensor down to a target shape. + + The input tensor shape and target shape are assumed to be + broadcast-compatible. In other words, a tensor with the target + shape can be broadcast to match the input tensor shape. + + """ + shape = tuple(x.size()) + target_shape = tuple(target_shape) + + # Return immediately if tensor already has correct shape + if shape == target_shape: + return x + + # Determine reduction dimensions + reduce_dims = [] + if len(shape) < len(target_shape): + raise ValueError( + "Invalid target shape " + f"(shape={shape} cannot be broadcast to shape={target_shape})." + ) + elif len(shape) > len(target_shape): + reduce_dims.extend(range(len(shape) - len(target_shape))) + for idx in range(-len(target_shape), 0): + if shape[idx] == target_shape[idx]: + pass + elif target_shape[idx] != 1: + raise ValueError( + "Invalid target shape " + f"(shape={shape} cannot be broadcast to shape={target_shape})." + ) + else: + reduce_dims.append(idx) + + # Perform reduction + return x.sum(reduce_dims).reshape(target_shape) + + +class MultiplyExtraInput(BasicOperation): + """Multiply with extra input tensor. + + If the tensor shapes do not match, they will follow NumPy + broadcasting semantics. + + """ + + # Operation expects extra input tensor + num_extra_inputs: int = 1 + + def op_forward(self, *args, **kwargs) -> None: + raise RuntimeError( + "{self.__class__.__name__} operation has " + f"{self.num_extra_inputs} extra tensor inputs " + f"and {self.num_extra_outputs} extra tensor outputs. " + "It overrides `fuser_forward` instead of `op_forward`." + ) + + def op_backward(self, *args, **kwargs) -> None: + raise RuntimeError( + "{self.__class__.__name__} operation has " + f"{self.num_extra_inputs} extra tensor inputs " + f"and {self.num_extra_outputs} extra tensor outputs. " + "It overrides `fuser_backward` instead of `op_backward`." + ) + + def fuser_forward( + self, + basic_op_ctxs: list[OperationContext], + input_: torch.Tensor, + *, + basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], + prev_op_grad_output_quantizer: Optional[Quantizer], + next_op_input_quantizer: Optional[Quantizer], + basic_op_kwargs: list[dict[str, Any]], + ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: + extra_input = basic_op_extra_inputs[0][0] + + # Determine compute dtype + if torch.is_autocast_enabled(): + dtype = torch.get_autocast_dtype("cuda") + elif isinstance(input_, torch.Tensor): + dtype = input_.dtype + else: + dtype = extra_input.dtype + + # Perform multiplication + x1 = maybe_dequantize(input_, dtype) + x2 = maybe_dequantize(extra_input, dtype) + output = input_ * extra_input + + # Save state for backward pass + ctx = basic_op_ctxs[0] + if ctx.requires_grad: + ctx.input_shape = x1.size() + ctx.extra_input_shape = extra_input.size() + ctx.input_requires_grad = True + if isinstance(input_, torch.Tensor): + ctx.input_requires_grad = input_.requires_grad + ctx.extra_input_requires_grad = extra_input.requires_grad + ctx.save_for_backward( + x1 if ctx.extra_input_requires_grad else None, + x2 if ctx.input_requires_grad else None, + ) + + return output, [()] + + def fuser_backward( + self, + basic_op_ctxs: list[OperationContext], + grad_output: torch.Tensor, + *, + basic_op_grad_extra_outputs: list[tuple[torch.Tensor, ...]], + ) -> tuple[ + torch.Tensor, + Iterable[Iterable[Optional[torch.Tensor]]], + Iterable[Iterable[Optional[torch.Tensor]]], + ]: + ctx = basic_op_ctxs[0] + input_, extra_input = ctx.saved_tensors + grad_input = None + if ctx.input_requires_grad: + grad_input = _reduce_broadcast_dims( + grad_output * extra_input, + ctx.input_shape, + ) + grad_extra_input = None + if ctx.extra_input_requires_grad: + grad_extra_input = _reduce_broadcast_dims( + grad_output * input_, + ctx.extra_input_shape, + ) + return grad_input, [()], [(grad_extra_input,)] From b448a17d2be841eaa02ade6975dd0b4f401fa6fc Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Sat, 10 Jan 2026 03:21:49 +0000 Subject: [PATCH 06/45] Bug fixes Signed-off-by: Tim Moon --- transformer_engine/pytorch/ops/basic/grouped_linear.py | 9 ++------- .../pytorch/ops/basic/multiply_extra_input.py | 2 -- 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/transformer_engine/pytorch/ops/basic/grouped_linear.py b/transformer_engine/pytorch/ops/basic/grouped_linear.py index 1c71e4de73..d2a4b379e5 100644 --- a/transformer_engine/pytorch/ops/basic/grouped_linear.py +++ b/transformer_engine/pytorch/ops/basic/grouped_linear.py @@ -490,7 +490,7 @@ def fuser_backward( for dy in torch.split(grad_output, split_sizes_int) ] else: - dys = torch.split(grad_output, split_sizes_int) + dys = torch.split(dy, split_sizes_int) if has_bias: grad_biases = [ dy.reshape(-1, dy.size(-1)).sum(dim=0) for dy in dys @@ -509,12 +509,7 @@ def fuser_backward( if hasattr(weight_param, "__fsdp_param__"): weight_param.main_grad = weight_param.get_main_grad() accumulate_into_main_grad = not getattr(weight_param, "overwrite_main_grad", False) - if not hasattr(weight_param, "main_grad"): - raise RuntimeError( - "GroupLinear op is configured with " - "accumulate_into_main_grad=True, " - "but weight parameter does not have main_grad attribute" - ) + grad_weights[group_idx] = weight_param.main_grad else: weight_shape = ws[0].size() device = grad_output.device diff --git a/transformer_engine/pytorch/ops/basic/multiply_extra_input.py b/transformer_engine/pytorch/ops/basic/multiply_extra_input.py index b4a763bde1..c1846f5e0d 100644 --- a/transformer_engine/pytorch/ops/basic/multiply_extra_input.py +++ b/transformer_engine/pytorch/ops/basic/multiply_extra_input.py @@ -114,8 +114,6 @@ def fuser_forward( ctx.input_shape = x1.size() ctx.extra_input_shape = extra_input.size() ctx.input_requires_grad = True - if isinstance(input_, torch.Tensor): - ctx.input_requires_grad = input_.requires_grad ctx.extra_input_requires_grad = extra_input.requires_grad ctx.save_for_backward( x1 if ctx.extra_input_requires_grad else None, From 3f388971a48ebfd5e361a1043d94f8ce68bdc518 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Wed, 14 Jan 2026 06:49:45 +0000 Subject: [PATCH 07/45] Expose option for custom op fusions Refactor fusion functions to remove index bookkeeping. Refactor fused ops to use consistent operation order. Signed-off-by: Tim Moon --- tests/pytorch/test_fusible_ops.py | 8 +- transformer_engine/pytorch/ops/__init__.py | 9 +- .../pytorch/ops/fused/__init__.py | 59 +++---- .../ops/fused/backward_activation_bias.py | 120 ++++++------- .../pytorch/ops/fused/backward_add_rmsnorm.py | 103 +++++------ .../pytorch/ops/fused/backward_linear_add.py | 118 ++++++------- .../ops/fused/backward_linear_scale.py | 110 ++++++------ .../fused/forward_linear_bias_activation.py | 118 +++++++------ .../ops/fused/forward_linear_bias_add.py | 119 ++++++------- .../ops/fused/forward_linear_scale_add.py | 126 +++++++------ .../ops/fused/userbuffers_backward_linear.py | 163 ++++++++--------- .../ops/fused/userbuffers_forward_linear.py | 150 ++++++++-------- transformer_engine/pytorch/ops/fuser.py | 166 ++++++++++++------ 13 files changed, 689 insertions(+), 680 deletions(-) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index ce15dd1421..7eb5302fca 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -2329,13 +2329,13 @@ def test_backward_activation_bias( backward_ops = model._module_groups[0]._backward_ops if with_quantization: assert len(backward_ops) == 2 - assert isinstance(backward_ops[0][0], BackwardActivationBias) - assert isinstance(backward_ops[1][0], te_ops.Quantize) + assert isinstance(backward_ops[0][0], te_ops.Quantize) + assert isinstance(backward_ops[1][0], BackwardActivationBias) else: assert len(backward_ops) == 3 - assert isinstance(backward_ops[0][0], act_type) + assert isinstance(backward_ops[0][0], te_ops.Quantize) assert isinstance(backward_ops[1][0], te_ops.Bias) - assert isinstance(backward_ops[2][0], te_ops.Quantize) + assert isinstance(backward_ops[2][0], act_type) # Expected numerical error tols = dtype_tols(dtype) diff --git a/transformer_engine/pytorch/ops/__init__.py b/transformer_engine/pytorch/ops/__init__.py index 2b270ea3de..4f1d64623a 100644 --- a/transformer_engine/pytorch/ops/__init__.py +++ b/transformer_engine/pytorch/ops/__init__.py @@ -8,7 +8,8 @@ """ -from transformer_engine.pytorch.ops.basic import * -from transformer_engine.pytorch.ops.linear import Linear -from transformer_engine.pytorch.ops.op import FusibleOperation -from transformer_engine.pytorch.ops.sequential import Sequential +from .basic import * +from .fuser import register_backward_fusion, register_forward_fusion +from .linear import Linear +from .op import FusibleOperation +from .sequential import Sequential diff --git a/transformer_engine/pytorch/ops/fused/__init__.py b/transformer_engine/pytorch/ops/fused/__init__.py index f4568ff25d..1ebfe23060 100644 --- a/transformer_engine/pytorch/ops/fused/__init__.py +++ b/transformer_engine/pytorch/ops/fused/__init__.py @@ -4,39 +4,26 @@ """Compound tensor operation supported by the operation fuser.""" -from .backward_activation_bias import ( - BackwardActivationBias, - fuse_backward_activation_bias, -) -from .backward_add_rmsnorm import ( - BackwardAddRMSNorm, - fuse_backward_add_rmsnorm, -) -from .backward_linear_add import ( - BackwardLinearAdd, - fuse_backward_linear_add, -) -from .backward_linear_scale import ( - BackwardLinearScale, - fuse_backward_linear_scale, -) -from .forward_linear_bias_activation import ( - ForwardLinearBiasActivation, - fuse_forward_linear_bias_activation, -) -from .forward_linear_bias_add import ( - ForwardLinearBiasAdd, - fuse_forward_linear_bias_add, -) -from .forward_linear_scale_add import ( - ForwardLinearScaleAdd, - fuse_forward_linear_scale_add, -) -from .userbuffers_backward_linear import ( - UserbuffersBackwardLinear, - fuse_userbuffers_backward_linear, -) -from .userbuffers_forward_linear import ( - UserbuffersForwardLinear, - fuse_userbuffers_forward_linear, -) +from ..fuser import register_backward_fusion, register_forward_fusion +from .backward_activation_bias import BackwardActivationBias +from .backward_add_rmsnorm import BackwardAddRMSNorm +from .backward_linear_add import BackwardLinearAdd +from .backward_linear_scale import BackwardLinearScale +from .forward_linear_bias_activation import ForwardLinearBiasActivation +from .forward_linear_bias_add import ForwardLinearBiasAdd +from .forward_linear_scale_add import ForwardLinearScaleAdd +from .userbuffers_backward_linear import UserbuffersBackwardLinear +from .userbuffers_forward_linear import UserbuffersForwardLinear + +# Register forward fusions +register_forward_fusion(UserbuffersForwardLinear.fuse_forward_ops) +register_forward_fusion(ForwardLinearBiasAdd.fuse_forward_ops) +register_forward_fusion(ForwardLinearBiasActivation.fuse_forward_ops) +register_forward_fusion(ForwardLinearScaleAdd.fuse_forward_ops) + +# Register backward fusions +register_backward_fusion(UserbuffersBackwardLinear.fuse_backward_ops) +register_backward_fusion(BackwardLinearAdd.fuse_backward_ops) +register_backward_fusion(BackwardLinearScale.fuse_backward_ops) +register_backward_fusion(BackwardActivationBias.fuse_backward_ops) +register_backward_fusion(BackwardAddRMSNorm.fuse_backward_ops) diff --git a/transformer_engine/pytorch/ops/fused/backward_activation_bias.py b/transformer_engine/pytorch/ops/fused/backward_activation_bias.py index d5b9ce0e96..8fd0ac6cde 100644 --- a/transformer_engine/pytorch/ops/fused/backward_activation_bias.py +++ b/transformer_engine/pytorch/ops/fused/backward_activation_bias.py @@ -53,8 +53,8 @@ def fuser_backward( ]: # Get basic operation contexts - activation_op_ctx = basic_op_ctxs[0] - bias_op_ctx = basic_op_ctxs[1] + bias_op_ctx = basic_op_ctxs[0] + activation_op_ctx = basic_op_ctxs[1] # Saved tensors from forward pass (act_input,) = activation_op_ctx.saved_tensors @@ -79,68 +79,58 @@ def fuser_backward( # Clear activation input tensor clear_tensor_data(act_input) - return dx, [(), (db,)], [(), ()] + return dx, [(db,), ()], [(), ()] - -def fuse_backward_activation_bias( - ops: list[tuple[FusibleOperation, list[int]]], - recipe: Optional[Recipe], -) -> list[tuple[FusibleOperation, list[int]]]: - """Fused backward dact + dbias + quantize - - Parameters - ---------- - ops : list of tuples - Backward pass operations and the indices of the corresponding - basic operations. - recipe : Recipe, optional - Used quantization recipe - - Returns - ------- - ops : list of tuples - Updated backward pass operations - - """ - - # Check if recipe supports bias activation fusion - if recipe is None: - return ops - - # Scan through ops, fusing if possible - out = [] - window = [] - while len(ops) >= 3: + @staticmethod + def fuse_backward_ops( + ops: list[FusibleOperation], + *, + recipe: Optional[Recipe] = None, + **unused, + ) -> list[FusibleOperation]: + """Apply operation fusion for backward pass. + + Parameters + ---------- + ops : list of FusibleOperation + Backward pass operations. + recipe : Recipe, optional + Quantization recipe. + + Returns + ------- + ops : list of FusibleOperation + Updated backward pass operations + + """ + + # Check if recipe supports bias activation fusion + if recipe is None: + return ops + + # Scan through ops, fusing if possible + out = [] + window = [] + while ops: + + # Shift window + while len(window) >= 3: + out.append(window[0]) + window = window[1:] + while ops and len(window) < 3: + window.append(ops[0]) + ops = ops[1:] + + # Construct fused op if window matches pattern + if ( + len(window) == 3 + and isinstance(window[2], _fusible_activations) + and isinstance(window[1], Bias) + and window[0].get_grad_output_quantizer() is not None + ): + op = BackwardActivationBias(bias=window[1], activation=window[2]) + window = [window[0], op] + + # Return list of ops out.extend(window) - - # Check if first op is a supported activation - window, ops = ops[:1], ops[1:] - op, _ = window[0] - if not isinstance(op, _fusible_activations): - continue - - # Check if second op is bias - op, _ = ops[0] - if not isinstance(op, Bias): - continue - - # Check if third op has a grad input quantizer - op, _ = ops[1] - if not op.num_quantizers("backward") > 0: - continue - - window.extend(ops[:1]) - ops = ops[1:] - - # Replace window with fused op - op = BackwardActivationBias( - activation=window[0][0], - bias=window[1][0], - ) - basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window] - window = [(op, basic_op_idxs)] - - # Return list of ops - out.extend(window) - out.extend(ops) - return out + return out diff --git a/transformer_engine/pytorch/ops/fused/backward_add_rmsnorm.py b/transformer_engine/pytorch/ops/fused/backward_add_rmsnorm.py index 186619caae..2eaded064d 100644 --- a/transformer_engine/pytorch/ops/fused/backward_add_rmsnorm.py +++ b/transformer_engine/pytorch/ops/fused/backward_add_rmsnorm.py @@ -42,7 +42,7 @@ def fuser_backward( # Get basic operations rmsnorm_op = self.basic_ops[1] - rmsnorm_op_ctx = basic_op_ctxs[0] + rmsnorm_op_ctx = basic_op_ctxs[1] # Saved tensors from forward pass x, rstdevs = rmsnorm_op_ctx.saved_tensors @@ -53,7 +53,7 @@ def fuser_backward( # Check input tensors dtype = rmsnorm_op_ctx.dtype - extra_grad = basic_op_grad_extra_outputs[1][0] + extra_grad = basic_op_grad_extra_outputs[0][0] dy = maybe_dequantize(grad_output.contiguous(), dtype).view(x.size()) w = maybe_dequantize(rmsnorm_op.weight, dtype).view((inner_dim,)) add = maybe_dequantize(extra_grad.contiguous(), dtype).view(x.size()) @@ -77,57 +77,50 @@ def fuser_backward( grad_input = dx.view(grad_output.size()) grad_weight = dw.view(weight_dims) - return grad_input, [(grad_weight,), ()], [(), ()] - - -def fuse_backward_add_rmsnorm( - ops: list[tuple[FusibleOperation, list[int]]], -) -> list[tuple[FusibleOperation, list[int]]]: - """Fused backward RMNorm + add - - Parameters - ---------- - ops : list of tuples - Backward pass operations and the indices of the corresponding - basic operations. - - Returns - ------- - ops : list of tuples - Updated backward pass operations - - """ - - # Scan through ops, fusing if possible - out = [] - window = [] - while len(ops) >= 2: + return grad_input, [(), (grad_weight,)], [(), ()] + + @staticmethod + def fuse_backward_ops( + ops: list[FusibleOperation], + **unused, + ) -> list[FusibleOperation]: + """Apply operation fusion for backward pass. + + Parameters + ---------- + ops : list of FusibleOperation + Backward pass operations. + + Returns + ------- + ops : list of FusibleOperation + Updated backward pass operations + + """ + + # Scan through ops, fusing if possible + out = [] + window = [] + while ops: + + # Shift window + while len(window) >= 2: + out.append(window[0]) + window = window[1:] + while ops and len(window) < 2: + window.append(ops[0]) + ops = ops[1:] + + # Construct fused op if window matches pattern + if ( + len(window) == 2 + and isinstance(window[0], MakeExtraOutput) + and isinstance(window[1], RMSNorm) + and not window[0]._in_place + ): + op = BackwardAddRMSNorm(add=window[0], rmsnorm=window[1]) + window = [op] + + # Return list of ops out.extend(window) - - # Check if first op is linear - window, ops = ops[:1], ops[1:] - op, _ = window[0] - if not isinstance(op, RMSNorm): - continue - - # Check if second op is "make extra output" - op, _ = ops[0] - if not isinstance(op, MakeExtraOutput): - continue - if op._in_place: - continue - window.extend(ops[:1]) - ops = ops[1:] - - # Replace window with fused op - op = BackwardAddRMSNorm( - rmsnorm=window[0][0], - add=window[1][0], - ) - basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window] - window = [(op, basic_op_idxs)] - - # Return list of ops - out.extend(window) - out.extend(ops) - return out + return out diff --git a/transformer_engine/pytorch/ops/fused/backward_linear_add.py b/transformer_engine/pytorch/ops/fused/backward_linear_add.py index 5e7339db85..e6307c254c 100644 --- a/transformer_engine/pytorch/ops/fused/backward_linear_add.py +++ b/transformer_engine/pytorch/ops/fused/backward_linear_add.py @@ -45,7 +45,7 @@ def fuser_backward( # Get basic operations linear_op = self.basic_ops[1] - linear_op_ctx = basic_op_ctxs[0] + linear_op_ctx = basic_op_ctxs[1] # Saved tensors from forward pass (x_local, w) = linear_op_ctx.saved_tensors @@ -71,7 +71,7 @@ def fuser_backward( accumulate_into_main_grad = False # Linear backward pass - grad_input = basic_op_grad_extra_outputs[1][0] + grad_input = basic_op_grad_extra_outputs[0][0] grad_input, grad_weight = BasicLinear._functional_backward( grad_output=grad_output, input=x_local, @@ -109,61 +109,61 @@ def fuser_backward( zero=getattr(weight_param, "zero_out_wgrad", False), ) - return grad_input, [(grad_weight,), ()], [(), ()] - - -def fuse_backward_linear_add( - ops: list[tuple[FusibleOperation, list[int]]], -) -> list[tuple[FusibleOperation, list[int]]]: - """Fused backward dgrad GEMM + add - - Parameters - ---------- - ops : list of tuples - Backward pass operations and the indices of the corresponding - basic operations. - - Returns - ------- - ops : list of tuples - Updated backward pass operations - - """ - - # Scan through ops, fusing if possible - out = [] - window = [] - while len(ops) >= 2: + return grad_input, [(), (grad_weight,)], [(), ()] + + @staticmethod + def fuse_backward_ops( + ops: list[FusibleOperation], + **unused, + ) -> list[FusibleOperation]: + """Apply operation fusion for backward pass. + + Parameters + ---------- + ops : list of FusibleOperation + Backward pass operations. + + Returns + ------- + ops : list of FusibleOperation + Updated backward pass operations + + """ + + # Scan through ops, fusing if possible + out = [] + window = [] + while ops: + + # Shift window + while len(window) >= 2: + out.append(window[0]) + window = window[1:] + while ops and len(window) < 2: + window.append(ops[0]) + ops = ops[1:] + + # Check if window matches pattern + matches_pattern = True + if not ( + len(window) == 2 + and isinstance(window[0], MakeExtraOutput) + and isinstance(window[1], BasicLinear) + ): + matches_pattern = False + elif not window[0]._in_place: + # Fused op accumulates grad input in-place + matches_pattern = False + elif window[1].tensor_parallel_mode == "column": + # Column tensor-parallelism requires communication + # after the dgrad GEMM + matches_pattern = False + + # Construct fused op if window matches pattern + if matches_pattern: + op = BackwardLinearAdd(backward_add=window[0], linear=window[1]) + window = [op] + + # Return list of ops out.extend(window) - - # Check if first op is linear - window, ops = ops[:1], ops[1:] - op, _ = window[0] - if not isinstance(op, BasicLinear): - continue - if op.tensor_parallel_mode == "column": - # Row tensor-parallelism requires communication after the - # GEMM - continue - - # Check if second op is "make extra output" - op, _ = ops[0] - if not isinstance(op, MakeExtraOutput): - continue - if not op._in_place: - continue - window.extend(ops[:1]) - ops = ops[1:] - - # Replace window with fused op - op = BackwardLinearAdd( - linear=window[0][0], - backward_add=window[1][0], - ) - basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window] - window = [(op, basic_op_idxs)] - - # Return list of ops - out.extend(window) - out.extend(ops) - return out + return out diff --git a/transformer_engine/pytorch/ops/fused/backward_linear_scale.py b/transformer_engine/pytorch/ops/fused/backward_linear_scale.py index f7f59e65c9..2fd72fb963 100644 --- a/transformer_engine/pytorch/ops/fused/backward_linear_scale.py +++ b/transformer_engine/pytorch/ops/fused/backward_linear_scale.py @@ -45,7 +45,7 @@ def fuser_backward( # Get basic operations linear_op = self.basic_ops[0] - linear_op_ctx = basic_op_ctxs[1] + linear_op_ctx = basic_op_ctxs[0] scale_op = self.basic_ops[1] # Saved tensors from forward pass @@ -109,58 +109,58 @@ def fuser_backward( zero=getattr(weight_param, "zero_out_wgrad", False), ) - return grad_input, [(), (grad_weight,)], [(), ()] - - -def fuse_backward_linear_scale( - ops: list[tuple[FusibleOperation, list[int]]], -) -> list[tuple[FusibleOperation, list[int]]]: - """Fused backward dgrad GEMM + constant scale - - Parameters - ---------- - ops : list of tuples - Backward pass operations and the indices of the corresponding - basic operations. - - Returns - ------- - ops : list of tuples - Updated backward pass operations - - """ - - # Scan through ops, fusing if possible - out = [] - window = [] - while len(ops) >= 2: + return grad_input, [(grad_weight,), ()], [(), ()] + + @staticmethod + def fuse_backward_ops( + ops: list[FusibleOperation], + **unused, + ) -> list[FusibleOperation]: + """Apply operation fusion for backward pass. + + Parameters + ---------- + ops : list of FusibleOperation + Backward pass operations. + + Returns + ------- + ops : list of FusibleOperation + Updated backward pass operations + + """ + + # Scan through ops, fusing if possible + out = [] + window = [] + while ops: + + # Shift window + while len(window) >= 2: + out.append(window[0]) + window = window[1:] + while ops and len(window) < 2: + window.append(ops[0]) + ops = ops[1:] + + # Check if window matches pattern + matches_pattern = True + if not ( + len(window) == 2 + and isinstance(window[0], BasicLinear) + and isinstance(window[1], ConstantScale) + ): + matches_pattern = False + elif window[0].tensor_parallel_mode == "column": + # Column tensor-parallelism requires communication + # after the dgrad GEMM + matches_pattern = False + + # Construct fused op if window matches pattern + if matches_pattern: + op = BackwardLinearScale(linear=window[0], scale=window[1]) + window = [op] + + # Return list of ops out.extend(window) - - # Check if first op is constant scale - window, ops = ops[:1], ops[1:] - op, _ = window[0] - if not isinstance(op, ConstantScale): - continue - - # Check if second op is linear - op, _ = ops[0] - if not isinstance(op, BasicLinear): - continue - if op.tensor_parallel_mode == "column": - # Column tensor-parallelism requires communication after the dgrad GEMM - continue - window.extend(ops[:1]) - ops = ops[1:] - - # Replace window with fused op - op = BackwardLinearScale( - scale=window[0][0], - linear=window[1][0], - ) - basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window] - window = [(op, basic_op_idxs)] - - # Return list of ops - out.extend(window) - out.extend(ops) - return out + return out diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py index 1c5edfcfcb..8e602e4cc2 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py @@ -134,62 +134,64 @@ def fuser_forward( return output, [() for _ in range(len(self.basic_ops))] - -def fuse_forward_linear_bias_activation( - ops: list[tuple[FusibleOperation, list[int]]], -) -> list[tuple[FusibleOperation, list[int]]]: - """Fuse forward GEMM + bias + activation - - Parameters - ---------- - ops : list of tuples - Forward pass operations and the indices of the corresponding - basic operations. - - Returns - ------- - ops : list of tuples - Updated forward pass operations - - """ - - # Scan through ops, fusing if possible - out = [] - window = [] - while len(ops) >= 2: + @staticmethod + def fuse_forward_ops( + ops: list[FusibleOperation], + **unused, + ) -> list[FusibleOperation]: + """Apply operation fusion for forward pass. + + Parameters + ---------- + ops : list of FusibleOperation + Forward pass operations. + + Returns + ------- + ops : list of FusibleOperation + Updated forward pass operations + + """ + + # Scan through ops, fusing if possible + out = [] + window = [] + while ops: + + # Shift window + while len(window) >= 2: + out.append(window[0]) + window = window[1:] + while ops and len(window) < 2: + window.append(ops[0]) + ops = ops[1:] + + # Check if window matches pattern + matches_pattern = True + if not ( + len(window) == 2 + and isinstance(window[0], BasicLinear) + and isinstance(window[1], Bias) + ): + matches_pattern = False + elif window[0].tensor_parallel_mode == "row": + # Row tensor-parallelism requires communication after + # the GEMM + matches_pattern = False + elif window[0].weight.dtype not in (torch.float16, torch.bfloat16): + # cuBLAS only supports fused GEMM+bias+activation with + # FP16 and BF16 output + matches_pattern = False + + # Construct fused op if window matches pattern + if matches_pattern: + op = ForwardLinearBiasActivation( + linear=window[0], + bias=window[1], + activation=None, + ) + window = [op] + + # Return list of ops out.extend(window) - - # Check if first op is linear - window, ops = ops[:1], ops[1:] - op1, _ = window[0] - if not isinstance(op1, BasicLinear): - continue - if op1.tensor_parallel_mode == "row": - # Row tensor-parallelism requires communication after the - # GEMM - continue - if op1.weight.dtype not in (torch.float16, torch.bfloat16): - # cuBLAS only supports fused GEMM+bias+activation with - # FP16 and BF16 output - continue - - # Check if second op is bias - op2, _ = ops[0] - if not isinstance(op2, Bias): - continue - window.extend(ops[:1]) - ops = ops[1:] - - # Replace window with fused op - op = ForwardLinearBiasActivation( - linear=window[0][0], - bias=window[1][0], - activation=None, - ) - basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window] - window = [(op, basic_op_idxs)] - - # Return list of ops - out.extend(window) - out.extend(ops) - return out + return out diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py index 4efb33e037..cc9554c876 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py @@ -131,72 +131,63 @@ def fuser_forward( return output, [() for _ in range(len(self.basic_ops))] + @staticmethod + def fuse_forward_ops( + ops: list[FusibleOperation], + **unused, + ) -> list[FusibleOperation]: + """Apply operation fusion for forward pass. + + Parameters + ---------- + ops : list of FusibleOperation + Forward pass operations. + + Returns + ------- + ops : list of FusibleOperation + Updated forward pass operations + + """ + + # Scan through ops, fusing if possible + out = [] + window = [] + while ops: + + # Shift window + out.extend(window) + window = [ops[0]] + ops = ops[1:] -def fuse_forward_linear_bias_add( - ops: list[tuple[FusibleOperation, list[int]]], -) -> list[tuple[FusibleOperation, list[int]]]: - """Fuse forward GEMM + bias + add - - Parameters - ---------- - ops : list of tuples - Forward pass operations and the indices of the corresponding - basic operations. + # Check if first op is linear + if not isinstance(window[0], BasicLinear): + continue + if window[0].tensor_parallel_mode == "row": + # Row tensor-parallelism requires communication after + # the GEMM + continue + linear = window[0] - Returns - ------- - ops : list of tuples - Updated forward pass operations + # Check if next op is bias + bias = None + if ops and isinstance(ops[0], Bias): + window.append(ops[0]) + ops = ops[1:] + bias = window[-1] + + # Check if next op is in-place add extra input + if ops and isinstance(ops[0], AddExtraInput) and ops[0]._in_place: + window.append(ops[0]) + ops = ops[1:] + add = window[-1] + else: + continue - """ + # Replace window with fused op + op = ForwardLinearBiasAdd(linear=linear, bias=bias, add=add) + window = [op] - # Scan through ops, fusing if possible - out = [] - window = [] - while len(ops) >= 2: + # Return list of ops out.extend(window) - - # Check if first op is linear - window, ops = ops[:1], ops[1:] - op, _ = window[0] - if not isinstance(op, BasicLinear): - continue - if op.tensor_parallel_mode == "row": - # Row tensor-parallelism requires communication after the - # GEMM - continue - linear = op - op, _ = ops[0] - - # Check if next op is bias - bias = None - if isinstance(op, Bias): - bias = op - window.extend(ops[:1]) - ops = ops[1:] - if len(ops) == 0: - continue - op, _ = ops[0] - - # Check if next op is in-place add extra input - if not isinstance(op, AddExtraInput): - continue - if not op._in_place: - continue - add = op - window.extend(ops[:1]) - ops = ops[1:] - - # Replace window with fused op - op = ForwardLinearBiasAdd( - linear=linear, - bias=bias, - add=add, - ) - basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window] - window = [(op, basic_op_idxs)] - - # Return list of ops - out.extend(window) - out.extend(ops) - return out + return out diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py index 25b40f76e3..3cf9f294c3 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py @@ -110,70 +110,64 @@ def fuser_forward( return output, [() for _ in range(len(self.basic_ops))] - -def fuse_forward_linear_scale_add( - ops: list[tuple[FusibleOperation, list[int]]], -) -> list[tuple[FusibleOperation, list[int]]]: - """Fuse forward GEMM + scale + add - - Parameters - ---------- - ops : list of tuples - Forward pass operations and the indices of the corresponding - basic operations. - - Returns - ------- - ops : list of tuples - Updated forward pass operations - - """ - - # Scan through ops, fusing if possible - out = [] - window = [] - while len(ops) >= 3: + @staticmethod + def fuse_forward_ops( + ops: list[FusibleOperation], + **unused, + ) -> list[FusibleOperation]: + """Apply operation fusion for forward pass. + + Parameters + ---------- + ops : list of FusibleOperation + Forward pass operations. + + Returns + ------- + ops : list of FusibleOperation + Updated forward pass operations + + """ + + # Scan through ops, fusing if possible + out = [] + window = [] + while ops: + + # Shift window + while len(window) >= 3: + out.append(window[0]) + window = window[1:] + while ops and len(window) < 3: + window.append(ops[0]) + ops = ops[1:] + + # Check if window matches pattern + matches_pattern = True + if not ( + len(window) == 3 + and isinstance(window[0], BasicLinear) + and isinstance(window[1], ConstantScale) + and isinstance(window[2], AddExtraInput) + ): + matches_pattern = False + elif window[0].tensor_parallel_mode == "row": + # Row tensor-parallelism requires communication after + # the GEMM + matches_pattern = False + elif not window[2]._in_place: + # Fused op accumulates output in-place + matches_pattern = False + + # Construct fused op if window matches pattern + if matches_pattern: + op = ForwardLinearScaleAdd( + linear=window[0], + scale=window[1], + add=window[2], + ) + window = [op] + + # Return list of ops out.extend(window) - - # Check if first op is linear - window, ops = ops[:1], ops[1:] - op, _ = window[0] - if not isinstance(op, BasicLinear): - continue - if op.tensor_parallel_mode == "row": - # Row tensor-parallelism requires communication after the - # GEMM - continue - linear = op - op, _ = ops[0] - - # Check if next op is constant scale - if not isinstance(op, ConstantScale): - continue - scale = op - window.extend(ops[:1]) - ops = ops[1:] - op, _ = ops[0] - - # Check if next op is in-place add extra input - if not isinstance(op, AddExtraInput): - continue - if not op._in_place: - continue - add = op - window.extend(ops[:1]) - ops = ops[1:] - - # Replace window with fused op - op = ForwardLinearScaleAdd( - linear=linear, - scale=scale, - add=add, - ) - basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window] - window = [(op, basic_op_idxs)] - - # Return list of ops - out.extend(window) - out.extend(ops) - return out + return out diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py index 4943ffb1bd..c4efd370c3 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py @@ -502,7 +502,7 @@ def fuser_backward( # Get basic operations idx = self._op_idxs["linear"] linear_op = self.basic_ops[idx] - linear_op_ctx = basic_op_ctxs[-1] + linear_op_ctx = basic_op_ctxs[0] bias_op = None if self._op_idxs["bias"] is not None: idx = self._op_idxs["bias"] @@ -577,99 +577,90 @@ def fuser_backward( grad_params[self._op_idxs["linear"]] = (grad_weight,) if bias_op is not None: grad_params[self._op_idxs["bias"]] = (grad_bias,) - grad_params.reverse() grad_extra_inputs = [() for _ in range(len(self.basic_ops))] return grad_input, grad_params, grad_extra_inputs + @staticmethod + def fuse_backward_ops( + ops: list[FusibleOperation], + **unused, + ) -> list[FusibleOperation]: + """Apply operation fusion for backward pass. -def fuse_userbuffers_backward_linear( - ops: list[tuple[FusibleOperation, list[int]]], -) -> list[tuple[FusibleOperation, list[int]]]: - """Substitute linear operations with Userbuffers implementation + Parameters + ---------- + ops : list of FusibleOperation + Backward pass operations. + recipe : Recipe, optional + Quantization recipe. - Parameters - ---------- - ops : list of tuples - Backward pass operations and the indices of the corresponding - basic operations. + Returns + ------- + ops : list of FusibleOperation + Updated backward pass operations - Returns - ------- - ops : list of tuples - Updated backward pass operations + """ - """ + # Return immediately if environment is not distributed + if ( + not torch.distributed.is_initialized() + or torch.distributed.get_world_size() == 1 + ): + return ops - # Return immediately if environment is not distributed - if not torch.distributed.is_initialized() or torch.distributed.get_world_size() == 1: - return ops - - # Sliding window in list of ops - window = [] - - def peek_next_op() -> Optional[FusibleOperation]: - """Get next op in list of ops""" - nonlocal ops - if not ops: - return None - return ops[-1][0] - - def pop_next_op() -> FusibleOperation: - """Remove next op from list of ops and add to sliding window""" - nonlocal ops, window - window.insert(0, ops[-1]) - ops = ops[:-1] - return window[0][0] - - # Scan through ops in reverse order, fusing if possible - out_reversed = [] - while ops: - out_reversed.extend(reversed(window)) - window.clear() - - # Check if next op is linear - next_op = pop_next_op() - if not isinstance(next_op, BasicLinear): - continue - linear = next_op - if linear._userbuffers_options is None: - continue - - # Check if next op is bias - bias = None - if linear.tensor_parallel_mode != "row" and isinstance(peek_next_op(), Bias): - bias = pop_next_op() - - # Check if next op is reduce-scatter - reduce_scatter = None - if linear.tensor_parallel_mode is None and isinstance(peek_next_op(), ReduceScatter): - reduce_scatter = pop_next_op() - - # Check for invalid combinations - if reduce_scatter is None: - if linear.tensor_parallel_mode is None: - continue - if linear.tensor_parallel_size == 1: - continue - if linear.tensor_parallel_mode == "row" and bias is not None: - continue - else: - if linear.tensor_parallel_mode is not None: + # Scan through ops, fusing if possible + out = [] + window = [] + while ops: + + # Shift window + out.extend(window) + window, ops = ops[:1], ops[1:] + + # Check if first op is linear + if not isinstance(window[0], BasicLinear): continue - if reduce_scatter.process_group_size == 1: + linear = window[0] + if linear._userbuffers_options is None: continue - # Replace window with fused op - op = UserbuffersBackwardLinear( - linear=linear, - bias=bias, - reduce_scatter=reduce_scatter, - ) - basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window] - window = [(op, basic_op_idxs)] - - # Return list of ops - out_reversed.extend(reversed(window)) - out = out_reversed - out.reverse() - return out + # Check if next op is bias + bias = None + if linear.tensor_parallel_mode != "row" and ops and isinstance(ops[0], Bias): + bias, ops = ops[0], ops[1:] + window.append(bias) + + # Check if next op is reduce-scatter + reduce_scatter = None + if ( + linear.tensor_parallel_mode is None + and ops and isinstance(ops[0], ReduceScatter) + ): + reduce_scatter, ops = ops[0], ops[1:] + window.append(reduce_scatter) + + # Check for invalid combinations + if reduce_scatter is None: + if linear.tensor_parallel_mode is None: + continue + if linear.tensor_parallel_size == 1: + continue + if linear.tensor_parallel_mode == "row" and bias is not None: + continue + else: + if linear.tensor_parallel_mode is not None: + continue + if reduce_scatter.process_group_size == 1: + continue + + # Replace window with fused op + op = UserbuffersBackwardLinear( + linear=linear, + bias=bias, + reduce_scatter=reduce_scatter, + ) + window = [op] + + # Return list of ops + out.extend(window) + return out diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py index fe04aa1e0b..48957d3da6 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py @@ -369,93 +369,85 @@ def fuser_forward( return output, [() for _ in range(len(self.basic_ops))] + @staticmethod + def fuse_forward_ops( + ops: list[FusibleOperation], + **unused, + ) -> list[FusibleOperation]: + """Apply operation fusion for forward pass. -def fuse_userbuffers_forward_linear( - ops: list[tuple[FusibleOperation, list[int]]], -) -> list[tuple[FusibleOperation, list[int]]]: - """Substitute linear operations with Userbuffers implementation - - Parameters - ---------- - ops : list of tuples - Forward pass operations and the indices of the corresponding - basic operations. - - Returns - ------- - ops : list of tuples - Updated forward pass operations + Parameters + ---------- + ops : list of FusibleOperation + Forward pass operations. - """ + Returns + ------- + ops : list of FusibleOperation + Updated forward pass operations - # Return immediately if environment is not distributed - if not torch.distributed.is_initialized() or torch.distributed.get_world_size() == 1: - return ops - - # Sliding window in list of ops - window = [] - - def peek_next_op() -> Optional[FusibleOperation]: - """Get next op in list of ops""" - nonlocal ops - if not ops: - return None - return ops[0][0] - - def pop_next_op() -> FusibleOperation: - """Remove next op from list of ops and add to sliding window""" - nonlocal ops, window - window.append(ops[0]) - ops = ops[1:] - return window[-1][0] - - # Scan through ops, fusing if possible - out = [] - while ops: - out.extend(window) - window.clear() + """ - # Check if next op is linear - next_op = pop_next_op() - if not isinstance(next_op, BasicLinear): - continue - linear = next_op - if linear._userbuffers_options is None: - continue + # Return immediately if environment is not distributed + if ( + not torch.distributed.is_initialized() + or torch.distributed.get_world_size() == 1 + ): + return ops - # Check if next op is bias - bias = None - if linear.tensor_parallel_mode != "row" and isinstance(peek_next_op(), Bias): - bias = pop_next_op() + # Scan through ops, fusing if possible + out = [] + window = [] + while ops: - # Check if next op is reduce-scatter - reduce_scatter = None - if linear.tensor_parallel_mode is None and isinstance(peek_next_op(), ReduceScatter): - reduce_scatter = pop_next_op() + # Shift window + out.extend(window) + window, ops = ops[:1], ops[1:] - # Check for invalid combinations - if reduce_scatter is None: - if linear.tensor_parallel_mode is None: - continue - if linear.tensor_parallel_size == 1: - continue - if linear.tensor_parallel_mode == "row" and bias is not None: - continue - else: - if linear.tensor_parallel_mode is not None: + # Check if first op is linear + if not isinstance(window[0], BasicLinear): continue - if reduce_scatter.process_group_size == 1: + linear = window[0] + if linear._userbuffers_options is None: continue - # Replace window with fused op - op = UserbuffersForwardLinear( - linear=linear, - bias=bias, - reduce_scatter=reduce_scatter, - ) - basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window] - window = [(op, basic_op_idxs)] + # Check if next op is bias + bias = None + if linear.tensor_parallel_mode != "row" and ops and isinstance(ops[0], Bias): + bias, ops = ops[0], ops[1:] + window.append(bias) + + # Check if next op is reduce-scatter + reduce_scatter = None + if ( + linear.tensor_parallel_mode is None + and ops and isinstance(ops[0], ReduceScatter) + ): + reduce_scatter, ops = ops[0], ops[1:] + window.append(reduce_scatter) + + # Check for invalid combinations + if reduce_scatter is None: + if linear.tensor_parallel_mode is None: + continue + if linear.tensor_parallel_size == 1: + continue + if linear.tensor_parallel_mode == "row" and bias is not None: + continue + else: + if linear.tensor_parallel_mode is not None: + continue + if reduce_scatter.process_group_size == 1: + continue + + # Replace window with fused op + op = UserbuffersForwardLinear( + linear=linear, + bias=bias, + reduce_scatter=reduce_scatter, + ) + window = [op] - # Return list of ops - out.extend(window) - return out + # Return list of ops + out.extend(window) + return out diff --git a/transformer_engine/pytorch/ops/fuser.py b/transformer_engine/pytorch/ops/fuser.py index bf7af48d03..e3eabfa575 100644 --- a/transformer_engine/pytorch/ops/fuser.py +++ b/transformer_engine/pytorch/ops/fuser.py @@ -5,33 +5,20 @@ """Manager class for a pipeline of fusible operations.""" from __future__ import annotations -from collections.abc import Callable, Iterable -from typing import Any, Optional +from collections.abc import Callable, Iterable, Sequence import itertools +from typing import Any, Optional import torch -from transformer_engine.pytorch.quantization import FP8GlobalStateManager, Recipe, DelayedScaling -from transformer_engine.pytorch.ops.op import ( +from ..quantization import FP8GlobalStateManager, Recipe, DelayedScaling +from ..quantized_tensor import prepare_for_saving, restore_from_saved +from .op import ( BasicOperation, FusibleOperation, + FusedOperation, OperationContext, ) -from transformer_engine.pytorch.ops.fused import ( - fuse_backward_activation_bias, - fuse_backward_add_rmsnorm, - fuse_backward_linear_add, - fuse_backward_linear_scale, - fuse_forward_linear_bias_activation, - fuse_forward_linear_bias_add, - fuse_forward_linear_scale_add, - fuse_userbuffers_backward_linear, - fuse_userbuffers_forward_linear, -) -from transformer_engine.pytorch.quantized_tensor import ( - prepare_for_saving, - restore_from_saved, -) def _split_tuple(t: tuple, idx: int) -> tuple[tuple, tuple]: @@ -57,6 +44,10 @@ def _is_graph_capturing() -> bool: return _is_graph_capturing_function() +# Type alias for a function that may perform operation fusion +type OperationFusionFunction = Callable[tuple[list[FusibleOperation], ...], list[FusibleOperation]] + + class _OperationFuserAutogradFunction(torch.autograd.Function): """Autograd function for a pipeline of operations @@ -241,7 +232,7 @@ def backward( dx = grad_output grad_params = [None for _ in range(len(basic_ops))] grad_extra_inputs = [None for _ in range(len(basic_ops))] - for op, basic_op_idxs in backward_ops: + for op, basic_op_idxs in reversed(backward_ops): # Stop if no more gradients are required if all(not basic_op_ctxs[idx].requires_grad for idx in basic_op_idxs): @@ -315,6 +306,10 @@ class OperationFuser: """ + # Functions to perform operation fusion + forward_fusion_functions: list[OperationFusionFunction] = [] + backward_fusion_functions: list[OperationFusionFunction] = [] + def __init__( self, ops: list[FusibleOperation], @@ -334,7 +329,7 @@ def __init__( self._basic_op_num_extra_inputs: list[int] = list(op.num_extra_inputs for op in basic_ops) self.num_extra_inputs: int = sum(self._basic_op_num_extra_inputs) - # Ops for forward and backward pass, will be populated in fuse_ops + # Ops for forward and backward pass, will be populated in maybe_fuse_ops self._forward_ops: list[tuple[FusibleOperation, list[int]]] self._backward_ops: list[tuple[FusibleOperation, list[int]]] @@ -349,31 +344,48 @@ def __init__( self._flat_basic_op_params = sum(self._basic_op_params, []) @classmethod - def _fuse_forward_ops( - cls, - ops: list[tuple[FusibleOperation, list[int]]], - recipe: Optional[Recipe], # pylint: disable=unused-argument - ) -> list[tuple[FusibleOperation, list[int]]]: - """Attempt to fuse operations in forward pass""" - ops = fuse_userbuffers_forward_linear(ops) - ops = fuse_forward_linear_bias_add(ops) - ops = fuse_forward_linear_bias_activation(ops) - ops = fuse_forward_linear_scale_add(ops) - return ops - - @classmethod - def _fuse_backward_ops( + def _fuse_ops( cls, - ops: list[tuple[FusibleOperation, list[int]]], + basic_ops: Sequence[BasicOperation], + fusion_funcs: Iterable[OperationFusionFunction], recipe: Optional[Recipe], ) -> list[tuple[FusibleOperation, list[int]]]: - """Attempt to fuse operations in backward pass""" - ops = fuse_userbuffers_backward_linear(ops) - ops = fuse_backward_linear_add(ops) - ops = fuse_backward_linear_scale(ops) - ops = fuse_backward_activation_bias(ops, recipe) - ops = fuse_backward_add_rmsnorm(ops) - return ops + """Apply operation fusions""" + + # Apply op fusions + fused_ops = list(basic_ops) + for func in fusion_funcs: + fused_ops = func(fused_ops, recipe=recipe) + + def raise_mismatch_error() -> None: + """Throw error indicating invalid op fusion""" + raise RuntimeError( + "Found mismatch after fusing operations " + f"(basic_ops={[o.__class__.__name__ for o in basic_ops]}, " + f"fused_ops={[o.__class__.__name__ for o in fused_ops]})" + ) + + # Determine basic op indices corresponding to each op + out = [] + idx = 0 + for op in fused_ops: + if isinstance(op, FusedOperation): + idxs = [] + for basic_op in op.basic_ops: + if basic_op is not basic_ops[idx]: + raise_mismatch_error() + idxs.append(idx) + idx += 1 + out.append((op, idxs)) + else: + if op is not basic_ops[idx]: + raise_mismatch_error() + out.append((op, [idx])) + idx += 1 + if idx != len(basic_ops): + raise_mismatch_error() + + return out def maybe_fuse_ops( self, @@ -424,12 +436,16 @@ def maybe_fuse_ops( op.pre_first_fuser_forward() # Prepare basic op lists for fusions - forward_ops = [(op, [idx]) for idx, op in enumerate(self._basic_ops)] - backward_ops = list(reversed(forward_ops[first_op_requiring_backward:])) - - # Fuse ops - self._forward_ops = self._fuse_forward_ops(forward_ops, recipe) - self._backward_ops = self._fuse_backward_ops(backward_ops, recipe) + self._forward_ops = OperationFuser._fuse_ops( + self._basic_ops, + OperationFuser.forward_fusion_functions, + recipe=recipe, + ) + self._backward_ops = OperationFuser._fuse_ops( + self._basic_ops, + OperationFuser.backward_fusion_functions, + recipe=recipe, + ) # Save current fusion params self.recipe_type, self.first_op_requiring_backward = fusion_params @@ -491,3 +507,55 @@ def __call__( *extra_inputs, ) return forward_func(*args) + + +def register_forward_fusion( + op_fusion_func: OperationFusionFunction, + prepend: bool = False, +) -> None: + """Register function to perform operation fusion for forward pass. + + The fusion function should have the following signature: + + func(ops, *, recipe) -> updated ops + + Parameters + ---------- + op_fusion_func: function + Function that takes a list of operations and may substitute + them with fused operations. + prepend: bool, default = ``False`` + Whether the operation fuser should apply this fusion function + first. The default is to apply it last. + + """ + if prepend: + OperationFuser.forward_fusion_functions.insert(0, op_fusion_func) + else: + OperationFuser.forward_fusion_functions.append(op_fusion_func) + + +def register_backward_fusion( + op_fusion_func: OperationFusionFunction, + prepend: bool = False, +) -> None: + """Register function to perform operation fusion for backward pass. + + The fusion function should have the following signature: + + func(ops, *, recipe) -> updated ops + + Parameters + ---------- + op_fusion_func: function + Function that takes a list of operations and may substitute + them with fused operations. + prepend: bool, default = ``False`` + Whether the operation fuser should apply this fusion function + first. The default is to apply it last. + + """ + if prepend: + OperationFuser.backward_fusion_functions.insert(0, op_fusion_func) + else: + OperationFuser.backward_fusion_functions.append(op_fusion_func) From a359b6766fa48d090114bf007247b1cf44e7b55e Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Wed, 14 Jan 2026 08:20:44 +0000 Subject: [PATCH 08/45] Add tests for custom ops Signed-off-by: Tim Moon --- tests/pytorch/test_fusible_ops.py | 286 +++++++++++++++++++++ transformer_engine/pytorch/ops/__init__.py | 2 +- 2 files changed, 287 insertions(+), 1 deletion(-) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 7eb5302fca..7e02731c94 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -2849,3 +2849,289 @@ def test_layernorm_mlp( with te.autocast(enabled=quantized_compute, recipe=recipe): y_test = forward(x_test) y_test.backward(dy_test) + + +class TestCustomOps: + """Test with ops that are defined externally""" + + def test_custom_basic_op( + self, + *, + shape: Iterable[int] = (7, 5), + dtype: torch.dtype = torch.float32, + device: torch.device = "cuda", + ) -> None: + """Custom basic op""" + + class CustomScaleOp(te.ops.BasicOperation): + """Custom op that applies a learnable scale""" + + def __init__(self) -> None: + super().__init__() + self.scale: torch.nn.Parameter + scale = torch.ones((), dtype=dtype, device=device) + scale = torch.nn.Parameter(scale) + self.register_parameter("scale", scale) + + def op_forward( + self, + ctx: OperationContext, + input_: torch.Tensor, + prev_op_grad_output_quantizer: Optional[Quantizer], + next_op_input_quantizer: Optional[Quantizer], + ) -> torch.Tensor: + ctx.save_for_backward(self.scale, input_) + return self.scale * input_ + + def op_backward( + self, + ctx: OperationContext, + grad_output: torch.Tensor, + ) -> torch.Tensor: + scale, input_, = ctx.saved_tensors + grad_scale = torch.inner(input_.reshape(-1), grad_output.reshape(-1)) + grad_scale = grad_scale.reshape(()) + grad_input = scale * grad_output + return grad_input, (grad_scale,) + + # Random data + x_ref, x_test = make_reference_and_test_tensors( + shape, + test_dtype=dtype, + test_device=device, + ) + w_ref, w_test = make_reference_and_test_tensors( + (), + test_dtype=dtype, + test_device=device, + ) + dy_ref, dy_test = make_reference_and_test_tensors( + shape, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + + # Plain PyTorch implementation + y_ref = w_ref * x_ref + y_ref.backward(dy_ref) + + # Implementation with fusible operation + op = CustomScaleOp() + forward = te.ops.Sequential(te.ops.Identity(), op, te.ops.Identity()) + with torch.no_grad(): + op.scale.copy_(w_test) + del w_test + y_test = forward(x_test) + y_test.backward(dy_test) + + # Check results + tols = dtype_tols(dtype) + y_test = y_test.to(dtype=torch.float64, device="cpu") + dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") + dw_test = op.scale.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(y_test, y_ref, **tols) + torch.testing.assert_close(dx_test, x_ref.grad, **tols) + torch.testing.assert_close(dw_test, w_ref.grad, **tols) + + def test_custom_forward_fused_op( + self, + *, + shape: Iterable[int] = (7, 11), + dtype: torch.dtype = torch.float32, + device: torch.device = "cuda", + ): + """Custom fused op in forward pass""" + + class CustomForwardLinearSiLU(te.ops.FusedOperation): + """Custom fused op for GEMM + SiLU""" + + _enabled = True + + def __init__(self, *, linear, silu) -> None: + super().__init__((linear, silu)) + + def fuser_forward( + self, + basic_op_ctxs: list[OperationContext], + input_: torch.Tensor, + **unused, + ) -> torch.Tensor: + # Do compute on CPU + x = input_.cpu() + w = self.basic_ops[0].weight.cpu() + y = torch.nn.functional.linear(x, w) + z = torch.nn.functional.silu(y) + + # Save state for linear backward + linear_op_ctx = basic_op_ctxs[0] + linear_op_ctx.save_for_backward(x.cuda(), w.cuda()) + linear_op_ctx.with_quantized_compute = False + linear_op_ctx.input_quantizer = None + linear_op_ctx.weight_quantizer = None + linear_op_ctx.grad_output_quantizer = None + linear_op_ctx.grad_input_quantizer = None + linear_op_ctx.dtype = w.dtype + linear_op_ctx.input_requires_grad = True + linear_op_ctx.weight_requires_grad = True + + # Save state for SiLU backward + silu_op_ctx = basic_op_ctxs[1] + silu_op_ctx.save_for_backward(y.cuda()) + silu_op_ctx.dtype = w.dtype + silu_op_ctx.prev_op_grad_output_quantizer = None + + return z.cuda(), [(), ()] + + @staticmethod + def fuse_ops( + ops: list[FusibleOperation], + **unused, + ) -> list[FusibleOperation]: + """Apply fusion the first time this function is called""" + if CustomForwardLinearSiLU._enabled: + CustomForwardLinearSiLU._enabled = False + op = CustomForwardLinearSiLU(linear=ops[0], silu=ops[1]) + return [op] + ops[2:] + return ops + + # Random data + x_ref, x_test = make_reference_and_test_tensors( + shape, + test_dtype=dtype, + test_device=device, + ) + w_ref, w_test = make_reference_and_test_tensors( + (shape[-1], shape[-1]), + test_dtype=dtype, + test_device=device, + ) + dy_ref, dy_test = make_reference_and_test_tensors( + shape, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + + # Plain PyTorch implementation + y_ref = torch.nn.functional.linear(x_ref, w_ref) + y_ref = torch.nn.functional.silu(y_ref) + y_ref.backward(dy_ref) + + # Implementation with fusible operation + te.ops.register_forward_fusion(CustomForwardLinearSiLU.fuse_ops) + model = te.ops.Sequential( + te.ops.Linear(shape[-1], shape[-1], bias=False), + te.ops.SiLU(), + ) + with torch.no_grad(): + model[0].weight.copy_(w_test) + del w_test + y_test = model(x_test) + y_test.backward(dy_test) + + # Check that forward operations have been fused + forward_ops = model._module_groups[0]._forward_ops + assert len(forward_ops) == 1 + assert isinstance(forward_ops[0][0], CustomForwardLinearSiLU) + + # Check results + tols = dtype_tols(dtype) + y_test = y_test.to(dtype=torch.float64, device="cpu") + dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") + dw_test = model[0].weight.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(y_test, y_ref, **tols) + torch.testing.assert_close(dx_test, x_ref.grad, **tols) + torch.testing.assert_close(dw_test, w_ref.grad, **tols) + + def test_custom_backward_fused_op( + self, + *, + shape: Iterable[int] = (13, 5), + dtype: torch.dtype = torch.float32, + device: torch.device = "cuda", + ): + """Custom fused op in backward pass""" + + class CustomBackwardLinearScale(te.ops.FusedOperation): + """Custom fused op for backward linear + scale""" + + _enabled: bool = True + + def __init__(self, *, scale, linear) -> None: + super().__init__((scale, linear)) + + def fuser_backward( + self, + basic_op_ctxs: list[OperationContext], + grad_output: torch.Tensor, + **unused, + ) -> torch.Tensor: + scale = self.basic_ops[0].scale + linear_op_ctx = basic_op_ctxs[1] + x, w = linear_op_ctx.saved_tensors + dy = grad_output + dx = torch.nn.functional.linear(dy, scale * w.T) + dw = torch.matmul(dy.T, x) + return dx, [(), (dw,)], [(), ()] + + @staticmethod + def fuse_ops( + ops: list[FusibleOperation], + **unused, + ) -> list[FusibleOperation]: + """Apply fusion the first time this function is called""" + if CustomBackwardLinearScale._enabled: + CustomBackwardLinearScale._enabled = False + op = CustomBackwardLinearScale(scale=ops[0], linear=ops[1]) + return [op] + ops[2:] + return ops + + # Random data + x_ref, x_test = make_reference_and_test_tensors( + shape, + test_dtype=dtype, + test_device=device, + ) + w_ref, w_test = make_reference_and_test_tensors( + (shape[-1], shape[-1]), + test_dtype=dtype, + test_device=device, + ) + dy_ref, dy_test = make_reference_and_test_tensors( + shape, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + scale = 1.234 + + # Plain PyTorch implementation + y_ref = torch.nn.functional.linear(scale * x_ref, w_ref) + y_ref.backward(dy_ref) + + # Implementation with fusible operation + te.ops.register_backward_fusion(CustomBackwardLinearScale.fuse_ops, prepend=True) + model = te.ops.Sequential( + te.ops.ConstantScale(scale), + te.ops.Linear(shape[-1], shape[-1], bias=False), + ) + with torch.no_grad(): + model[1].weight.copy_(w_test) + del w_test + y_test = model(x_test) + y_test.backward(dy_test) + + # Check that forward operations have been fused + backward_ops = model._module_groups[0]._backward_ops + assert len(backward_ops) == 1 + assert isinstance(backward_ops[0][0], CustomBackwardLinearScale) + + # Check results + tols = dtype_tols(dtype) + y_test = y_test.to(dtype=torch.float64, device="cpu") + dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") + dw_test = model[1].weight.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(y_test, y_ref, **tols) + torch.testing.assert_close(dx_test, x_ref.grad, **tols) + torch.testing.assert_close(dw_test, w_ref.grad, **tols) diff --git a/transformer_engine/pytorch/ops/__init__.py b/transformer_engine/pytorch/ops/__init__.py index 4f1d64623a..c61b50417d 100644 --- a/transformer_engine/pytorch/ops/__init__.py +++ b/transformer_engine/pytorch/ops/__init__.py @@ -11,5 +11,5 @@ from .basic import * from .fuser import register_backward_fusion, register_forward_fusion from .linear import Linear -from .op import FusibleOperation +from .op import BasicOperation, FusedOperation, FusibleOperation from .sequential import Sequential From 5f7204faf13228db2e51b69c487e59349a46c81a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 14 Jan 2026 08:29:45 +0000 Subject: [PATCH 09/45] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/test_fusible_ops.py | 5 ++++- .../pytorch/ops/fused/userbuffers_backward_linear.py | 10 ++-------- .../pytorch/ops/fused/userbuffers_forward_linear.py | 10 ++-------- 3 files changed, 8 insertions(+), 17 deletions(-) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 7e02731c94..3863614aa8 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -2888,7 +2888,10 @@ def op_backward( ctx: OperationContext, grad_output: torch.Tensor, ) -> torch.Tensor: - scale, input_, = ctx.saved_tensors + ( + scale, + input_, + ) = ctx.saved_tensors grad_scale = torch.inner(input_.reshape(-1), grad_output.reshape(-1)) grad_scale = grad_scale.reshape(()) grad_input = scale * grad_output diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py index c4efd370c3..a8b1c4df33 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py @@ -602,10 +602,7 @@ def fuse_backward_ops( """ # Return immediately if environment is not distributed - if ( - not torch.distributed.is_initialized() - or torch.distributed.get_world_size() == 1 - ): + if not torch.distributed.is_initialized() or torch.distributed.get_world_size() == 1: return ops # Scan through ops, fusing if possible @@ -632,10 +629,7 @@ def fuse_backward_ops( # Check if next op is reduce-scatter reduce_scatter = None - if ( - linear.tensor_parallel_mode is None - and ops and isinstance(ops[0], ReduceScatter) - ): + if linear.tensor_parallel_mode is None and ops and isinstance(ops[0], ReduceScatter): reduce_scatter, ops = ops[0], ops[1:] window.append(reduce_scatter) diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py index 48957d3da6..f3665ea10f 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py @@ -389,10 +389,7 @@ def fuse_forward_ops( """ # Return immediately if environment is not distributed - if ( - not torch.distributed.is_initialized() - or torch.distributed.get_world_size() == 1 - ): + if not torch.distributed.is_initialized() or torch.distributed.get_world_size() == 1: return ops # Scan through ops, fusing if possible @@ -419,10 +416,7 @@ def fuse_forward_ops( # Check if next op is reduce-scatter reduce_scatter = None - if ( - linear.tensor_parallel_mode is None - and ops and isinstance(ops[0], ReduceScatter) - ): + if linear.tensor_parallel_mode is None and ops and isinstance(ops[0], ReduceScatter): reduce_scatter, ops = ops[0], ops[1:] window.append(reduce_scatter) From 8ddb8ce1ff894187e3569ec2ae122ff2e4b31d1b Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Wed, 14 Jan 2026 20:01:34 +0000 Subject: [PATCH 10/45] Fix linter warnings and numerical test failures Signed-off-by: Tim Moon --- tests/pytorch/test_fusible_ops.py | 35 ++++++++++++++----- .../ops/fused/backward_activation_bias.py | 2 +- .../pytorch/ops/fused/backward_add_rmsnorm.py | 2 +- .../pytorch/ops/fused/backward_linear_add.py | 2 +- .../ops/fused/backward_linear_scale.py | 2 +- .../fused/forward_linear_bias_activation.py | 2 +- .../ops/fused/forward_linear_bias_add.py | 2 +- .../ops/fused/forward_linear_scale_add.py | 2 +- .../ops/fused/userbuffers_backward_linear.py | 2 +- .../ops/fused/userbuffers_forward_linear.py | 2 +- transformer_engine/pytorch/ops/fuser.py | 6 ++-- 11 files changed, 39 insertions(+), 20 deletions(-) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 3863614aa8..a4b12223da 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -2960,31 +2960,36 @@ def fuser_forward( input_: torch.Tensor, **unused, ) -> torch.Tensor: - # Do compute on CPU + weight = self.basic_ops[0].weight + dtype = weight.dtype + device = weight.device + + # Perform compute on CPU, because why not? x = input_.cpu() - w = self.basic_ops[0].weight.cpu() + w = weight.cpu() y = torch.nn.functional.linear(x, w) z = torch.nn.functional.silu(y) + out = z.to(device=device) # Save state for linear backward linear_op_ctx = basic_op_ctxs[0] - linear_op_ctx.save_for_backward(x.cuda(), w.cuda()) + linear_op_ctx.save_for_backward(input_, weight) linear_op_ctx.with_quantized_compute = False linear_op_ctx.input_quantizer = None linear_op_ctx.weight_quantizer = None linear_op_ctx.grad_output_quantizer = None linear_op_ctx.grad_input_quantizer = None - linear_op_ctx.dtype = w.dtype + linear_op_ctx.dtype = dtype linear_op_ctx.input_requires_grad = True linear_op_ctx.weight_requires_grad = True # Save state for SiLU backward silu_op_ctx = basic_op_ctxs[1] - silu_op_ctx.save_for_backward(y.cuda()) - silu_op_ctx.dtype = w.dtype + silu_op_ctx.save_for_backward(y.to(device=device)) + silu_op_ctx.dtype = dtype silu_op_ctx.prev_op_grad_output_quantizer = None - return z.cuda(), [(), ()] + return out, [(), ()] @staticmethod def fuse_ops( @@ -3070,12 +3075,24 @@ def fuser_backward( grad_output: torch.Tensor, **unused, ) -> torch.Tensor: - scale = self.basic_ops[0].scale + + # Load state from linear forward linear_op_ctx = basic_op_ctxs[1] x, w = linear_op_ctx.saved_tensors - dy = grad_output + dtype = linear_op_ctx.dtype + device = w.device + + # Perform compute in FP64 and apply scale before dgrad + # GEMM instead of after + scale = self.basic_ops[0].scale + dy = grad_output.double() + x = x.double() + w = w.double() dx = torch.nn.functional.linear(dy, scale * w.T) dw = torch.matmul(dy.T, x) + dx = dx.to(dtype=dtype) + dw = dw.to(dtype=dtype) + return dx, [(), (dw,)], [(), ()] @staticmethod diff --git a/transformer_engine/pytorch/ops/fused/backward_activation_bias.py b/transformer_engine/pytorch/ops/fused/backward_activation_bias.py index 8fd0ac6cde..7cbc28e357 100644 --- a/transformer_engine/pytorch/ops/fused/backward_activation_bias.py +++ b/transformer_engine/pytorch/ops/fused/backward_activation_bias.py @@ -86,7 +86,7 @@ def fuse_backward_ops( ops: list[FusibleOperation], *, recipe: Optional[Recipe] = None, - **unused, + **unused, # pylint: disable=unused-argument ) -> list[FusibleOperation]: """Apply operation fusion for backward pass. diff --git a/transformer_engine/pytorch/ops/fused/backward_add_rmsnorm.py b/transformer_engine/pytorch/ops/fused/backward_add_rmsnorm.py index 2eaded064d..747edd6f98 100644 --- a/transformer_engine/pytorch/ops/fused/backward_add_rmsnorm.py +++ b/transformer_engine/pytorch/ops/fused/backward_add_rmsnorm.py @@ -82,7 +82,7 @@ def fuser_backward( @staticmethod def fuse_backward_ops( ops: list[FusibleOperation], - **unused, + **unused, # pylint: disable=unused-argument ) -> list[FusibleOperation]: """Apply operation fusion for backward pass. diff --git a/transformer_engine/pytorch/ops/fused/backward_linear_add.py b/transformer_engine/pytorch/ops/fused/backward_linear_add.py index e6307c254c..1e410b3999 100644 --- a/transformer_engine/pytorch/ops/fused/backward_linear_add.py +++ b/transformer_engine/pytorch/ops/fused/backward_linear_add.py @@ -114,7 +114,7 @@ def fuser_backward( @staticmethod def fuse_backward_ops( ops: list[FusibleOperation], - **unused, + **unused, # pylint: disable=unused-argument ) -> list[FusibleOperation]: """Apply operation fusion for backward pass. diff --git a/transformer_engine/pytorch/ops/fused/backward_linear_scale.py b/transformer_engine/pytorch/ops/fused/backward_linear_scale.py index 2fd72fb963..4474b55cea 100644 --- a/transformer_engine/pytorch/ops/fused/backward_linear_scale.py +++ b/transformer_engine/pytorch/ops/fused/backward_linear_scale.py @@ -114,7 +114,7 @@ def fuser_backward( @staticmethod def fuse_backward_ops( ops: list[FusibleOperation], - **unused, + **unused, # pylint: disable=unused-argument ) -> list[FusibleOperation]: """Apply operation fusion for backward pass. diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py index 8e602e4cc2..63dc05b22a 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py @@ -137,7 +137,7 @@ def fuser_forward( @staticmethod def fuse_forward_ops( ops: list[FusibleOperation], - **unused, + **unused, # pylint: disable=unused-argument ) -> list[FusibleOperation]: """Apply operation fusion for forward pass. diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py index cc9554c876..2dfc0566b7 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py @@ -134,7 +134,7 @@ def fuser_forward( @staticmethod def fuse_forward_ops( ops: list[FusibleOperation], - **unused, + **unused, # pylint: disable=unused-argument ) -> list[FusibleOperation]: """Apply operation fusion for forward pass. diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py index 3cf9f294c3..8088c61ec2 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py @@ -113,7 +113,7 @@ def fuser_forward( @staticmethod def fuse_forward_ops( ops: list[FusibleOperation], - **unused, + **unused, # pylint: disable=unused-argument ) -> list[FusibleOperation]: """Apply operation fusion for forward pass. diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py index a8b1c4df33..077f2758cd 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py @@ -583,7 +583,7 @@ def fuser_backward( @staticmethod def fuse_backward_ops( ops: list[FusibleOperation], - **unused, + **unused, # pylint: disable=unused-argument ) -> list[FusibleOperation]: """Apply operation fusion for backward pass. diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py index f3665ea10f..6ef9bf083b 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py @@ -372,7 +372,7 @@ def fuser_forward( @staticmethod def fuse_forward_ops( ops: list[FusibleOperation], - **unused, + **unused, # pylint: disable=unused-argument ) -> list[FusibleOperation]: """Apply operation fusion for forward pass. diff --git a/transformer_engine/pytorch/ops/fuser.py b/transformer_engine/pytorch/ops/fuser.py index e3eabfa575..7fe6ea37ed 100644 --- a/transformer_engine/pytorch/ops/fuser.py +++ b/transformer_engine/pytorch/ops/fuser.py @@ -7,7 +7,7 @@ from __future__ import annotations from collections.abc import Callable, Iterable, Sequence import itertools -from typing import Any, Optional +from typing import Any, Optional, TypeAlias import torch @@ -45,7 +45,9 @@ def _is_graph_capturing() -> bool: # Type alias for a function that may perform operation fusion -type OperationFusionFunction = Callable[tuple[list[FusibleOperation], ...], list[FusibleOperation]] +OperationFusionFunction: TypeAlias = ( + "Callable[tuple[list[FusibleOperation], ...], list[FusibleOperation]]" +) class _OperationFuserAutogradFunction(torch.autograd.Function): From cfc2617ca06204d53ee553ef2c6778e3db968ba0 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Thu, 15 Jan 2026 00:17:46 +0000 Subject: [PATCH 11/45] Tweak pattern matching logic with fixed window sizes Signed-off-by: Tim Moon --- .../ops/fused/backward_activation_bias.py | 32 +++++++++++-------- .../pytorch/ops/fused/backward_add_rmsnorm.py | 29 +++++++++-------- .../pytorch/ops/fused/backward_linear_add.py | 28 ++++++++-------- .../ops/fused/backward_linear_scale.py | 28 ++++++++-------- .../fused/forward_linear_bias_activation.py | 28 ++++++++-------- .../ops/fused/forward_linear_scale_add.py | 28 ++++++++-------- 6 files changed, 93 insertions(+), 80 deletions(-) diff --git a/transformer_engine/pytorch/ops/fused/backward_activation_bias.py b/transformer_engine/pytorch/ops/fused/backward_activation_bias.py index 7cbc28e357..0b6dee7e73 100644 --- a/transformer_engine/pytorch/ops/fused/backward_activation_bias.py +++ b/transformer_engine/pytorch/ops/fused/backward_activation_bias.py @@ -110,26 +110,30 @@ def fuse_backward_ops( # Scan through ops, fusing if possible out = [] - window = [] - while ops: - - # Shift window - while len(window) >= 3: - out.append(window[0]) - window = window[1:] - while ops and len(window) < 3: - window.append(ops[0]) - ops = ops[1:] - - # Construct fused op if window matches pattern + window, ops = ops[:3], ops[3:] + while len(window) == 3: if ( - len(window) == 3 - and isinstance(window[2], _fusible_activations) + isinstance(window[2], _fusible_activations) and isinstance(window[1], Bias) and window[0].get_grad_output_quantizer() is not None ): + # Construct fused op if window matches pattern op = BackwardActivationBias(bias=window[1], activation=window[2]) window = [window[0], op] + while ops and len(window) < 3: + window.append(ops[0]) + ops = ops[1:] + else: + # Shift window if window doesn't match pattern + out.extend(window[:-2]) + window = window[-2:] + + # Adjust window to expected size + out.extend(window[:-3]) + window = window[-3:] + while ops and len(window) < 3: + window.append(ops[0]) + ops = ops[1:] # Return list of ops out.extend(window) diff --git a/transformer_engine/pytorch/ops/fused/backward_add_rmsnorm.py b/transformer_engine/pytorch/ops/fused/backward_add_rmsnorm.py index 747edd6f98..a3c81e60c8 100644 --- a/transformer_engine/pytorch/ops/fused/backward_add_rmsnorm.py +++ b/transformer_engine/pytorch/ops/fused/backward_add_rmsnorm.py @@ -100,26 +100,27 @@ def fuse_backward_ops( # Scan through ops, fusing if possible out = [] - window = [] - while ops: - - # Shift window - while len(window) >= 2: - out.append(window[0]) - window = window[1:] - while ops and len(window) < 2: - window.append(ops[0]) - ops = ops[1:] - - # Construct fused op if window matches pattern + window, ops = ops[:2], ops[2:] + while len(window) == 2: if ( - len(window) == 2 - and isinstance(window[0], MakeExtraOutput) + isinstance(window[0], MakeExtraOutput) and isinstance(window[1], RMSNorm) and not window[0]._in_place ): + # Construct fused op if window matches pattern op = BackwardAddRMSNorm(add=window[0], rmsnorm=window[1]) window = [op] + else: + # Shift window if window doesn't match pattern + out.extend(window[:-1]) + window = window[-1:] + + # Adjust window to expected size + out.extend(window[:-2]) + window = window[-2:] + while ops and len(window) < 2: + window.append(ops[0]) + ops = ops[1:] # Return list of ops out.extend(window) diff --git a/transformer_engine/pytorch/ops/fused/backward_linear_add.py b/transformer_engine/pytorch/ops/fused/backward_linear_add.py index 1e410b3999..610e936f35 100644 --- a/transformer_engine/pytorch/ops/fused/backward_linear_add.py +++ b/transformer_engine/pytorch/ops/fused/backward_linear_add.py @@ -132,22 +132,13 @@ def fuse_backward_ops( # Scan through ops, fusing if possible out = [] - window = [] - while ops: - - # Shift window - while len(window) >= 2: - out.append(window[0]) - window = window[1:] - while ops and len(window) < 2: - window.append(ops[0]) - ops = ops[1:] + window, ops = ops[:2], ops[2:] + while len(window) == 2: # Check if window matches pattern matches_pattern = True if not ( - len(window) == 2 - and isinstance(window[0], MakeExtraOutput) + isinstance(window[0], MakeExtraOutput) and isinstance(window[1], BasicLinear) ): matches_pattern = False @@ -159,10 +150,21 @@ def fuse_backward_ops( # after the dgrad GEMM matches_pattern = False - # Construct fused op if window matches pattern if matches_pattern: + # Construct fused op if window matches pattern op = BackwardLinearAdd(backward_add=window[0], linear=window[1]) window = [op] + else: + # Shift window if window doesn't match pattern + out.extend(window[:-1]) + window = window[-1:] + + # Adjust window to expected size + out.extend(window[:-2]) + window = window[-2:] + while ops and len(window) < 2: + window.append(ops[0]) + ops = ops[1:] # Return list of ops out.extend(window) diff --git a/transformer_engine/pytorch/ops/fused/backward_linear_scale.py b/transformer_engine/pytorch/ops/fused/backward_linear_scale.py index 4474b55cea..2b417d1ccf 100644 --- a/transformer_engine/pytorch/ops/fused/backward_linear_scale.py +++ b/transformer_engine/pytorch/ops/fused/backward_linear_scale.py @@ -132,22 +132,13 @@ def fuse_backward_ops( # Scan through ops, fusing if possible out = [] - window = [] - while ops: - - # Shift window - while len(window) >= 2: - out.append(window[0]) - window = window[1:] - while ops and len(window) < 2: - window.append(ops[0]) - ops = ops[1:] + window, ops = ops[:2], ops[2:] + while len(window) == 2: # Check if window matches pattern matches_pattern = True if not ( - len(window) == 2 - and isinstance(window[0], BasicLinear) + isinstance(window[0], BasicLinear) and isinstance(window[1], ConstantScale) ): matches_pattern = False @@ -156,10 +147,21 @@ def fuse_backward_ops( # after the dgrad GEMM matches_pattern = False - # Construct fused op if window matches pattern if matches_pattern: + # Construct fused op if window matches pattern op = BackwardLinearScale(linear=window[0], scale=window[1]) window = [op] + else: + # Shift window if window doesn't match pattern + out.extend(window[:-1]) + window = window[-1:] + + # Adjust window to expected size + out.extend(window[:-2]) + window = window[-2:] + while ops and len(window) < 2: + window.append(ops[0]) + ops = ops[1:] # Return list of ops out.extend(window) diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py index 63dc05b22a..d383a0739e 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py @@ -155,22 +155,13 @@ def fuse_forward_ops( # Scan through ops, fusing if possible out = [] - window = [] - while ops: - - # Shift window - while len(window) >= 2: - out.append(window[0]) - window = window[1:] - while ops and len(window) < 2: - window.append(ops[0]) - ops = ops[1:] + window, ops = ops[:2], ops[2:] + while len(window) == 2: # Check if window matches pattern matches_pattern = True if not ( - len(window) == 2 - and isinstance(window[0], BasicLinear) + isinstance(window[0], BasicLinear) and isinstance(window[1], Bias) ): matches_pattern = False @@ -183,14 +174,25 @@ def fuse_forward_ops( # FP16 and BF16 output matches_pattern = False - # Construct fused op if window matches pattern if matches_pattern: + # Construct fused op if window matches pattern op = ForwardLinearBiasActivation( linear=window[0], bias=window[1], activation=None, ) window = [op] + else: + # Shift window if window doesn't match pattern + out.extend(window[:-1]) + window = window[-1:] + + # Adjust window to expected size + out.extend(window[:-2]) + window = window[-2:] + while ops and len(window) < 2: + window.append(ops[0]) + ops = ops[1:] # Return list of ops out.extend(window) diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py index 8088c61ec2..ae4bdd4b19 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py @@ -131,22 +131,13 @@ def fuse_forward_ops( # Scan through ops, fusing if possible out = [] - window = [] - while ops: - - # Shift window - while len(window) >= 3: - out.append(window[0]) - window = window[1:] - while ops and len(window) < 3: - window.append(ops[0]) - ops = ops[1:] + window, ops = ops[:3], ops[3:] + while len(window) == 3: # Check if window matches pattern matches_pattern = True if not ( - len(window) == 3 - and isinstance(window[0], BasicLinear) + isinstance(window[0], BasicLinear) and isinstance(window[1], ConstantScale) and isinstance(window[2], AddExtraInput) ): @@ -159,14 +150,25 @@ def fuse_forward_ops( # Fused op accumulates output in-place matches_pattern = False - # Construct fused op if window matches pattern if matches_pattern: + # Construct fused op if window matches pattern op = ForwardLinearScaleAdd( linear=window[0], scale=window[1], add=window[2], ) window = [op] + else: + # Shift window if window doesn't match pattern + out.extend(window[:-2]) + window = window[-2:] + + # Adjust window to expected size + out.extend(window[:-3]) + window = window[-3:] + while ops and len(window) < 3: + window.append(ops[0]) + ops = ops[1:] # Return list of ops out.extend(window) From 0ce5dfbc816adb6f36bcafb3bbf57001772e3568 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 15 Jan 2026 00:18:38 +0000 Subject: [PATCH 12/45] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/ops/fused/backward_linear_add.py | 5 +---- .../pytorch/ops/fused/backward_linear_scale.py | 5 +---- .../pytorch/ops/fused/forward_linear_bias_activation.py | 5 +---- 3 files changed, 3 insertions(+), 12 deletions(-) diff --git a/transformer_engine/pytorch/ops/fused/backward_linear_add.py b/transformer_engine/pytorch/ops/fused/backward_linear_add.py index 610e936f35..c06e212e87 100644 --- a/transformer_engine/pytorch/ops/fused/backward_linear_add.py +++ b/transformer_engine/pytorch/ops/fused/backward_linear_add.py @@ -137,10 +137,7 @@ def fuse_backward_ops( # Check if window matches pattern matches_pattern = True - if not ( - isinstance(window[0], MakeExtraOutput) - and isinstance(window[1], BasicLinear) - ): + if not (isinstance(window[0], MakeExtraOutput) and isinstance(window[1], BasicLinear)): matches_pattern = False elif not window[0]._in_place: # Fused op accumulates grad input in-place diff --git a/transformer_engine/pytorch/ops/fused/backward_linear_scale.py b/transformer_engine/pytorch/ops/fused/backward_linear_scale.py index 2b417d1ccf..709073e6f8 100644 --- a/transformer_engine/pytorch/ops/fused/backward_linear_scale.py +++ b/transformer_engine/pytorch/ops/fused/backward_linear_scale.py @@ -137,10 +137,7 @@ def fuse_backward_ops( # Check if window matches pattern matches_pattern = True - if not ( - isinstance(window[0], BasicLinear) - and isinstance(window[1], ConstantScale) - ): + if not (isinstance(window[0], BasicLinear) and isinstance(window[1], ConstantScale)): matches_pattern = False elif window[0].tensor_parallel_mode == "column": # Column tensor-parallelism requires communication diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py index d383a0739e..dfc11a19e7 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py @@ -160,10 +160,7 @@ def fuse_forward_ops( # Check if window matches pattern matches_pattern = True - if not ( - isinstance(window[0], BasicLinear) - and isinstance(window[1], Bias) - ): + if not (isinstance(window[0], BasicLinear) and isinstance(window[1], Bias)): matches_pattern = False elif window[0].tensor_parallel_mode == "row": # Row tensor-parallelism requires communication after From 49929039f4e5afea8a4c944d970b2c6b30133e56 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Thu, 15 Jan 2026 21:48:43 +0000 Subject: [PATCH 13/45] Use TF32 tols in fused op tests Signed-off-by: Tim Moon --- tests/pytorch/test_fusible_ops.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index a4b12223da..b97e4cc3b7 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -2967,7 +2967,7 @@ def fuser_forward( # Perform compute on CPU, because why not? x = input_.cpu() w = weight.cpu() - y = torch.nn.functional.linear(x, w) + y = torch.matmul(x, w.T) z = torch.nn.functional.silu(y) out = z.to(device=device) @@ -3043,8 +3043,12 @@ def fuse_ops( assert len(forward_ops) == 1 assert isinstance(forward_ops[0][0], CustomForwardLinearSiLU) - # Check results + # Expected numerical error tols = dtype_tols(dtype) + if dtype == torch.float32: + tols = dtype_tols(torch.float16) # TF32 GEMM + + # Check results y_test = y_test.to(dtype=torch.float64, device="cpu") dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") dw_test = model[0].weight.grad.to(dtype=torch.float64, device="cpu") @@ -3088,7 +3092,7 @@ def fuser_backward( dy = grad_output.double() x = x.double() w = w.double() - dx = torch.nn.functional.linear(dy, scale * w.T) + dx = torch.matmul(dy, scale * w) dw = torch.matmul(dy.T, x) dx = dx.to(dtype=dtype) dw = dw.to(dtype=dtype) @@ -3147,8 +3151,12 @@ def fuse_ops( assert len(backward_ops) == 1 assert isinstance(backward_ops[0][0], CustomBackwardLinearScale) - # Check results + # Expected numerical error tols = dtype_tols(dtype) + if dtype == torch.float32: + tols = dtype_tols(torch.float16) # TF32 GEMM + + # Check results y_test = y_test.to(dtype=torch.float64, device="cpu") dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") dw_test = model[1].weight.grad.to(dtype=torch.float64, device="cpu") From 9ab77518e68fa6996a6fe4949fda0247af860938 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Thu, 15 Jan 2026 21:52:13 +0000 Subject: [PATCH 14/45] Review suggestion from @greptile-apps Signed-off-by: Tim Moon --- .../pytorch/ops/fused/backward_activation_bias.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/transformer_engine/pytorch/ops/fused/backward_activation_bias.py b/transformer_engine/pytorch/ops/fused/backward_activation_bias.py index 0b6dee7e73..4ab082d32b 100644 --- a/transformer_engine/pytorch/ops/fused/backward_activation_bias.py +++ b/transformer_engine/pytorch/ops/fused/backward_activation_bias.py @@ -120,9 +120,6 @@ def fuse_backward_ops( # Construct fused op if window matches pattern op = BackwardActivationBias(bias=window[1], activation=window[2]) window = [window[0], op] - while ops and len(window) < 3: - window.append(ops[0]) - ops = ops[1:] else: # Shift window if window doesn't match pattern out.extend(window[:-2]) From 9348138111c03202b82857ba944b7ce24306180e Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Thu, 15 Jan 2026 22:56:16 +0000 Subject: [PATCH 15/45] Fix linter warnings Signed-off-by: Tim Moon --- .../pytorch/ops/basic/grouped_linear.py | 44 ++++++++++++++++--- .../pytorch/ops/basic/multiply_extra_input.py | 3 +- 2 files changed, 41 insertions(+), 6 deletions(-) diff --git a/transformer_engine/pytorch/ops/basic/grouped_linear.py b/transformer_engine/pytorch/ops/basic/grouped_linear.py index d2a4b379e5..e90482b399 100644 --- a/transformer_engine/pytorch/ops/basic/grouped_linear.py +++ b/transformer_engine/pytorch/ops/basic/grouped_linear.py @@ -5,7 +5,7 @@ """Fusible operation for bias.""" from __future__ import annotations -from collections.abc import Iterable +from collections.abc import Callable, Iterable import contextlib import math from typing import Any, Optional @@ -14,13 +14,14 @@ import transformer_engine_torch as tex from ...cpp_extensions import general_grouped_gemm +from ...distributed import CudaRNGStatesTracker from ...module.base import ( _2X_ACC_FPROP, _2X_ACC_DGRAD, _2X_ACC_WGRAD, get_dummy_wgrad, ) -from ...quantization import FP8GlobalStateManager +from ...quantization import FP8GlobalStateManager, Recipe from ...tensor import Quantizer from ...utils import ( canonicalize_device, @@ -33,6 +34,40 @@ class GroupedLinear(BasicOperation): + """Apply multiple linear transformations: :math:``y_i = x_i W_i^T + b_i`` + + This is equivalent to splitting the input tensor along its first + dimension, applying a separate ``torch.nn.Linear`` to each split, + and concatenating along the first dimension. + + Paramters + --------- + group_size : int + Number of linear transformations. + in_features : int + Inner dimension of input tensor. + out_features : int + Inner dimension of output tensor. + bias : bool, default = ``True`` + Apply additive bias. + device : torch.device, default = default CUDA device + Tensor device. + dtype : torch.dtype, default = default dtype + Tensor datatype. + rng_state_tracker_function : callable + Function that returns ``CudaRNGStatesTracker``, which is used + for model-parallel weight initialization. + accumulate_into_main_grad : bool, default = ``False`` + Whether to directly accumulate weight gradients into the + weight's ``main_grad`` attribute instead of relying on PyTorch + autograd. The weight's ``main_grad`` must be set externally + and there is no guarantee that `grad` will be set or be + meaningful. This is primarily intented to integrate with + Megatron-LM. This argument along with weight tensor having + attribute ``overwrite_main_grad`` set to True will overwrite + ``main_grad`` instead of accumulating. + + """ # Operation expects input split sizes num_extra_inputs: int = 1 @@ -120,6 +155,7 @@ def num_quantizers(self, mode: str) -> int: @property def has_bias(self) -> bool: + """Whether an additive bias is being applied""" return self.bias0 is not None def reset_parameters(self) -> None: @@ -216,7 +252,7 @@ def pre_first_fuser_forward(self) -> None: f"Weight {group_idx} has requires_grad={weight.requires_grad}, " f"but expected requires_grad={weight_requires_grad}." ) - if type(weight.data) != weight_tensor_type: + if type(weight.data) != weight_tensor_type: # pylint: disable=unidiomatic-typecheck raise RuntimeError( f"Weight {group_idx} has invalid tensor type " f"(expected {weight_tensor_type.__name__}, " @@ -364,7 +400,6 @@ def fuser_forward( dtype = self.weight0.dtype # Extract split sizes from extra input - # TODO Support splits on GPU split_sizes = basic_op_extra_inputs[0][0] split_sizes_int = [int(s) for s in split_sizes.tolist()] if len(split_sizes_int) != group_size: @@ -472,7 +507,6 @@ def fuser_backward( ws, saved_tensors = saved_tensors[:group_size], saved_tensors[group_size:] # Split grad output tensor and convert dtypes if needed - # TODO Support splits on GPU split_sizes_int = [int(s) for s in split_sizes.tolist()] dy = maybe_dequantize(grad_output, ctx.dtype) dys = None diff --git a/transformer_engine/pytorch/ops/basic/multiply_extra_input.py b/transformer_engine/pytorch/ops/basic/multiply_extra_input.py index c1846f5e0d..1209963872 100644 --- a/transformer_engine/pytorch/ops/basic/multiply_extra_input.py +++ b/transformer_engine/pytorch/ops/basic/multiply_extra_input.py @@ -10,6 +10,7 @@ import torch +from ...tensor import Quantizer from ..op import BasicOperation, OperationContext from .._common import maybe_dequantize @@ -39,7 +40,7 @@ def _reduce_broadcast_dims( "Invalid target shape " f"(shape={shape} cannot be broadcast to shape={target_shape})." ) - elif len(shape) > len(target_shape): + if len(shape) > len(target_shape): reduce_dims.extend(range(len(shape) - len(target_shape))) for idx in range(-len(target_shape), 0): if shape[idx] == target_shape[idx]: From 536672976b0994c6d2e6689936cbfcb9c1bea8aa Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 15 Jan 2026 23:02:25 +0000 Subject: [PATCH 16/45] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/test_fusible_ops.py | 12 +++--- .../pytorch/ops/basic/grouped_linear.py | 39 ++++++++----------- .../pytorch/ops/basic/multiply_extra_input.py | 6 +-- 3 files changed, 24 insertions(+), 33 deletions(-) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index dfb92c6863..a39d4e521a 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -2075,12 +2075,12 @@ def test_grouped_linear( @pytest.mark.parametrize( "input_shape,extra_input_shape", ( - ((3,4,5), (3,4,5)), - ((6,7), ()), - ((), (8,9)), - ((10,11,12), (11,1)), - ((1,15), (13,14,15)), - ) + ((3, 4, 5), (3, 4, 5)), + ((6, 7), ()), + ((), (8, 9)), + ((10, 11, 12), (11, 1)), + ((1, 15), (13, 14, 15)), + ), ) @pytest.mark.parametrize("input_requires_grad", (False, True)) @pytest.mark.parametrize("extra_input_requires_grad", (False, True)) diff --git a/transformer_engine/pytorch/ops/basic/grouped_linear.py b/transformer_engine/pytorch/ops/basic/grouped_linear.py index e90482b399..b7a6b843e4 100644 --- a/transformer_engine/pytorch/ops/basic/grouped_linear.py +++ b/transformer_engine/pytorch/ops/basic/grouped_linear.py @@ -239,8 +239,7 @@ def pre_first_fuser_forward(self) -> None: weight = getattr(self, f"weight{group_idx}") if weight.dtype != dtype: raise RuntimeError( - f"Weight {group_idx} has invalid dtype " - f"(expected {dtype}, got {weight.dtype})." + f"Weight {group_idx} has invalid dtype (expected {dtype}, got {weight.dtype})." ) if not devices_match(weight.device, device): raise RuntimeError( @@ -264,13 +263,10 @@ def pre_first_fuser_forward(self) -> None: bias = getattr(self, f"bias{group_idx}") if self.has_bias: if bias is None: - raise RuntimeError( - f"Expected biases, but bias {group_idx} is uninitialized" - ) + raise RuntimeError(f"Expected biases, but bias {group_idx} is uninitialized") if bias.dtype != dtype: raise RuntimeError( - f"Bias {group_idx} has invalid dtype " - f"(expected {dtype}, got {bias.dtype})." + f"Bias {group_idx} has invalid dtype (expected {dtype}, got {bias.dtype})." ) if not devices_match(bias.device, device): raise RuntimeError( @@ -284,9 +280,7 @@ def pre_first_fuser_forward(self) -> None: ) else: if bias is not None: - raise RuntimeError( - f"Expected no biases, but bias {group_idx} is initialized" - ) + raise RuntimeError(f"Expected no biases, but bias {group_idx} is initialized") def pre_fuser_forward(self, *, requires_grad: bool) -> None: super().pre_fuser_forward(requires_grad=requires_grad) @@ -345,8 +339,12 @@ def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None: input_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon weight_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_weight.power_2_scale weight_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_weight.amax_epsilon - grad_output_quantizer.force_pow_2_scales = recipe.fp8_quant_bwd_grad.power_2_scale - grad_output_quantizer.amax_epsilon_scales = recipe.fp8_quant_bwd_grad.amax_epsilon + grad_output_quantizer.force_pow_2_scales = ( + recipe.fp8_quant_bwd_grad.power_2_scale + ) + grad_output_quantizer.amax_epsilon_scales = ( + recipe.fp8_quant_bwd_grad.amax_epsilon + ) def op_forward(self, *args, **kwargs): raise RuntimeError( @@ -403,18 +401,13 @@ def fuser_forward( split_sizes = basic_op_extra_inputs[0][0] split_sizes_int = [int(s) for s in split_sizes.tolist()] if len(split_sizes_int) != group_size: - raise ValueError( - f"Expected {group_size} splits, but got {len(split_sizes_int)}." - ) + raise ValueError(f"Expected {group_size} splits, but got {len(split_sizes_int)}.") # Extract params weights = [getattr(self, f"weight{idx}") for idx in range(group_size)] bs = None if has_bias: - bs = [ - maybe_dequantize(getattr(self, f"bias{idx}"), dtype) - for idx in range(group_size) - ] + bs = [maybe_dequantize(getattr(self, f"bias{idx}"), dtype) for idx in range(group_size)] # Convert weight dtype if needed ws = [] @@ -526,9 +519,7 @@ def fuser_backward( else: dys = torch.split(dy, split_sizes_int) if has_bias: - grad_biases = [ - dy.reshape(-1, dy.size(-1)).sum(dim=0) for dy in dys - ] + grad_biases = [dy.reshape(-1, dy.size(-1)).sum(dim=0) for dy in dys] # Initialize grad weight grads accumulate_into_main_grad = self._accumulate_into_main_grad @@ -542,7 +533,9 @@ def fuser_backward( weight_param = getattr(self, f"weight{group_idx}") if hasattr(weight_param, "__fsdp_param__"): weight_param.main_grad = weight_param.get_main_grad() - accumulate_into_main_grad = not getattr(weight_param, "overwrite_main_grad", False) + accumulate_into_main_grad = not getattr( + weight_param, "overwrite_main_grad", False + ) grad_weights[group_idx] = weight_param.main_grad else: weight_shape = ws[0].size() diff --git a/transformer_engine/pytorch/ops/basic/multiply_extra_input.py b/transformer_engine/pytorch/ops/basic/multiply_extra_input.py index 1209963872..f9dfef4d81 100644 --- a/transformer_engine/pytorch/ops/basic/multiply_extra_input.py +++ b/transformer_engine/pytorch/ops/basic/multiply_extra_input.py @@ -37,8 +37,7 @@ def _reduce_broadcast_dims( reduce_dims = [] if len(shape) < len(target_shape): raise ValueError( - "Invalid target shape " - f"(shape={shape} cannot be broadcast to shape={target_shape})." + f"Invalid target shape (shape={shape} cannot be broadcast to shape={target_shape})." ) if len(shape) > len(target_shape): reduce_dims.extend(range(len(shape) - len(target_shape))) @@ -47,8 +46,7 @@ def _reduce_broadcast_dims( pass elif target_shape[idx] != 1: raise ValueError( - "Invalid target shape " - f"(shape={shape} cannot be broadcast to shape={target_shape})." + f"Invalid target shape (shape={shape} cannot be broadcast to shape={target_shape})." ) else: reduce_dims.append(idx) From 321646ec25db774eb97450ef926d6bdf1d0a7468 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Fri, 16 Jan 2026 03:17:33 +0000 Subject: [PATCH 17/45] Initial impl of fused op for grouped MLP Signed-off-by: Tim Moon --- transformer_engine/pytorch/ops/__init__.py | 2 + .../pytorch/ops/fused/__init__.py | 2 + .../pytorch/ops/fused/forward_grouped_mlp.py | 296 ++++++++++++++++++ 3 files changed, 300 insertions(+) create mode 100644 transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py diff --git a/transformer_engine/pytorch/ops/__init__.py b/transformer_engine/pytorch/ops/__init__.py index c61b50417d..cd37144dfa 100644 --- a/transformer_engine/pytorch/ops/__init__.py +++ b/transformer_engine/pytorch/ops/__init__.py @@ -13,3 +13,5 @@ from .linear import Linear from .op import BasicOperation, FusedOperation, FusibleOperation from .sequential import Sequential + +import transformer_engine.pytorch.ops.fused diff --git a/transformer_engine/pytorch/ops/fused/__init__.py b/transformer_engine/pytorch/ops/fused/__init__.py index 1ebfe23060..691b23cf77 100644 --- a/transformer_engine/pytorch/ops/fused/__init__.py +++ b/transformer_engine/pytorch/ops/fused/__init__.py @@ -14,8 +14,10 @@ from .forward_linear_scale_add import ForwardLinearScaleAdd from .userbuffers_backward_linear import UserbuffersBackwardLinear from .userbuffers_forward_linear import UserbuffersForwardLinear +from .forward_grouped_mlp import ForwardGroupedMLP_CuTeGEMMSwiGLU # Register forward fusions +register_forward_fusion(ForwardGroupedMLP_CuTeGEMMSwiGLU.fuse_forward_ops) register_forward_fusion(UserbuffersForwardLinear.fuse_forward_ops) register_forward_fusion(ForwardLinearBiasAdd.fuse_forward_ops) register_forward_fusion(ForwardLinearBiasActivation.fuse_forward_ops) diff --git a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py new file mode 100644 index 0000000000..078554ea01 --- /dev/null +++ b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py @@ -0,0 +1,296 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Fused operation for forward GEMM + scale + add.""" + +from __future__ import annotations +from collections.abc import Iterable +from typing import Any, Optional + +import torch + +import transformer_engine_torch as tex +from ...cpp_extensions import general_grouped_gemm +from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload +from ...quantization import FP8GlobalStateManager +from ...tensor import Quantizer +from ..basic import GroupedLinear, MultiplyExtraInput, SwiGLU +from ..fuser import register_forward_fusion +from ..op import FusedOperation, FusibleOperation, OperationContext +from .._common import is_quantized_tensor, maybe_dequantize + + +class ForwardGroupedMLP_CuTeGEMMSwiGLU(FusedOperation): + + def __init__( + self, + *, + fc1: GroupedLinear, + swiglu: SwiGLU, + fc2: GroupedLinear, + scale: MultiplyExtraInput, + ) -> None: + super().__init__((fc1, swiglu, fc2, scale)) + + def fuser_forward( + self, + basic_op_ctxs: list[OperationContext], + input_: torch.Tensor, + *, + basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], + prev_op_grad_output_quantizer: Optional[Quantizer], + next_op_input_quantizer: Optional[Quantizer], + basic_op_kwargs: list[dict[str, Any]], + ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: + + # Get basic operations + fc1_op, swiglu_op, fc2_op, scale_op = self.basic_ops + fc1_ctx, swiglu_ctx, fc2_ctx, scale_ctx = basic_op_ctxs + + group_size = fc1_op.group_size + device = fc1_op.weight0.device + in_shape = list(input_.size()) + + # Check which grads are required + requires_grad = any(ctx.requires_grad for ctx in basic_op_ctxs) + input_requires_grad = requires_grad + weight_requires_grad = ( + requires_grad + and (fc1_op.weight0.requires_grad or fc2_op.weight0.requires_grad) + ) + + # Quantizers + with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() + fc1_input_quantizers = [None] * group_size + fc1_weight_quantizers = [None] * group_size + fc1_grad_output_quantizers = [None] * group_size + fc2_input_quantizers = [None] * group_size + fc2_weight_quantizers = [None] * group_size + fc2_grad_output_quantizers = [None] * group_size + if with_quantized_compute: + for idx in range(group_size): + fc1_input_quantizers[idx] = fc1_op.get_quantizer("forward", 2 * group_idx) + fc1_weight_quantizers[idx] = fc1_op.get_quantizer("forward", 2 * group_idx + 1) + fc1_grad_output_quantizers[idx] = fc1_op.get_quantizer("backward", group_idx) + fc2_input_quantizers[idx] = fc2_op.get_quantizer("forward", 2 * group_idx) + fc2_weight_quantizers[idx] = fc2_op.get_quantizer("forward", 2 * group_idx + 1) + fc2_grad_output_quantizers[idx] = fc2_op.get_quantizer("backward", group_idx) + + # Get autocast dtype if needed + if torch.is_autocast_enabled(): + dtype = torch.get_autocast_dtype("cuda") + else: + dtype = fc1_op.weight0.dtype + + # Extract split sizes from extra input + fc1_split_sizes = basic_op_extra_inputs[0][0] + fc2_split_sizes = basic_op_extra_inputs[2][0] + if ( + fc1_split_sizes.size() != fc2_split_sizes.size() + or fc1_split_sizes.data_ptr() != fc2_split_sizes.data_ptr() + ): + raise RuntimeError( + f"{self.__class__.__name__} got different split points for FC1 and FC2." + ) + split_sizes = fc1_split_sizes + split_sizes_int = [int(s) for s in split_sizes.tolist()] + if len(split_sizes_int) != group_size: + raise ValueError(f"Expected {group_size} splits, but got {len(split_sizes_int)}.") + + # Extract params + fc1_weights = [getattr(fc1_op, f"weight{idx}") for idx in range(group_size)] + fc2_weights = [getattr(fc2_op, f"weight{idx}") for idx in range(group_size)] + + # Convert weight dtype if needed + fc1_ws = [] + fc2_ws = [] + for w, quantizer in zip(fc1_weights, fc1_weight_quantizers): + if not with_quantized_compute: + w = maybe_dequantize(w, dtype) + elif with_quantized_compute and not is_quantized_tensor(w): + quantizer = weight_quantizers[group_idx] + quantizer.set_usage(rowwise=True, columnwise=input_requires_grad) + w = quantizer(w) + fc1_ws.append(w) + for w, quantizer in zip(fc2_weights, fc2_weight_quantizers): + if not with_quantized_compute: + w = maybe_dequantize(w, dtype) + elif with_quantized_compute and not is_quantized_tensor(w): + quantizer = weight_quantizers[group_idx] + quantizer.set_usage(rowwise=True, columnwise=input_requires_grad) + w = quantizer(w) + fc2_ws.append(w) + + # Split input tensor and convert dtypes if needed + fc1_x = maybe_dequantize(input_, dtype) + fc1_xs = None + if with_quantized_compute: + for quantizer in fc1_input_quantizers: + quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) + fc1_xs = tex.split_quantize(fc1_x, split_sizes_int, fc1_input_quantizers) + else: + fc1_xs = torch.split(fc1_x, split_sizes_int) + + # FC1 GEMM + fc1_out_shape = in_shape[:-1] + [fc1_op.out_features] + fc1_out = torch.empty(fc1_out_shape, dtype=dtype, device=device) + general_grouped_gemm( + fc1_ws, + fc1_xs, + [fc1_out], + [None] * group_size, # quantization_params + dtype, + m_splits=split_sizes_int, + bias=[None] * group_size, + use_bias=False, + single_output=True, + ) + + # SwiGLU + swiglu_in = fc1_out + swiglu_out = tex.swiglu(swiglu_in, None) + + # Split input tensor and convert dtypes if needed + fc2_x = swiglu_out + fc2_xs = None + if with_quantized_compute: + for quantizer in fc2_input_quantizers: + quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) + fc2_xs = tex.split_quantize(fc2_x, split_sizes_int, fc2_input_quantizers) + else: + fc2_xs = torch.split(fc2_x, split_sizes_int) + + # FC2 GEMM + fc2_out_shape = in_shape[:-1] + [fc2_op.out_features] + fc2_out = torch.empty(fc2_out_shape, dtype=dtype, device=device) + general_grouped_gemm( + fc2_ws, + fc2_xs, + [fc2_out], + [None] * group_size, # quantization_params + dtype, + m_splits=split_sizes_int, + bias=[None] * group_size, + use_bias=False, + single_output=True, + ) + + # Post-scale + scales = basic_op_extra_inputs[3][0] + scales_shape = tuple(scales.size()) + if scales.numel() != scales_shape[0]: + raise RuntimeError( + f"{self.__class__.__name__} assumes scales are over leading dim, " + f"but got shape={scales_shape}." + ) + out = fc2_out * scales + + # Save state for backward pass + if requires_grad: + # FC1 state + fc1_ctx.save_for_backward(split_sizes, *fc1_xs, *fc1_ws) + fc1_ctx.with_quantized_compute = with_quantized_compute + fc1_ctx.input_quantizers = fc1_input_quantizers + fc1_ctx.weight_quantizers = fc1_weight_quantizers + fc1_ctx.grad_output_quantizers = fc1_grad_output_quantizers + fc1_ctx.grad_input_quantizers = None + fc1_ctx.dtype = dtype + fc1_ctx.input_requires_grad = input_requires_grad + fc1_ctx.weight_requires_grad = weight_requires_grad + + # SwiGLU + swiglu_ctx.save_for_backward(swiglu_in) + swiglu_ctx.dtype = dtype + swiglu_ctx.prev_op_grad_output_quantizer = None + + # FC2 state + fc2_ctx.save_for_backward(split_sizes, *fc2_xs, *fc2_ws) + fc2_ctx.with_quantized_compute = with_quantized_compute + fc2_ctx.input_quantizers = fc2_input_quantizers + fc2_ctx.weight_quantizers = fc2_weight_quantizers + fc2_ctx.grad_output_quantizers = fc2_grad_output_quantizers + fc2_ctx.grad_input_quantizers = None + fc2_ctx.dtype = dtype + fc2_ctx.input_requires_grad = input_requires_grad + fc2_ctx.weight_requires_grad = weight_requires_grad + + # Scale + scale_ctx.save_for_backward(fc2_out, scales) + scale_ctx.input_shape = fc2_out.size() + scale_ctx.extra_input_shape = scales_shape + scale_ctx.input_requires_grad = True + scale_ctx.extra_input_requires_grad = scales.requires_grad + + return out, [(), (), (), ()] + + @staticmethod + def fuse_forward_ops( + ops: list[FusibleOperation], + *, + recipe: Optional[Recipe] = None, + **unused, # pylint: disable=unused-argument + ) -> list[FusibleOperation]: + """Apply operation fusion for forward pass. + + Parameters + ---------- + ops : list of FusibleOperation + Forward pass operations. + recipe : Recipe, optional + Quantization recipe. + + Returns + ------- + ops : list of FusibleOperation + Updated forward pass operations + + """ + + # Check if recipe is supported + if recipe is not None: + return ops + + # Scan through ops, fusing if possible + out = [] + window, ops = ops[:4], ops[4:] + while len(window) == 4: + + # Check if window matches pattern + matches_pattern = True + if not ( + isinstance(window[0], GroupedLinear) + and isinstance(window[1], SwiGLU) + and isinstance(window[2], GroupedLinear) + and isinstance(window[3], MultiplyExtraInput) + ): + matches_pattern = False + elif window[0].has_bias or window[2].has_bias: + matches_pattern = False + elif window[0].group_size != window[2].group_size: + matches_pattern = False + + if matches_pattern: + # Construct fused op if window matches pattern + op = ForwardGroupedMLP_CuTeGEMMSwiGLU( + fc1=window[0], + swiglu=window[1], + fc2=window[2], + scale=window[3], + ) + window = [op] + else: + # Shift window if window doesn't match pattern + out.extend(window[:-3]) + window = window[-3:] + + # Adjust window to expected size + out.extend(window[:-4]) + window = window[-4:] + while ops and len(window) < 4: + window.append(ops[0]) + ops = ops[1:] + + # Return list of ops + out.extend(window) + return out From e137451d002702b3e519c565769936c2169e8f7f Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Sat, 17 Jan 2026 02:45:26 +0000 Subject: [PATCH 18/45] Import group GEMM+SwiGLU kernel Signed-off-by: Tim Moon --- .../pytorch/ops/fused/__init__.py | 5 +- .../pytorch/ops/fused/forward_grouped_mlp.py | 156 +++++++++--------- 2 files changed, 82 insertions(+), 79 deletions(-) diff --git a/transformer_engine/pytorch/ops/fused/__init__.py b/transformer_engine/pytorch/ops/fused/__init__.py index 691b23cf77..cabd86442a 100644 --- a/transformer_engine/pytorch/ops/fused/__init__.py +++ b/transformer_engine/pytorch/ops/fused/__init__.py @@ -14,10 +14,8 @@ from .forward_linear_scale_add import ForwardLinearScaleAdd from .userbuffers_backward_linear import UserbuffersBackwardLinear from .userbuffers_forward_linear import UserbuffersForwardLinear -from .forward_grouped_mlp import ForwardGroupedMLP_CuTeGEMMSwiGLU # Register forward fusions -register_forward_fusion(ForwardGroupedMLP_CuTeGEMMSwiGLU.fuse_forward_ops) register_forward_fusion(UserbuffersForwardLinear.fuse_forward_ops) register_forward_fusion(ForwardLinearBiasAdd.fuse_forward_ops) register_forward_fusion(ForwardLinearBiasActivation.fuse_forward_ops) @@ -29,3 +27,6 @@ register_backward_fusion(BackwardLinearScale.fuse_backward_ops) register_backward_fusion(BackwardActivationBias.fuse_backward_ops) register_backward_fusion(BackwardAddRMSNorm.fuse_backward_ops) + +from .forward_grouped_mlp import fuse_forward_ops as forward_grouped_mlp_fuse_ops +register_forward_fusion(forward_grouped_mlp_fuse_ops, prepend=True) diff --git a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py index 078554ea01..52dac035f7 100644 --- a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py @@ -9,6 +9,7 @@ from typing import Any, Optional import torch +from cudnn import grouped_gemm_swiglu_wrapper_sm100 ### TODO Check if available import transformer_engine_torch as tex from ...cpp_extensions import general_grouped_gemm @@ -21,7 +22,7 @@ from .._common import is_quantized_tensor, maybe_dequantize -class ForwardGroupedMLP_CuTeGEMMSwiGLU(FusedOperation): +class ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8(FusedOperation): def __init__( self, @@ -70,12 +71,12 @@ def fuser_forward( fc2_grad_output_quantizers = [None] * group_size if with_quantized_compute: for idx in range(group_size): - fc1_input_quantizers[idx] = fc1_op.get_quantizer("forward", 2 * group_idx) - fc1_weight_quantizers[idx] = fc1_op.get_quantizer("forward", 2 * group_idx + 1) - fc1_grad_output_quantizers[idx] = fc1_op.get_quantizer("backward", group_idx) - fc2_input_quantizers[idx] = fc2_op.get_quantizer("forward", 2 * group_idx) - fc2_weight_quantizers[idx] = fc2_op.get_quantizer("forward", 2 * group_idx + 1) - fc2_grad_output_quantizers[idx] = fc2_op.get_quantizer("backward", group_idx) + fc1_input_quantizers[idx] = fc1_op.get_quantizer("forward", 2 * idx) + fc1_weight_quantizers[idx] = fc1_op.get_quantizer("forward", 2 * idx + 1) + fc1_grad_output_quantizers[idx] = fc1_op.get_quantizer("backward", idx) + fc2_input_quantizers[idx] = fc2_op.get_quantizer("forward", 2 * idx) + fc2_weight_quantizers[idx] = fc2_op.get_quantizer("forward", 2 * idx + 1) + fc2_grad_output_quantizers[idx] = fc2_op.get_quantizer("backward", idx) # Get autocast dtype if needed if torch.is_autocast_enabled(): @@ -224,73 +225,74 @@ def fuser_forward( return out, [(), (), (), ()] - @staticmethod - def fuse_forward_ops( - ops: list[FusibleOperation], - *, - recipe: Optional[Recipe] = None, - **unused, # pylint: disable=unused-argument - ) -> list[FusibleOperation]: - """Apply operation fusion for forward pass. - - Parameters - ---------- - ops : list of FusibleOperation - Forward pass operations. - recipe : Recipe, optional - Quantization recipe. - - Returns - ------- - ops : list of FusibleOperation - Updated forward pass operations - - """ - - # Check if recipe is supported - if recipe is not None: - return ops - - # Scan through ops, fusing if possible - out = [] - window, ops = ops[:4], ops[4:] - while len(window) == 4: - - # Check if window matches pattern - matches_pattern = True - if not ( - isinstance(window[0], GroupedLinear) - and isinstance(window[1], SwiGLU) - and isinstance(window[2], GroupedLinear) - and isinstance(window[3], MultiplyExtraInput) - ): - matches_pattern = False - elif window[0].has_bias or window[2].has_bias: - matches_pattern = False - elif window[0].group_size != window[2].group_size: - matches_pattern = False - - if matches_pattern: - # Construct fused op if window matches pattern - op = ForwardGroupedMLP_CuTeGEMMSwiGLU( - fc1=window[0], - swiglu=window[1], - fc2=window[2], - scale=window[3], - ) - window = [op] - else: - # Shift window if window doesn't match pattern - out.extend(window[:-3]) - window = window[-3:] - - # Adjust window to expected size - out.extend(window[:-4]) - window = window[-4:] - while ops and len(window) < 4: - window.append(ops[0]) - ops = ops[1:] - - # Return list of ops - out.extend(window) - return out +def fuse_forward_ops( + ops: list[FusibleOperation], + *, + recipe: Optional[Recipe] = None, + **unused, # pylint: disable=unused-argument +) -> list[FusibleOperation]: + """Apply operation fusion for forward pass. + + Parameters + ---------- + ops : list of FusibleOperation + Forward pass operations. + recipe : Recipe, optional + Quantization recipe. + + Returns + ------- + ops : list of FusibleOperation + Updated forward pass operations + + """ + + # Check if recipe is supported + if recipe is None: + return ops + if not recipe.mxfp8(): + return ops + + # Scan through ops, fusing if possible + out = [] + window, ops = ops[:4], ops[4:] + while len(window) == 4: + + # Check if window matches pattern + matches_pattern = True + if not ( + isinstance(window[0], GroupedLinear) + and isinstance(window[1], SwiGLU) + and isinstance(window[2], GroupedLinear) + and isinstance(window[3], MultiplyExtraInput) + ): + matches_pattern = False + elif window[0].has_bias or window[2].has_bias: + matches_pattern = False + elif window[0].group_size != window[2].group_size: + matches_pattern = False + + if matches_pattern: + # Construct fused op if window matches pattern + op = ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8( + fc1=window[0], + swiglu=window[1], + fc2=window[2], + scale=window[3], + ) + window = [op] + else: + # Shift window if window doesn't match pattern + out.extend(window[:-3]) + window = window[-3:] + + # Adjust window to expected size + out.extend(window[:-4]) + window = window[-4:] + while ops and len(window) < 4: + window.append(ops[0]) + ops = ops[1:] + + # Return list of ops + out.extend(window) + return out From cb728bb545c3fbbfd7194d6db2b3cba20ccff035 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Tue, 20 Jan 2026 22:09:36 +0000 Subject: [PATCH 19/45] Add unit test for grouped MLP op Signed-off-by: Tim Moon --- tests/pytorch/test_fusible_ops.py | 147 ++++++++++++++++++++++++++++++ 1 file changed, 147 insertions(+) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 9363c9d89a..e3300ddc4e 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -3143,6 +3143,153 @@ def to_cpu(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]: torch.testing.assert_close(to_cpu(ffn1.bias.grad), b1_ref.grad, **tols) torch.testing.assert_close(to_cpu(ffn2.bias.grad), b2_ref.grad, **tols) + def test_grouped_mlp( + self, + *, + dtype: torch.dtype = torch.bfloat16, + quantization: Optional[str] = "mxfp8", + device: torch.device = "cuda", + group_size: int = 4, + hidden_size: int = 256, + split_alignment: int = 256, + ) -> None: + """GroupedLinear + SwiGLU + GroupedLinear""" + + # Split sizes + split_sizes = [split_alignment * i for i in range(group_size)] + random.shuffle(split_sizes) + split_sizes = torch.tensor(split_sizes, dtype=torch.int, device="cpu") + + # Make input shape + in_shape = (split_sizes.sum().item(), hidden_size) + out_shape = in_shape + + # Skip invalid configurations + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) + if quantization != "mxfp8": + pytest.skip("Quantization scheme is not supported") + if dtype != torch.bfloat16: + pytest.skip("Non-quantized dtype must be BF16") + + # Random data + x_ref, x_test = make_reference_and_test_tensors( + in_shape, + quantization=quantization, + test_dtype=dtype, + test_device=device, + ) + dy_ref, dy_test = make_reference_and_test_tensors( + out_shape, + quantization=quantization, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + probs_ref, probs_test = make_reference_and_test_tensors( + (in_shape[0], 1), + test_dtype=dtype, + test_device=device, + ) + fc1_ws_ref, fc1_ws_test = [], [] + fc2_ws_ref, fc2_ws_test = [], [] + for _ in range(group_size): + fc1_w_ref, fc1_w_test = make_reference_and_test_tensors( + (2 * hidden_size, hidden_size), + quantization=quantization, + test_dtype=dtype, + test_device=device, + ) + fc2_w_ref, fc2_w_test = make_reference_and_test_tensors( + (hidden_size, hidden_size), + quantization=quantization, + test_dtype=dtype, + test_device=device, + ) + fc1_ws_ref.append(fc1_w_ref) + fc1_ws_test.append(fc1_w_test) + fc2_ws_ref.append(fc2_w_ref) + fc2_ws_test.append(fc2_w_test) + with torch.no_grad(): + for t in fc1_ws_ref + fc1_ws_test + fc2_ws_ref + fc2_ws_test: + t *= 1 / 64 + for t in (x_ref, x_test, dy_ref, dy_test): + t -= 0.5 + + # Reference implementation + xs = torch.split(x_ref, split_sizes.tolist()) + ys = [] + for x, fc1_w, fc2_w, prob in zip(xs, fc1_ws_ref, fc2_ws_ref, probs_ref): + x = torch.nn.functional.linear(x, fc1_w) + x1, x2 = x.chunk(2, dim=-1) + x = torch.nn.functional.silu(x1) * x2 + x = torch.nn.functional.linear(x, fc2_w) + x = x * prob + ys.append(x) + y_ref = torch.cat(ys) + y_ref.backward(dy_ref) + + # Construct operations + recipe = make_recipe(quantization) + with te.quantized_model_init(recipe=recipe): + fc1 = te_ops.GroupedLinear( + group_size, + hidden_size, + 2 * hidden_size, + bias=False, + device=device, + dtype=dtype, + ) + fc2 = te_ops.GroupedLinear( + group_size, + hidden_size, + hidden_size, + bias=False, + device=device, + dtype=dtype, + ) + module = te_ops.Sequential( + fc1, + te_ops.SwiGLU(), + fc2, + te_ops.MultiplyExtraInput() + ) + + # Copy weights + with torch.no_grad(): + for group_idx in range(group_size): + getattr(fc1, f"weight{group_idx}").copy_(fc1_ws_test[group_idx]) + getattr(fc2, f"weight{group_idx}").copy_(fc2_ws_test[group_idx]) + del fc1_ws_test, fc2_ws_test + + # Fuse ops and perform forward and backward pass + with te.autocast(recipe=recipe): + y_test = module(x_test, split_sizes, split_sizes, probs_test) + y_test.backward(dy_test) + + # Check that forward operations have been fused + forward_ops = module._module_groups[0]._forward_ops + assert len(forward_ops) == 1 + + def to_cpu(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]: + """Convert to FP64 CPU tensor""" + if tensor is None: + return None + out = tensor.detach().to(dtype=torch.float64, device="cpu") + out = out.requires_grad_(requires_grad=tensor.requires_grad) + return out + + # Check values + tols = {"rtol": 0.25, "atol": 0.5} # Loose tols for sanity checking + torch.testing.assert_close(to_cpu(y_test), y_ref, **tols) + torch.testing.assert_close(to_cpu(x_test.grad), x_ref.grad, **tols) + for group_idx in range(group_size): + fc1_dw_test = to_cpu(getattr(fc1, f"weight{group_idx}").grad) + fc1_dw_ref = fc1_ws_ref[group_idx].grad + fc2_dw_test = to_cpu(getattr(fc2, f"weight{group_idx}").grad) + fc2_dw_ref = fc2_ws_ref[group_idx].grad + torch.testing.assert_close(fc2_dw_test, fc2_dw_ref, **tols) + torch.testing.assert_close(fc1_dw_test, fc1_dw_ref, **tols) + class TestCustomOps: """Test with ops that are defined externally""" From e7459cc2f487ce2b8083f27fd7e020338149c6c5 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Wed, 21 Jan 2026 04:14:07 +0000 Subject: [PATCH 20/45] Call fused group GEMM + SwiGLU kernel Signed-off-by: Tim Moon --- tests/pytorch/test_fusible_ops.py | 6 +- .../pytorch/ops/fused/forward_grouped_mlp.py | 238 +++++++++++------- 2 files changed, 153 insertions(+), 91 deletions(-) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index e3300ddc4e..cd03f8564a 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -3222,8 +3222,8 @@ def test_grouped_mlp( x = torch.nn.functional.linear(x, fc1_w) x1, x2 = x.chunk(2, dim=-1) x = torch.nn.functional.silu(x1) * x2 - x = torch.nn.functional.linear(x, fc2_w) x = x * prob + x = torch.nn.functional.linear(x, fc2_w) ys.append(x) y_ref = torch.cat(ys) y_ref.backward(dy_ref) @@ -3250,8 +3250,8 @@ def test_grouped_mlp( module = te_ops.Sequential( fc1, te_ops.SwiGLU(), + te_ops.MultiplyExtraInput(), fc2, - te_ops.MultiplyExtraInput() ) # Copy weights @@ -3263,7 +3263,7 @@ def test_grouped_mlp( # Fuse ops and perform forward and backward pass with te.autocast(recipe=recipe): - y_test = module(x_test, split_sizes, split_sizes, probs_test) + y_test = module(x_test, split_sizes, probs_test, split_sizes) y_test.backward(dy_test) # Check that forward operations have been fused diff --git a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py index 52dac035f7..bcfb0bb158 100644 --- a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py @@ -15,7 +15,7 @@ from ...cpp_extensions import general_grouped_gemm from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...quantization import FP8GlobalStateManager -from ...tensor import Quantizer +from ...tensor import MXFP8Tensor, Quantizer from ..basic import GroupedLinear, MultiplyExtraInput, SwiGLU from ..fuser import register_forward_fusion from ..op import FusedOperation, FusibleOperation, OperationContext @@ -29,10 +29,10 @@ def __init__( *, fc1: GroupedLinear, swiglu: SwiGLU, - fc2: GroupedLinear, scale: MultiplyExtraInput, + fc2: GroupedLinear, ) -> None: - super().__init__((fc1, swiglu, fc2, scale)) + super().__init__((fc1, swiglu, scale, fc2)) def fuser_forward( self, @@ -46,12 +46,20 @@ def fuser_forward( ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: # Get basic operations - fc1_op, swiglu_op, fc2_op, scale_op = self.basic_ops - fc1_ctx, swiglu_ctx, fc2_ctx, scale_ctx = basic_op_ctxs + fc1_op, swiglu_op, scale_op, fc2_op = self.basic_ops + fc1_ctx, swiglu_ctx, scale_ctx, fc2_ctx = basic_op_ctxs + # Tensor properties + in_shape = list(input_.size()) + assert len(in_shape) == 2, f"Expected 2D input tensor, got shape={in_shape}." + fc1_weight_shape = (fc1_op.out_features, fc1_op.in_features) + fc2_weight_shape = (fc2_op.out_features, fc2_op.in_features) group_size = fc1_op.group_size device = fc1_op.weight0.device - in_shape = list(input_.size()) + if torch.is_autocast_enabled(): + dtype = torch.get_autocast_dtype("cuda") + else: + dtype = fc1_op.weight0.dtype # Check which grads are required requires_grad = any(ctx.requires_grad for ctx in basic_op_ctxs) @@ -62,31 +70,23 @@ def fuser_forward( ) # Quantizers - with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() fc1_input_quantizers = [None] * group_size fc1_weight_quantizers = [None] * group_size fc1_grad_output_quantizers = [None] * group_size fc2_input_quantizers = [None] * group_size fc2_weight_quantizers = [None] * group_size fc2_grad_output_quantizers = [None] * group_size - if with_quantized_compute: - for idx in range(group_size): - fc1_input_quantizers[idx] = fc1_op.get_quantizer("forward", 2 * idx) - fc1_weight_quantizers[idx] = fc1_op.get_quantizer("forward", 2 * idx + 1) - fc1_grad_output_quantizers[idx] = fc1_op.get_quantizer("backward", idx) - fc2_input_quantizers[idx] = fc2_op.get_quantizer("forward", 2 * idx) - fc2_weight_quantizers[idx] = fc2_op.get_quantizer("forward", 2 * idx + 1) - fc2_grad_output_quantizers[idx] = fc2_op.get_quantizer("backward", idx) - - # Get autocast dtype if needed - if torch.is_autocast_enabled(): - dtype = torch.get_autocast_dtype("cuda") - else: - dtype = fc1_op.weight0.dtype + for idx in range(group_size): + fc1_input_quantizers[idx] = fc1_op.get_quantizer("forward", 2 * idx) + fc1_weight_quantizers[idx] = fc1_op.get_quantizer("forward", 2 * idx + 1) + fc1_grad_output_quantizers[idx] = fc1_op.get_quantizer("backward", idx) + fc2_input_quantizers[idx] = fc2_op.get_quantizer("forward", 2 * idx) + fc2_weight_quantizers[idx] = fc2_op.get_quantizer("forward", 2 * idx + 1) + fc2_grad_output_quantizers[idx] = fc2_op.get_quantizer("backward", idx) # Extract split sizes from extra input fc1_split_sizes = basic_op_extra_inputs[0][0] - fc2_split_sizes = basic_op_extra_inputs[2][0] + fc2_split_sizes = basic_op_extra_inputs[3][0] if ( fc1_split_sizes.size() != fc2_split_sizes.size() or fc1_split_sizes.data_ptr() != fc2_split_sizes.data_ptr() @@ -99,6 +99,15 @@ def fuser_forward( if len(split_sizes_int) != group_size: raise ValueError(f"Expected {group_size} splits, but got {len(split_sizes_int)}.") + # Extract post-scales from extra input + scales = basic_op_extra_inputs[2][0] + scales_shape = tuple(scales.size()) + if scales.numel() != scales_shape[0]: + raise RuntimeError( + f"{self.__class__.__name__} assumes scales are over leading dim, " + f"but got shape={scales_shape}." + ) + # Extract params fc1_weights = [getattr(fc1_op, f"weight{idx}") for idx in range(group_size)] fc2_weights = [getattr(fc2_op, f"weight{idx}") for idx in range(group_size)] @@ -107,17 +116,14 @@ def fuser_forward( fc1_ws = [] fc2_ws = [] for w, quantizer in zip(fc1_weights, fc1_weight_quantizers): - if not with_quantized_compute: - w = maybe_dequantize(w, dtype) - elif with_quantized_compute and not is_quantized_tensor(w): + if not is_quantized_tensor(w): quantizer = weight_quantizers[group_idx] quantizer.set_usage(rowwise=True, columnwise=input_requires_grad) w = quantizer(w) + tex.swizzle_scales_for_gemm_(w) fc1_ws.append(w) for w, quantizer in zip(fc2_weights, fc2_weight_quantizers): - if not with_quantized_compute: - w = maybe_dequantize(w, dtype) - elif with_quantized_compute and not is_quantized_tensor(w): + if not is_quantized_tensor(w): quantizer = weight_quantizers[group_idx] quantizer.set_usage(rowwise=True, columnwise=input_requires_grad) w = quantizer(w) @@ -126,44 +132,110 @@ def fuser_forward( # Split input tensor and convert dtypes if needed fc1_x = maybe_dequantize(input_, dtype) fc1_xs = None - if with_quantized_compute: - for quantizer in fc1_input_quantizers: - quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) - fc1_xs = tex.split_quantize(fc1_x, split_sizes_int, fc1_input_quantizers) - else: - fc1_xs = torch.split(fc1_x, split_sizes_int) - - # FC1 GEMM - fc1_out_shape = in_shape[:-1] + [fc1_op.out_features] - fc1_out = torch.empty(fc1_out_shape, dtype=dtype, device=device) - general_grouped_gemm( - fc1_ws, - fc1_xs, - [fc1_out], - [None] * group_size, # quantization_params - dtype, - m_splits=split_sizes_int, - bias=[None] * group_size, - use_bias=False, - single_output=True, + for quantizer in fc1_input_quantizers: + quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) + quantizer.optimize_for_gemm = True + fc1_xs = tex.split_quantize(fc1_x, split_sizes_int, fc1_input_quantizers) + + # Pack tensors + fc1_x_data = torch.cat([x._rowwise_data for x in fc1_xs]) + fc1_x_scales = torch.cat([x._rowwise_scale_inv for x in fc1_xs]) + fc1_w_data = torch.cat([w._rowwise_data for w in fc1_weights]) + fc1_w_scales = torch.cat([w._rowwise_scale_inv for w in fc1_weights]) + + # Reorder and reshape tensors + fc1_x_data = fc1_x_data.view(dtype=torch.float8_e4m3fn) + fc1_x_data = fc1_x_data.unsqueeze(0).permute(1, 2, 0) + fc1_x_scales = fc1_x_scales.view(dtype=torch.float8_e8m0fnu) + fc1_x_scales = fc1_x_scales.reshape( + 1, + in_shape[0] // 128, + in_shape[1] // 128, + 32, + 4, + 4, + ) + fc1_x_scales = fc1_x_scales.permute(3, 4, 1, 5, 2, 0) + fc1_w_data = fc1_w_data.view(dtype=torch.float8_e4m3fn) + fc1_w_data = fc1_w_data.permute(1, 2, 0) + fc1_w_scales = fc1_w_scales.view(dtype=torch.float8_e8m0fnu) + fc1_w_scales = fc1_w_scales.reshape( + group_size, + fc1_weight_shape[0] // 128, + fc1_weight_shape[1] // 128, + 32, + 4, + 4, + ) + fc1_w_scales = fc1_w_scales.permute(3, 4, 1, 5, 2, 0) + + # Kernel tile logic + tile_idx_to_expert_idx = [] + cta_tile_m = 256 ### TODO ? + for group_idx in range(group_size): + num_tiles = split_sizes_int[group_idx] // cta_tile_m + tile_idx_to_expert_idx.extend([group_idx] * num_tiles) + num_non_exiting_tiles = torch.tensor([len(tile_idx_to_expert_idx)], device=device, dtype=torch.int32) + tile_idx_to_expert_idx = torch.tensor(tile_idx_to_expert_idx, device=device, dtype=torch.int32) + + # Fused kernel for FC1 + SwiGLU + post-scale + fc1_kernel_out = grouped_gemm_swiglu_wrapper_sm100( + fc1_x_data, + fc1_w_data, + fc1_x_scales, + fc1_w_scales, + tile_idx_to_expert_idx, + num_non_exiting_tiles, + torch.ones(group_size, dtype=dtype, device=device), # alpha_tensor + split_sizes_int, + torch.ones(1, dtype=dtype, device=device), # norm_const_tensor + scales.detach().reshape(-1, 1, 1), + acc_dtype=torch.float32, + c_dtype=torch.bfloat16, + d_dtype=torch.float8_e4m3fn, + cd_major="n", + cluster_shape_mn=(2, 1), + sf_vec_size=32, + sf_dtype=torch.float8_e8m0fnu, ) - # SwiGLU - swiglu_in = fc1_out - swiglu_out = tex.swiglu(swiglu_in, None) - - # Split input tensor and convert dtypes if needed - fc2_x = swiglu_out - fc2_xs = None - if with_quantized_compute: - for quantizer in fc2_input_quantizers: - quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) - fc2_xs = tex.split_quantize(fc2_x, split_sizes_int, fc2_input_quantizers) - else: - fc2_xs = torch.split(fc2_x, split_sizes_int) + # Extract kernel outputs and construct MXFP8 tensors + swiglu_in = fc1_kernel_out["c_tensor"] + swiglu_in = swiglu_in.reshape(in_shape[0], fc1_weight_shape[0]).contiguous() + fc2_in_row_data = fc1_kernel_out["d_tensor"] + fc2_in_row_data = fc2_in_row_data.reshape(in_shape[0], fc2_weight_shape[1]).contiguous() + fc2_in_row_data = torch.split(fc2_in_row_data, split_sizes_int) + fc2_in_row_scale = fc1_kernel_out["sfd_row_tensor"] + fc2_in_row_scale = fc2_in_row_scale.permute(5, 2, 0, 1, 4, 3).contiguous() ### TODO Preserve swizzling + fc2_in_row_scale = fc2_in_row_scale.reshape(in_shape[0], fc2_weight_shape[1] // 32) + fc2_in_row_scale = torch.split(fc2_in_row_scale, split_sizes_int) + fc2_in_col_data = fc1_kernel_out["d_col_tensor"] + fc2_in_col_data = fc2_in_col_data.reshape(in_shape[0], fc2_weight_shape[1]).contiguous() + fc2_in_col_data = torch.split(fc2_in_col_data, split_sizes_int) + fc2_in_col_scale = fc1_kernel_out["sfd_col_tensor"] + fc2_in_col_scale = fc2_in_col_scale.permute(5, 4, 3, 2, 0, 1).contiguous() ### TODO Preserve swizzling + fc2_in_col_scale = fc2_in_col_scale.reshape(in_shape[0] // 32, fc2_weight_shape[1]) + fc2_in_col_scale = torch.split(fc2_in_col_scale, [s // 32 for s in split_sizes_int]) + + # Construct MXFP8 tensors for FC2 + fc2_xs = [] + for group_idx in range(group_size): + x = MXFP8Tensor( + shape=(split_sizes_int[group_idx], fc2_weight_shape[1]), + dtype=dtype, + fp8_dtype=tex.DType.kFloat8E4M3, + rowwise_data=fc2_in_row_data[group_idx], + rowwise_scale_inv=fc2_in_row_scale[group_idx], + columnwise_data=fc2_in_col_data[group_idx], + columnwise_scale_inv=fc2_in_col_scale[group_idx], + quantizer=fc2_input_quantizers[group_idx], + requires_grad=False, + with_gemm_swizzled_scales=False, + ) + fc2_xs.append(x) # FC2 GEMM - fc2_out_shape = in_shape[:-1] + [fc2_op.out_features] + fc2_out_shape = in_shape[:-1] + [fc2_weight_shape[0]] fc2_out = torch.empty(fc2_out_shape, dtype=dtype, device=device) general_grouped_gemm( fc2_ws, @@ -177,21 +249,11 @@ def fuser_forward( single_output=True, ) - # Post-scale - scales = basic_op_extra_inputs[3][0] - scales_shape = tuple(scales.size()) - if scales.numel() != scales_shape[0]: - raise RuntimeError( - f"{self.__class__.__name__} assumes scales are over leading dim, " - f"but got shape={scales_shape}." - ) - out = fc2_out * scales - # Save state for backward pass if requires_grad: # FC1 state fc1_ctx.save_for_backward(split_sizes, *fc1_xs, *fc1_ws) - fc1_ctx.with_quantized_compute = with_quantized_compute + fc1_ctx.with_quantized_compute = True fc1_ctx.input_quantizers = fc1_input_quantizers fc1_ctx.weight_quantizers = fc1_weight_quantizers fc1_ctx.grad_output_quantizers = fc1_grad_output_quantizers @@ -205,9 +267,16 @@ def fuser_forward( swiglu_ctx.dtype = dtype swiglu_ctx.prev_op_grad_output_quantizer = None + # Scale + scale_ctx.save_for_backward(fc2_out, scales) + scale_ctx.input_shape = fc2_out.size() + scale_ctx.extra_input_shape = scales_shape + scale_ctx.input_requires_grad = True + scale_ctx.extra_input_requires_grad = scales.requires_grad + # FC2 state fc2_ctx.save_for_backward(split_sizes, *fc2_xs, *fc2_ws) - fc2_ctx.with_quantized_compute = with_quantized_compute + fc2_ctx.with_quantized_compute = True fc2_ctx.input_quantizers = fc2_input_quantizers fc2_ctx.weight_quantizers = fc2_weight_quantizers fc2_ctx.grad_output_quantizers = fc2_grad_output_quantizers @@ -216,14 +285,7 @@ def fuser_forward( fc2_ctx.input_requires_grad = input_requires_grad fc2_ctx.weight_requires_grad = weight_requires_grad - # Scale - scale_ctx.save_for_backward(fc2_out, scales) - scale_ctx.input_shape = fc2_out.size() - scale_ctx.extra_input_shape = scales_shape - scale_ctx.input_requires_grad = True - scale_ctx.extra_input_requires_grad = scales.requires_grad - - return out, [(), (), (), ()] + return fc2_out, [(), (), (), ()] def fuse_forward_ops( ops: list[FusibleOperation], @@ -263,13 +325,13 @@ def fuse_forward_ops( if not ( isinstance(window[0], GroupedLinear) and isinstance(window[1], SwiGLU) - and isinstance(window[2], GroupedLinear) - and isinstance(window[3], MultiplyExtraInput) + and isinstance(window[2], MultiplyExtraInput) + and isinstance(window[3], GroupedLinear) ): matches_pattern = False - elif window[0].has_bias or window[2].has_bias: + elif window[0].has_bias or window[3].has_bias: matches_pattern = False - elif window[0].group_size != window[2].group_size: + elif window[0].group_size != window[3].group_size: matches_pattern = False if matches_pattern: @@ -277,8 +339,8 @@ def fuse_forward_ops( op = ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8( fc1=window[0], swiglu=window[1], - fc2=window[2], - scale=window[3], + scale=window[2], + fc2=window[3], ) window = [op] else: From b15ca0da63cf3b2e8ba2b047b6cac97550254170 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Wed, 21 Jan 2026 15:58:36 -0800 Subject: [PATCH 21/45] Debug test failures Test is too permissive since the test should still be failing. The weights are not properly interleaved yet. Signed-off-by: Tim Moon --- tests/pytorch/test_fusible_ops.py | 4 +++- .../pytorch/ops/basic/grouped_linear.py | 7 ++++--- .../pytorch/ops/fused/forward_grouped_mlp.py | 19 +++++++++++++------ 3 files changed, 20 insertions(+), 10 deletions(-) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index cd03f8564a..a621103e2a 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -3211,9 +3211,11 @@ def test_grouped_mlp( fc2_ws_test.append(fc2_w_test) with torch.no_grad(): for t in fc1_ws_ref + fc1_ws_test + fc2_ws_ref + fc2_ws_test: - t *= 1 / 64 + t -= 0.5 + t *= 1 / 2 for t in (x_ref, x_test, dy_ref, dy_test): t -= 0.5 + t *= 1 / 2 # Reference implementation xs = torch.split(x_ref, split_sizes.tolist()) diff --git a/transformer_engine/pytorch/ops/basic/grouped_linear.py b/transformer_engine/pytorch/ops/basic/grouped_linear.py index b7a6b843e4..ed8d5c6012 100644 --- a/transformer_engine/pytorch/ops/basic/grouped_linear.py +++ b/transformer_engine/pytorch/ops/basic/grouped_linear.py @@ -374,6 +374,7 @@ def fuser_forward( ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: group_size = self.group_size has_bias = self.has_bias + device = self.weight0.device # Check which grads are required ctx = basic_op_ctxs[0] @@ -433,7 +434,7 @@ def fuser_forward( # Allocate output tensor in_shape = list(input_.size()) out_shape = in_shape[:-1] + [self.out_features] - out = torch.empty(out_shape, dtype=dtype, device=input_.device) + out = torch.empty(out_shape, dtype=dtype, device=device) # Perform GEMMs general_grouped_gemm( @@ -491,6 +492,7 @@ def fuser_backward( ]: group_size = self.group_size has_bias = self.has_bias + device = self.weight0.device # Saved tensors from forward pass ctx = basic_op_ctxs[0] @@ -539,7 +541,6 @@ def fuser_backward( grad_weights[group_idx] = weight_param.main_grad else: weight_shape = ws[0].size() - device = grad_output.device for group_idx in range(group_size): grad_weights[group_idx] = torch.empty( weight_shape, @@ -557,7 +558,7 @@ def fuser_backward( grad_input = torch.empty( in_shape, dtype=ctx.dtype, - device=grad_output.device, + device=device, ) general_grouped_gemm( ws, diff --git a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py index bcfb0bb158..fd9eb66c1f 100644 --- a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py @@ -137,15 +137,11 @@ def fuser_forward( quantizer.optimize_for_gemm = True fc1_xs = tex.split_quantize(fc1_x, split_sizes_int, fc1_input_quantizers) - # Pack tensors + # Pack data tensors fc1_x_data = torch.cat([x._rowwise_data for x in fc1_xs]) - fc1_x_scales = torch.cat([x._rowwise_scale_inv for x in fc1_xs]) - fc1_w_data = torch.cat([w._rowwise_data for w in fc1_weights]) - fc1_w_scales = torch.cat([w._rowwise_scale_inv for w in fc1_weights]) - - # Reorder and reshape tensors fc1_x_data = fc1_x_data.view(dtype=torch.float8_e4m3fn) fc1_x_data = fc1_x_data.unsqueeze(0).permute(1, 2, 0) + fc1_x_scales = torch.cat([x._rowwise_scale_inv for x in fc1_xs]) fc1_x_scales = fc1_x_scales.view(dtype=torch.float8_e8m0fnu) fc1_x_scales = fc1_x_scales.reshape( 1, @@ -156,8 +152,12 @@ def fuser_forward( 4, ) fc1_x_scales = fc1_x_scales.permute(3, 4, 1, 5, 2, 0) + + # Pack weight tensors + fc1_w_data = torch.stack([w._rowwise_data for w in fc1_weights]) fc1_w_data = fc1_w_data.view(dtype=torch.float8_e4m3fn) fc1_w_data = fc1_w_data.permute(1, 2, 0) + fc1_w_scales = torch.stack([w._rowwise_scale_inv for w in fc1_weights]) fc1_w_scales = fc1_w_scales.view(dtype=torch.float8_e8m0fnu) fc1_w_scales = fc1_w_scales.reshape( group_size, @@ -333,6 +333,13 @@ def fuse_forward_ops( matches_pattern = False elif window[0].group_size != window[3].group_size: matches_pattern = False + elif ( + window[0].in_features % 256 != 0 + or window[0].out_features % 256 != 0 + or window[3].in_features % 256 != 0 + or window[3].out_features % 256 != 0 + ): + matches_pattern = False if matches_pattern: # Construct fused op if window matches pattern From 3da2c176bb82636b82bb8ac754c57f63b19d8bf6 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Wed, 21 Jan 2026 16:58:07 -0800 Subject: [PATCH 22/45] Get test to not pass trivially Signed-off-by: Tim Moon --- tests/pytorch/test_fusible_ops.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index a621103e2a..8fe756d5fc 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -3165,6 +3165,7 @@ def test_grouped_mlp( out_shape = in_shape # Skip invalid configurations + with_quantization = quantization is not None maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) if quantization != "mxfp8": pytest.skip("Quantization scheme is not supported") @@ -3215,12 +3216,12 @@ def test_grouped_mlp( t *= 1 / 2 for t in (x_ref, x_test, dy_ref, dy_test): t -= 0.5 - t *= 1 / 2 # Reference implementation xs = torch.split(x_ref, split_sizes.tolist()) + probs = torch.split(probs_ref, split_sizes.tolist()) ys = [] - for x, fc1_w, fc2_w, prob in zip(xs, fc1_ws_ref, fc2_ws_ref, probs_ref): + for x, fc1_w, fc2_w, prob in zip(xs, fc1_ws_ref, fc2_ws_ref, probs): x = torch.nn.functional.linear(x, fc1_w) x1, x2 = x.chunk(2, dim=-1) x = torch.nn.functional.silu(x1) * x2 @@ -3232,7 +3233,7 @@ def test_grouped_mlp( # Construct operations recipe = make_recipe(quantization) - with te.quantized_model_init(recipe=recipe): + with te.quantized_model_init(enabled=with_quantization, recipe=recipe): fc1 = te_ops.GroupedLinear( group_size, hidden_size, @@ -3264,7 +3265,7 @@ def test_grouped_mlp( del fc1_ws_test, fc2_ws_test # Fuse ops and perform forward and backward pass - with te.autocast(recipe=recipe): + with te.autocast(enabled=with_quantization, recipe=recipe): y_test = module(x_test, split_sizes, probs_test, split_sizes) y_test.backward(dy_test) @@ -3284,6 +3285,7 @@ def to_cpu(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]: tols = {"rtol": 0.25, "atol": 0.5} # Loose tols for sanity checking torch.testing.assert_close(to_cpu(y_test), y_ref, **tols) torch.testing.assert_close(to_cpu(x_test.grad), x_ref.grad, **tols) + torch.testing.assert_close(to_cpu(probs_test.grad), probs_ref.grad, **tols) for group_idx in range(group_size): fc1_dw_test = to_cpu(getattr(fc1, f"weight{group_idx}").grad) fc1_dw_ref = fc1_ws_ref[group_idx].grad From 0270eb1b6bf445f3824c1a10445cd6816f7bb5d0 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Wed, 21 Jan 2026 18:03:22 -0800 Subject: [PATCH 23/45] Handle interleaving for SwiGLU Signed-off-by: Tim Moon --- .../pytorch/ops/fused/forward_grouped_mlp.py | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py index fd9eb66c1f..d3d37626b3 100644 --- a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py @@ -120,7 +120,6 @@ def fuser_forward( quantizer = weight_quantizers[group_idx] quantizer.set_usage(rowwise=True, columnwise=input_requires_grad) w = quantizer(w) - tex.swizzle_scales_for_gemm_(w) fc1_ws.append(w) for w, quantizer in zip(fc2_weights, fc2_weight_quantizers): if not is_quantized_tensor(w): @@ -156,17 +155,16 @@ def fuser_forward( # Pack weight tensors fc1_w_data = torch.stack([w._rowwise_data for w in fc1_weights]) fc1_w_data = fc1_w_data.view(dtype=torch.float8_e4m3fn) + fc1_w_data = fc1_w_data.reshape(group_size, 2, fc1_weight_shape[0] // 64, 32, fc1_weight_shape[1]) + fc1_w_data = fc1_w_data.transpose(1, 2).contiguous() # Interleave for SwiGLU + fc1_w_data = fc1_w_data.reshape(group_size, fc1_weight_shape[0], fc1_weight_shape[1]) fc1_w_data = fc1_w_data.permute(1, 2, 0) fc1_w_scales = torch.stack([w._rowwise_scale_inv for w in fc1_weights]) fc1_w_scales = fc1_w_scales.view(dtype=torch.float8_e8m0fnu) - fc1_w_scales = fc1_w_scales.reshape( - group_size, - fc1_weight_shape[0] // 128, - fc1_weight_shape[1] // 128, - 32, - 4, - 4, - ) + fc1_w_scales = fc1_w_scales.reshape(group_size, 2, fc1_weight_shape[0] // 64, 32, fc1_weight_shape[1] // 32) + fc1_w_scales = fc1_w_scales.transpose(1, 2).contiguous() # Interleave for SwiGLU + fc1_w_scales = fc1_w_scales.reshape(group_size, fc1_weight_shape[0] // 128, 4, 32, fc1_weight_shape[1] // 128, 4) + fc1_w_scales = fc1_w_scales.permute(0, 1, 4, 3, 2, 5).contiguous() # Convert to swizzled layout fc1_w_scales = fc1_w_scales.permute(3, 4, 1, 5, 2, 0) # Kernel tile logic @@ -201,19 +199,21 @@ def fuser_forward( # Extract kernel outputs and construct MXFP8 tensors swiglu_in = fc1_kernel_out["c_tensor"] - swiglu_in = swiglu_in.reshape(in_shape[0], fc1_weight_shape[0]).contiguous() + swiglu_in = swiglu_in.reshape(in_shape[0], fc1_weight_shape[0] // 64, 2, 32) + swiglu_in = swiglu_in.transpose(1, 2).contiguous() # Remove SwiGLU interleaving + swiglu_in = swiglu_in.reshape(in_shape[0], fc1_weight_shape[0]) fc2_in_row_data = fc1_kernel_out["d_tensor"] fc2_in_row_data = fc2_in_row_data.reshape(in_shape[0], fc2_weight_shape[1]).contiguous() fc2_in_row_data = torch.split(fc2_in_row_data, split_sizes_int) fc2_in_row_scale = fc1_kernel_out["sfd_row_tensor"] - fc2_in_row_scale = fc2_in_row_scale.permute(5, 2, 0, 1, 4, 3).contiguous() ### TODO Preserve swizzling + fc2_in_row_scale = fc2_in_row_scale.permute(5, 2, 0, 1, 4, 3).contiguous() # Convert to compact layout fc2_in_row_scale = fc2_in_row_scale.reshape(in_shape[0], fc2_weight_shape[1] // 32) fc2_in_row_scale = torch.split(fc2_in_row_scale, split_sizes_int) fc2_in_col_data = fc1_kernel_out["d_col_tensor"] fc2_in_col_data = fc2_in_col_data.reshape(in_shape[0], fc2_weight_shape[1]).contiguous() fc2_in_col_data = torch.split(fc2_in_col_data, split_sizes_int) fc2_in_col_scale = fc1_kernel_out["sfd_col_tensor"] - fc2_in_col_scale = fc2_in_col_scale.permute(5, 4, 3, 2, 0, 1).contiguous() ### TODO Preserve swizzling + fc2_in_col_scale = fc2_in_col_scale.permute(5, 4, 3, 2, 0, 1).contiguous() # Convert to compact layout fc2_in_col_scale = fc2_in_col_scale.reshape(in_shape[0] // 32, fc2_weight_shape[1]) fc2_in_col_scale = torch.split(fc2_in_col_scale, [s // 32 for s in split_sizes_int]) From 0b097900dc83c2a3c0c2da92f503e83a1e410136 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Wed, 21 Jan 2026 19:41:42 -0800 Subject: [PATCH 24/45] Fix numeric tests, except for probs grad Signed-off-by: Tim Moon --- .../pytorch/ops/fused/forward_grouped_mlp.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py index d3d37626b3..c35cbe8248 100644 --- a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py @@ -156,13 +156,15 @@ def fuser_forward( fc1_w_data = torch.stack([w._rowwise_data for w in fc1_weights]) fc1_w_data = fc1_w_data.view(dtype=torch.float8_e4m3fn) fc1_w_data = fc1_w_data.reshape(group_size, 2, fc1_weight_shape[0] // 64, 32, fc1_weight_shape[1]) - fc1_w_data = fc1_w_data.transpose(1, 2).contiguous() # Interleave for SwiGLU + fc1_w_data = fc1_w_data.transpose(1, 2) # Interleave SwiGLU gate/activation + fc1_w_data = fc1_w_data.flip(2).contiguous() # Swap SwiGLU gate/activation fc1_w_data = fc1_w_data.reshape(group_size, fc1_weight_shape[0], fc1_weight_shape[1]) fc1_w_data = fc1_w_data.permute(1, 2, 0) fc1_w_scales = torch.stack([w._rowwise_scale_inv for w in fc1_weights]) fc1_w_scales = fc1_w_scales.view(dtype=torch.float8_e8m0fnu) fc1_w_scales = fc1_w_scales.reshape(group_size, 2, fc1_weight_shape[0] // 64, 32, fc1_weight_shape[1] // 32) - fc1_w_scales = fc1_w_scales.transpose(1, 2).contiguous() # Interleave for SwiGLU + fc1_w_scales = fc1_w_scales.transpose(1, 2) # Interleave SwiGLU gate/activation + fc1_w_scales = fc1_w_scales.flip(2).contiguous() # Swap SwiGLU gate/activation fc1_w_scales = fc1_w_scales.reshape(group_size, fc1_weight_shape[0] // 128, 4, 32, fc1_weight_shape[1] // 128, 4) fc1_w_scales = fc1_w_scales.permute(0, 1, 4, 3, 2, 5).contiguous() # Convert to swizzled layout fc1_w_scales = fc1_w_scales.permute(3, 4, 1, 5, 2, 0) @@ -199,21 +201,25 @@ def fuser_forward( # Extract kernel outputs and construct MXFP8 tensors swiglu_in = fc1_kernel_out["c_tensor"] + swiglu_in = swiglu_in.permute(2, 0, 1).contiguous() swiglu_in = swiglu_in.reshape(in_shape[0], fc1_weight_shape[0] // 64, 2, 32) - swiglu_in = swiglu_in.transpose(1, 2).contiguous() # Remove SwiGLU interleaving + swiglu_in = swiglu_in.transpose(1, 2) # Undo interleaved SwiGLU gate/activation + swiglu_in = swiglu_in.flip(1).contiguous() # Undo swapped SwiGLU gate/activation swiglu_in = swiglu_in.reshape(in_shape[0], fc1_weight_shape[0]) fc2_in_row_data = fc1_kernel_out["d_tensor"] + fc2_in_row_data = fc2_in_row_data.permute(2, 0, 1).contiguous() fc2_in_row_data = fc2_in_row_data.reshape(in_shape[0], fc2_weight_shape[1]).contiguous() fc2_in_row_data = torch.split(fc2_in_row_data, split_sizes_int) fc2_in_row_scale = fc1_kernel_out["sfd_row_tensor"] - fc2_in_row_scale = fc2_in_row_scale.permute(5, 2, 0, 1, 4, 3).contiguous() # Convert to compact layout + fc2_in_row_scale = fc2_in_row_scale.permute(5, 2, 1, 0, 4, 3).contiguous() # Convert to compact layout fc2_in_row_scale = fc2_in_row_scale.reshape(in_shape[0], fc2_weight_shape[1] // 32) fc2_in_row_scale = torch.split(fc2_in_row_scale, split_sizes_int) fc2_in_col_data = fc1_kernel_out["d_col_tensor"] + fc2_in_col_data = fc2_in_col_data.permute(2, 0, 1).contiguous() fc2_in_col_data = fc2_in_col_data.reshape(in_shape[0], fc2_weight_shape[1]).contiguous() fc2_in_col_data = torch.split(fc2_in_col_data, split_sizes_int) fc2_in_col_scale = fc1_kernel_out["sfd_col_tensor"] - fc2_in_col_scale = fc2_in_col_scale.permute(5, 4, 3, 2, 0, 1).contiguous() # Convert to compact layout + fc2_in_col_scale = fc2_in_col_scale.permute(5, 4, 3, 2, 1, 0).contiguous() # Convert to compact layout fc2_in_col_scale = fc2_in_col_scale.reshape(in_shape[0] // 32, fc2_weight_shape[1]) fc2_in_col_scale = torch.split(fc2_in_col_scale, [s // 32 for s in split_sizes_int]) From 7c4029082a460944417ceb42ea5094a1b1a09633 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Thu, 22 Jan 2026 05:38:50 +0000 Subject: [PATCH 25/45] Use pre-swizzled scales from GEMM+SwiGLU output Signed-off-by: Tim Moon --- .../pytorch/ops/fused/forward_grouped_mlp.py | 44 +++++++++---------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py index c35cbe8248..8989c5fc95 100644 --- a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py @@ -142,7 +142,7 @@ def fuser_forward( fc1_x_data = fc1_x_data.unsqueeze(0).permute(1, 2, 0) fc1_x_scales = torch.cat([x._rowwise_scale_inv for x in fc1_xs]) fc1_x_scales = fc1_x_scales.view(dtype=torch.float8_e8m0fnu) - fc1_x_scales = fc1_x_scales.reshape( + fc1_x_scales = fc1_x_scales.view( 1, in_shape[0] // 128, in_shape[1] // 128, @@ -155,17 +155,17 @@ def fuser_forward( # Pack weight tensors fc1_w_data = torch.stack([w._rowwise_data for w in fc1_weights]) fc1_w_data = fc1_w_data.view(dtype=torch.float8_e4m3fn) - fc1_w_data = fc1_w_data.reshape(group_size, 2, fc1_weight_shape[0] // 64, 32, fc1_weight_shape[1]) + fc1_w_data = fc1_w_data.view(group_size, 2, fc1_weight_shape[0] // 64, 32, fc1_weight_shape[1]) fc1_w_data = fc1_w_data.transpose(1, 2) # Interleave SwiGLU gate/activation fc1_w_data = fc1_w_data.flip(2).contiguous() # Swap SwiGLU gate/activation - fc1_w_data = fc1_w_data.reshape(group_size, fc1_weight_shape[0], fc1_weight_shape[1]) + fc1_w_data = fc1_w_data.view(group_size, fc1_weight_shape[0], fc1_weight_shape[1]) fc1_w_data = fc1_w_data.permute(1, 2, 0) fc1_w_scales = torch.stack([w._rowwise_scale_inv for w in fc1_weights]) fc1_w_scales = fc1_w_scales.view(dtype=torch.float8_e8m0fnu) - fc1_w_scales = fc1_w_scales.reshape(group_size, 2, fc1_weight_shape[0] // 64, 32, fc1_weight_shape[1] // 32) + fc1_w_scales = fc1_w_scales.view(group_size, 2, fc1_weight_shape[0] // 64, 32, fc1_weight_shape[1] // 32) fc1_w_scales = fc1_w_scales.transpose(1, 2) # Interleave SwiGLU gate/activation fc1_w_scales = fc1_w_scales.flip(2).contiguous() # Swap SwiGLU gate/activation - fc1_w_scales = fc1_w_scales.reshape(group_size, fc1_weight_shape[0] // 128, 4, 32, fc1_weight_shape[1] // 128, 4) + fc1_w_scales = fc1_w_scales.view(group_size, fc1_weight_shape[0] // 128, 4, 32, fc1_weight_shape[1] // 128, 4) fc1_w_scales = fc1_w_scales.permute(0, 1, 4, 3, 2, 5).contiguous() # Convert to swizzled layout fc1_w_scales = fc1_w_scales.permute(3, 4, 1, 5, 2, 0) @@ -201,27 +201,27 @@ def fuser_forward( # Extract kernel outputs and construct MXFP8 tensors swiglu_in = fc1_kernel_out["c_tensor"] - swiglu_in = swiglu_in.permute(2, 0, 1).contiguous() - swiglu_in = swiglu_in.reshape(in_shape[0], fc1_weight_shape[0] // 64, 2, 32) + swiglu_in = swiglu_in.permute(2, 0, 1) + swiglu_in = swiglu_in.view(in_shape[0], fc1_weight_shape[0] // 64, 2, 32) swiglu_in = swiglu_in.transpose(1, 2) # Undo interleaved SwiGLU gate/activation - swiglu_in = swiglu_in.flip(1).contiguous() # Undo swapped SwiGLU gate/activation - swiglu_in = swiglu_in.reshape(in_shape[0], fc1_weight_shape[0]) + swiglu_in = swiglu_in.flip(1) # Undo swapped SwiGLU gate/activation + swiglu_in = swiglu_in.contiguous().view(in_shape[0], fc1_weight_shape[0]) fc2_in_row_data = fc1_kernel_out["d_tensor"] - fc2_in_row_data = fc2_in_row_data.permute(2, 0, 1).contiguous() - fc2_in_row_data = fc2_in_row_data.reshape(in_shape[0], fc2_weight_shape[1]).contiguous() - fc2_in_row_data = torch.split(fc2_in_row_data, split_sizes_int) + fc2_in_row_data = fc2_in_row_data.permute(2, 0, 1) + fc2_in_row_data = fc2_in_row_data.view(in_shape[0], fc2_weight_shape[1]) + fc2_in_row_data = torch.split(fc2_in_row_data.contiguous(), split_sizes_int) fc2_in_row_scale = fc1_kernel_out["sfd_row_tensor"] - fc2_in_row_scale = fc2_in_row_scale.permute(5, 2, 1, 0, 4, 3).contiguous() # Convert to compact layout - fc2_in_row_scale = fc2_in_row_scale.reshape(in_shape[0], fc2_weight_shape[1] // 32) - fc2_in_row_scale = torch.split(fc2_in_row_scale, split_sizes_int) + fc2_in_row_scale = fc2_in_row_scale.permute(5, 2, 4, 0, 1, 3) + fc2_in_row_scale = fc2_in_row_scale.view(in_shape[0], fc2_weight_shape[1] // 32) + fc2_in_row_scale = torch.split(fc2_in_row_scale.contiguous(), split_sizes_int) fc2_in_col_data = fc1_kernel_out["d_col_tensor"] - fc2_in_col_data = fc2_in_col_data.permute(2, 0, 1).contiguous() - fc2_in_col_data = fc2_in_col_data.reshape(in_shape[0], fc2_weight_shape[1]).contiguous() - fc2_in_col_data = torch.split(fc2_in_col_data, split_sizes_int) + fc2_in_col_data = fc2_in_col_data.permute(2, 0, 1) + fc2_in_col_data = fc2_in_col_data.view(in_shape[0], fc2_weight_shape[1]) + fc2_in_col_data = torch.split(fc2_in_col_data.contiguous(), split_sizes_int) fc2_in_col_scale = fc1_kernel_out["sfd_col_tensor"] - fc2_in_col_scale = fc2_in_col_scale.permute(5, 4, 3, 2, 1, 0).contiguous() # Convert to compact layout - fc2_in_col_scale = fc2_in_col_scale.reshape(in_shape[0] // 32, fc2_weight_shape[1]) - fc2_in_col_scale = torch.split(fc2_in_col_scale, [s // 32 for s in split_sizes_int]) + fc2_in_col_scale = fc2_in_col_scale.permute(5, 2, 4, 0, 1, 3) + fc2_in_col_scale = torch.split(fc2_in_col_scale, [s // 128 for s in split_sizes_int], dim=2) + fc2_in_col_scale = [s.contiguous().view(-1, fc2_weight_shape[1]) for s in fc2_in_col_scale] # Construct MXFP8 tensors for FC2 fc2_xs = [] @@ -236,7 +236,7 @@ def fuser_forward( columnwise_scale_inv=fc2_in_col_scale[group_idx], quantizer=fc2_input_quantizers[group_idx], requires_grad=False, - with_gemm_swizzled_scales=False, + with_gemm_swizzled_scales=True, ) fc2_xs.append(x) From a098cc07a3d509f0283aa2effa825efb64913f7e Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Fri, 23 Jan 2026 01:24:21 +0000 Subject: [PATCH 26/45] Add scaled SwiGLU op Signed-off-by: Tim Moon --- tests/pytorch/test_fusible_ops.py | 88 +++++++++- .../pytorch/ops/basic/__init__.py | 1 + .../pytorch/ops/basic/scaled_swiglu.py | 151 ++++++++++++++++++ .../pytorch/ops/fused/forward_grouped_mlp.py | 80 ++++------ 4 files changed, 268 insertions(+), 52 deletions(-) create mode 100644 transformer_engine/pytorch/ops/basic/scaled_swiglu.py diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 8fe756d5fc..285a27b662 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -2136,6 +2136,84 @@ def test_multiply_extra_input( else: assert x2_test.grad is None + @pytest.mark.parametrize("in_shape", ((71, 192), (5, 7, 128))) + @pytest.mark.parametrize("gate_interleave_size", (None, 32)) + @pytest.mark.parametrize("input_requires_grad", (False, True)) + @pytest.mark.parametrize("scales_requires_grad", (False, True)) + def test_scaled_swiglu( + self, + *, + in_shape: Iterable[int], + gate_interleave_size: Optional[int], + dtype: torch.dtype = torch.float32, + device: torch.device = "cuda", + input_requires_grad: bool, + scales_requires_grad: bool, + ) -> None: + """Multiply two tensors""" + + # Tensor dims + out_shape = list(in_shape) + out_shape[-1] //= 2 + + # Random data + x_ref, x_test = make_reference_and_test_tensors( + in_shape, + test_dtype=dtype, + test_device=device, + requires_grad=input_requires_grad, + ) + scales_ref, scales_test = make_reference_and_test_tensors( + in_shape[:-1], + test_dtype=dtype, + test_device=device, + requires_grad=scales_requires_grad, + ) + dy_ref, dy_test = make_reference_and_test_tensors( + out_shape, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + + # Plain PyTorch implementation + x = x_ref + if gate_interleave_size is not None: + x = x.reshape( + -1, + in_shape[-1] // (2 * gate_interleave_size), + 2, + gate_interleave_size, + ) + x = x.transpose(1, 2) + x = x.reshape(in_shape) + x1, x2 = x.chunk(2, dim=-1) + y = torch.nn.functional.silu(x1) * x2 + y_ref = scales_ref.unsqueeze(-1) * y + if input_requires_grad or scales_requires_grad: + y_ref.backward(dy_ref) + + # Implementation with fusible operation + op = te_ops.ScaledSwiGLU(gate_interleave_size=gate_interleave_size) + y_test = op(x_test, scales_test) + if input_requires_grad or scales_requires_grad: + y_test.backward(dy_test) + + # Check results + tols = dtype_tols(dtype) + y_test = y_test.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(y_test, y_ref, **tols) + if input_requires_grad: + dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(dx_test, x_ref.grad, **tols) + else: + assert x_test.grad is None + if scales_requires_grad: + ds_test = scales_test.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(ds_test, scales_ref.grad, **tols) + else: + assert scales_test.grad is None + class TestFusedOps: """Tests for fused operations""" @@ -3187,7 +3265,7 @@ def test_grouped_mlp( requires_grad=False, ) probs_ref, probs_test = make_reference_and_test_tensors( - (in_shape[0], 1), + (in_shape[0],), test_dtype=dtype, test_device=device, ) @@ -3223,9 +3301,12 @@ def test_grouped_mlp( ys = [] for x, fc1_w, fc2_w, prob in zip(xs, fc1_ws_ref, fc2_ws_ref, probs): x = torch.nn.functional.linear(x, fc1_w) + x = x.reshape(-1, 2 * hidden_size // 64, 2, 32) + x = x.transpose(1, 2) + x = x.reshape(-1, 2 * hidden_size) x1, x2 = x.chunk(2, dim=-1) x = torch.nn.functional.silu(x1) * x2 - x = x * prob + x = x * prob.unsqueeze(-1) x = torch.nn.functional.linear(x, fc2_w) ys.append(x) y_ref = torch.cat(ys) @@ -3252,8 +3333,7 @@ def test_grouped_mlp( ) module = te_ops.Sequential( fc1, - te_ops.SwiGLU(), - te_ops.MultiplyExtraInput(), + te_ops.ScaledSwiGLU(gate_interleave_size=32), fc2, ) diff --git a/transformer_engine/pytorch/ops/basic/__init__.py b/transformer_engine/pytorch/ops/basic/__init__.py index c119682151..c02340817a 100644 --- a/transformer_engine/pytorch/ops/basic/__init__.py +++ b/transformer_engine/pytorch/ops/basic/__init__.py @@ -34,3 +34,4 @@ from .reduce_scatter import ReduceScatter from .reshape import Reshape from .rmsnorm import RMSNorm +from .scaled_swiglu import ScaledSwiGLU diff --git a/transformer_engine/pytorch/ops/basic/scaled_swiglu.py b/transformer_engine/pytorch/ops/basic/scaled_swiglu.py new file mode 100644 index 0000000000..9749b29edd --- /dev/null +++ b/transformer_engine/pytorch/ops/basic/scaled_swiglu.py @@ -0,0 +1,151 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Fusible operation for multiplying with extra input tensor.""" + +from __future__ import annotations +from collections.abc import Iterable +from typing import Any, Optional + +import torch + +import transformer_engine_torch as tex +from ...tensor import Quantizer +from ..op import BasicOperation, OperationContext +from .._common import maybe_dequantize + + +class ScaledSwiGLU(BasicOperation): + """SwiGLU with post-scaling + """ + + # Operation expects scales + num_extra_inputs: int = 1 + + def __init__(self, gate_interleave_size: Optional[int] = None): + super().__init__() + self.gate_interleave_size: Optional[int] = gate_interleave_size + + def op_forward(self, *args, **kwargs) -> None: + raise RuntimeError( + "{self.__class__.__name__} operation has " + f"{self.num_extra_inputs} extra tensor inputs " + f"and {self.num_extra_outputs} extra tensor outputs. " + "It overrides `fuser_forward` instead of `op_forward`." + ) + + def op_backward(self, *args, **kwargs) -> None: + raise RuntimeError( + "{self.__class__.__name__} operation has " + f"{self.num_extra_inputs} extra tensor inputs " + f"and {self.num_extra_outputs} extra tensor outputs. " + "It overrides `fuser_backward` instead of `op_backward`." + ) + + def fuser_forward( + self, + basic_op_ctxs: list[OperationContext], + input_: torch.Tensor, + *, + basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], + prev_op_grad_output_quantizer: Optional[Quantizer], + next_op_input_quantizer: Optional[Quantizer], + basic_op_kwargs: list[dict[str, Any]], + ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: + extra_input = basic_op_extra_inputs[0][0] + + # Determine compute dtype + if torch.is_autocast_enabled(): + dtype = torch.get_autocast_dtype("cuda") + elif isinstance(input_, torch.Tensor): + dtype = input_.dtype + else: + dtype = extra_input.dtype + + # Make sure inputs are in correct dtype + input_ = maybe_dequantize(input_, dtype) + scales = maybe_dequantize(extra_input, dtype) + + # Remove gate interleaving if needed + swiglu_in = input_ + if self.gate_interleave_size is not None: + shape = swiglu_in.size() + swiglu_in = swiglu_in.reshape( + -1, + shape[-1] // (2 * self.gate_interleave_size), + 2, + self.gate_interleave_size, + ) + swiglu_in = swiglu_in.transpose(1, 2).contiguous() + swiglu_in = swiglu_in.view(shape) + + # Compute scaled SwiGLU + swiglu_out = tex.swiglu(swiglu_in, None) + out = swiglu_out * scales.unsqueeze(-1) + + # Save state for backward pass + ctx = basic_op_ctxs[0] + if ctx.requires_grad: + ctx.input_requires_grad = True + ctx.extra_input_requires_grad = extra_input.requires_grad + ctx.dtype = dtype + ctx.save_for_backward( + input_, + scales if ctx.input_requires_grad else None, + ) + + return out, [()] + + def fuser_backward( + self, + basic_op_ctxs: list[OperationContext], + grad_output: torch.Tensor, + *, + basic_op_grad_extra_outputs: list[tuple[torch.Tensor, ...]], + ) -> tuple[ + torch.Tensor, + Iterable[Iterable[Optional[torch.Tensor]]], + Iterable[Iterable[Optional[torch.Tensor]]], + ]: + ctx = basic_op_ctxs[0] + input_, scales = ctx.saved_tensors + grad_output = maybe_dequantize(grad_output, ctx.dtype) + + # Remove gate interleaving if needed + swiglu_in = input_ + if self.gate_interleave_size is not None: + shape = swiglu_in.size() + swiglu_in = swiglu_in.reshape( + -1, + shape[-1] // (2 * self.gate_interleave_size), + 2, + self.gate_interleave_size, + ) + swiglu_in = swiglu_in.transpose(1, 2).contiguous() + swiglu_in = swiglu_in.view(shape) + + # Compute input grad + grad_input = None + if ctx.input_requires_grad: + grad_swiglu_out = grad_output * scales.unsqueeze(-1) + grad_swiglu_in = tex.dswiglu(grad_swiglu_out, swiglu_in, None) + grad_input = grad_swiglu_in + if self.gate_interleave_size is not None: + shape = grad_input.size() + grad_input = grad_input.reshape( + -1, + 2, + shape[-1] // (2 * self.gate_interleave_size), + self.gate_interleave_size, + ) + grad_input = grad_input.transpose(1, 2).contiguous() + grad_input = grad_input.view(shape) + + # Compute scales grad by recomputing SwiGLU + grad_extra_input = None + if ctx.extra_input_requires_grad: + swiglu_out = tex.swiglu(swiglu_in, None) + grad_extra_input = torch.linalg.vecdot(swiglu_out, grad_output) + + return grad_input, [()], [(grad_extra_input,)] diff --git a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py index 8989c5fc95..ff00c92f02 100644 --- a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py @@ -16,7 +16,7 @@ from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...quantization import FP8GlobalStateManager from ...tensor import MXFP8Tensor, Quantizer -from ..basic import GroupedLinear, MultiplyExtraInput, SwiGLU +from ..basic import GroupedLinear, ScaledSwiGLU from ..fuser import register_forward_fusion from ..op import FusedOperation, FusibleOperation, OperationContext from .._common import is_quantized_tensor, maybe_dequantize @@ -28,11 +28,10 @@ def __init__( self, *, fc1: GroupedLinear, - swiglu: SwiGLU, - scale: MultiplyExtraInput, + swiglu: ScaledSwiGLU, fc2: GroupedLinear, ) -> None: - super().__init__((fc1, swiglu, scale, fc2)) + super().__init__((fc1, swiglu, fc2)) def fuser_forward( self, @@ -46,8 +45,8 @@ def fuser_forward( ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: # Get basic operations - fc1_op, swiglu_op, scale_op, fc2_op = self.basic_ops - fc1_ctx, swiglu_ctx, scale_ctx, fc2_ctx = basic_op_ctxs + fc1_op, swiglu_op, fc2_op = self.basic_ops + fc1_ctx, swiglu_ctx, fc2_ctx = basic_op_ctxs # Tensor properties in_shape = list(input_.size()) @@ -86,7 +85,7 @@ def fuser_forward( # Extract split sizes from extra input fc1_split_sizes = basic_op_extra_inputs[0][0] - fc2_split_sizes = basic_op_extra_inputs[3][0] + fc2_split_sizes = basic_op_extra_inputs[2][0] if ( fc1_split_sizes.size() != fc2_split_sizes.size() or fc1_split_sizes.data_ptr() != fc2_split_sizes.data_ptr() @@ -100,13 +99,7 @@ def fuser_forward( raise ValueError(f"Expected {group_size} splits, but got {len(split_sizes_int)}.") # Extract post-scales from extra input - scales = basic_op_extra_inputs[2][0] - scales_shape = tuple(scales.size()) - if scales.numel() != scales_shape[0]: - raise RuntimeError( - f"{self.__class__.__name__} assumes scales are over leading dim, " - f"but got shape={scales_shape}." - ) + scales = basic_op_extra_inputs[1][0] # Extract params fc1_weights = [getattr(fc1_op, f"weight{idx}") for idx in range(group_size)] @@ -155,15 +148,13 @@ def fuser_forward( # Pack weight tensors fc1_w_data = torch.stack([w._rowwise_data for w in fc1_weights]) fc1_w_data = fc1_w_data.view(dtype=torch.float8_e4m3fn) - fc1_w_data = fc1_w_data.view(group_size, 2, fc1_weight_shape[0] // 64, 32, fc1_weight_shape[1]) - fc1_w_data = fc1_w_data.transpose(1, 2) # Interleave SwiGLU gate/activation + fc1_w_data = fc1_w_data.view(group_size, fc1_weight_shape[0] // 64, 2, 32, fc1_weight_shape[1]) fc1_w_data = fc1_w_data.flip(2).contiguous() # Swap SwiGLU gate/activation fc1_w_data = fc1_w_data.view(group_size, fc1_weight_shape[0], fc1_weight_shape[1]) fc1_w_data = fc1_w_data.permute(1, 2, 0) fc1_w_scales = torch.stack([w._rowwise_scale_inv for w in fc1_weights]) fc1_w_scales = fc1_w_scales.view(dtype=torch.float8_e8m0fnu) - fc1_w_scales = fc1_w_scales.view(group_size, 2, fc1_weight_shape[0] // 64, 32, fc1_weight_shape[1] // 32) - fc1_w_scales = fc1_w_scales.transpose(1, 2) # Interleave SwiGLU gate/activation + fc1_w_scales = fc1_w_scales.view(group_size, fc1_weight_shape[0] // 64, 2, 32, fc1_weight_shape[1] // 32) fc1_w_scales = fc1_w_scales.flip(2).contiguous() # Swap SwiGLU gate/activation fc1_w_scales = fc1_w_scales.view(group_size, fc1_weight_shape[0] // 128, 4, 32, fc1_weight_shape[1] // 128, 4) fc1_w_scales = fc1_w_scales.permute(0, 1, 4, 3, 2, 5).contiguous() # Convert to swizzled layout @@ -203,8 +194,7 @@ def fuser_forward( swiglu_in = fc1_kernel_out["c_tensor"] swiglu_in = swiglu_in.permute(2, 0, 1) swiglu_in = swiglu_in.view(in_shape[0], fc1_weight_shape[0] // 64, 2, 32) - swiglu_in = swiglu_in.transpose(1, 2) # Undo interleaved SwiGLU gate/activation - swiglu_in = swiglu_in.flip(1) # Undo swapped SwiGLU gate/activation + swiglu_in = swiglu_in.flip(2) # Undo swapped SwiGLU gate/activation swiglu_in = swiglu_in.contiguous().view(in_shape[0], fc1_weight_shape[0]) fc2_in_row_data = fc1_kernel_out["d_tensor"] fc2_in_row_data = fc2_in_row_data.permute(2, 0, 1) @@ -257,7 +247,7 @@ def fuser_forward( # Save state for backward pass if requires_grad: - # FC1 state + # FC1 fc1_ctx.save_for_backward(split_sizes, *fc1_xs, *fc1_ws) fc1_ctx.with_quantized_compute = True fc1_ctx.input_quantizers = fc1_input_quantizers @@ -268,17 +258,11 @@ def fuser_forward( fc1_ctx.input_requires_grad = input_requires_grad fc1_ctx.weight_requires_grad = weight_requires_grad - # SwiGLU - swiglu_ctx.save_for_backward(swiglu_in) + # Scaled SwiGLU + swiglu_ctx.save_for_backward(swiglu_in, scales) + swiglu_ctx.input_requires_grad = True + swiglu_ctx.extra_input_requires_grad = True swiglu_ctx.dtype = dtype - swiglu_ctx.prev_op_grad_output_quantizer = None - - # Scale - scale_ctx.save_for_backward(fc2_out, scales) - scale_ctx.input_shape = fc2_out.size() - scale_ctx.extra_input_shape = scales_shape - scale_ctx.input_requires_grad = True - scale_ctx.extra_input_requires_grad = scales.requires_grad # FC2 state fc2_ctx.save_for_backward(split_sizes, *fc2_xs, *fc2_ws) @@ -291,7 +275,7 @@ def fuser_forward( fc2_ctx.input_requires_grad = input_requires_grad fc2_ctx.weight_requires_grad = weight_requires_grad - return fc2_out, [(), (), (), ()] + return fc2_out, [(), (), ()] def fuse_forward_ops( ops: list[FusibleOperation], @@ -323,48 +307,48 @@ def fuse_forward_ops( # Scan through ops, fusing if possible out = [] - window, ops = ops[:4], ops[4:] - while len(window) == 4: + window, ops = ops[:3], ops[3:] + while len(window) == 3: # Check if window matches pattern matches_pattern = True if not ( isinstance(window[0], GroupedLinear) - and isinstance(window[1], SwiGLU) - and isinstance(window[2], MultiplyExtraInput) - and isinstance(window[3], GroupedLinear) + and isinstance(window[1], ScaledSwiGLU) + and isinstance(window[2], GroupedLinear) ): matches_pattern = False - elif window[0].has_bias or window[3].has_bias: + elif window[0].has_bias or window[2].has_bias: matches_pattern = False - elif window[0].group_size != window[3].group_size: + elif window[0].group_size != window[2].group_size: matches_pattern = False elif ( window[0].in_features % 256 != 0 or window[0].out_features % 256 != 0 - or window[3].in_features % 256 != 0 - or window[3].out_features % 256 != 0 + or window[2].in_features % 256 != 0 + or window[2].out_features % 256 != 0 ): matches_pattern = False + elif window[1].gate_interleave_size != 32: + matches_pattern = False if matches_pattern: # Construct fused op if window matches pattern op = ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8( fc1=window[0], swiglu=window[1], - scale=window[2], - fc2=window[3], + fc2=window[2], ) window = [op] else: # Shift window if window doesn't match pattern - out.extend(window[:-3]) - window = window[-3:] + out.extend(window[:-2]) + window = window[-2:] # Adjust window to expected size - out.extend(window[:-4]) - window = window[-4:] - while ops and len(window) < 4: + out.extend(window[:-3]) + window = window[-3:] + while ops and len(window) < 3: window.append(ops[0]) ops = ops[1:] From e4f51d3e01c68249eaf5e3a126e873b7262bec97 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Fri, 23 Jan 2026 03:15:02 +0000 Subject: [PATCH 27/45] Avoid CPU splits in group GEMM+SwiGLU kernel Signed-off-by: Tim Moon --- tests/pytorch/test_fusible_ops.py | 4 +- .../pytorch/ops/fused/forward_grouped_mlp.py | 60 ++++++++++++------- 2 files changed, 42 insertions(+), 22 deletions(-) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 285a27b662..57bf63bdaa 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -1952,7 +1952,7 @@ def test_grouped_linear( # Split sizes split_sizes = [split_alignment * i for i in range(group_size)] random.shuffle(split_sizes) - split_sizes = torch.tensor(split_sizes, dtype=torch.int, device="cpu") + split_sizes = torch.tensor(split_sizes, dtype=torch.int, device=device) # Make input and weight shapes consistent out_features, in_features = weight_shape @@ -3236,7 +3236,7 @@ def test_grouped_mlp( # Split sizes split_sizes = [split_alignment * i for i in range(group_size)] random.shuffle(split_sizes) - split_sizes = torch.tensor(split_sizes, dtype=torch.int, device="cpu") + split_sizes = torch.tensor(split_sizes, dtype=torch.int, device=device) # Make input shape in_shape = (split_sizes.sum().item(), hidden_size) diff --git a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py index ff00c92f02..8eb8efd961 100644 --- a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py @@ -94,9 +94,16 @@ def fuser_forward( f"{self.__class__.__name__} got different split points for FC1 and FC2." ) split_sizes = fc1_split_sizes - split_sizes_int = [int(s) for s in split_sizes.tolist()] - if len(split_sizes_int) != group_size: - raise ValueError(f"Expected {group_size} splits, but got {len(split_sizes_int)}.") + split_sizes_cpu = [int(s) for s in split_sizes.tolist()] + if len(split_sizes_cpu) != group_size: + raise ValueError(f"Expected {group_size} splits, but got {len(split_sizes_cpu)}.") + split_sizes = split_sizes.to(dtype=torch.int, device=device) + split_points = torch.zeros( + split_sizes.numel() + 1, + dtype=torch.int, + device=device, + ) + torch.cumsum(split_sizes, 0, out=split_points[1:]) # Extract post-scales from extra input scales = basic_op_extra_inputs[1][0] @@ -127,7 +134,7 @@ def fuser_forward( for quantizer in fc1_input_quantizers: quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) quantizer.optimize_for_gemm = True - fc1_xs = tex.split_quantize(fc1_x, split_sizes_int, fc1_input_quantizers) + fc1_xs = tex.split_quantize(fc1_x, split_sizes_cpu, fc1_input_quantizers) # Pack data tensors fc1_x_data = torch.cat([x._rowwise_data for x in fc1_xs]) @@ -161,13 +168,26 @@ def fuser_forward( fc1_w_scales = fc1_w_scales.permute(3, 4, 1, 5, 2, 0) # Kernel tile logic - tile_idx_to_expert_idx = [] - cta_tile_m = 256 ### TODO ? - for group_idx in range(group_size): - num_tiles = split_sizes_int[group_idx] // cta_tile_m - tile_idx_to_expert_idx.extend([group_idx] * num_tiles) - num_non_exiting_tiles = torch.tensor([len(tile_idx_to_expert_idx)], device=device, dtype=torch.int32) - tile_idx_to_expert_idx = torch.tensor(tile_idx_to_expert_idx, device=device, dtype=torch.int32) + mma_tiler_mn = (256, 256) + tile_points = torch.arange( + 0, + in_shape[0], + mma_tiler_mn[0], + dtype=torch.int, + device=device, + ) + tile_idx_to_expert_idx = torch.searchsorted( + split_points[1:], + tile_points, + out_int32=True, + side="right", + ) + num_non_exiting_tiles = torch.full( + (1,), + in_shape[0] // mma_tiler_mn[0], + dtype=torch.int, + device=device, + ) # Fused kernel for FC1 + SwiGLU + post-scale fc1_kernel_out = grouped_gemm_swiglu_wrapper_sm100( @@ -178,19 +198,19 @@ def fuser_forward( tile_idx_to_expert_idx, num_non_exiting_tiles, torch.ones(group_size, dtype=dtype, device=device), # alpha_tensor - split_sizes_int, torch.ones(1, dtype=dtype, device=device), # norm_const_tensor scales.detach().reshape(-1, 1, 1), + split_points, acc_dtype=torch.float32, c_dtype=torch.bfloat16, d_dtype=torch.float8_e4m3fn, cd_major="n", + mma_tiler_mn=mma_tiler_mn, cluster_shape_mn=(2, 1), sf_vec_size=32, - sf_dtype=torch.float8_e8m0fnu, ) - # Extract kernel outputs and construct MXFP8 tensors + # Unpack kernel outputs swiglu_in = fc1_kernel_out["c_tensor"] swiglu_in = swiglu_in.permute(2, 0, 1) swiglu_in = swiglu_in.view(in_shape[0], fc1_weight_shape[0] // 64, 2, 32) @@ -199,25 +219,25 @@ def fuser_forward( fc2_in_row_data = fc1_kernel_out["d_tensor"] fc2_in_row_data = fc2_in_row_data.permute(2, 0, 1) fc2_in_row_data = fc2_in_row_data.view(in_shape[0], fc2_weight_shape[1]) - fc2_in_row_data = torch.split(fc2_in_row_data.contiguous(), split_sizes_int) + fc2_in_row_data = torch.split(fc2_in_row_data.contiguous(), split_sizes_cpu) fc2_in_row_scale = fc1_kernel_out["sfd_row_tensor"] fc2_in_row_scale = fc2_in_row_scale.permute(5, 2, 4, 0, 1, 3) fc2_in_row_scale = fc2_in_row_scale.view(in_shape[0], fc2_weight_shape[1] // 32) - fc2_in_row_scale = torch.split(fc2_in_row_scale.contiguous(), split_sizes_int) + fc2_in_row_scale = torch.split(fc2_in_row_scale.contiguous(), split_sizes_cpu) fc2_in_col_data = fc1_kernel_out["d_col_tensor"] fc2_in_col_data = fc2_in_col_data.permute(2, 0, 1) fc2_in_col_data = fc2_in_col_data.view(in_shape[0], fc2_weight_shape[1]) - fc2_in_col_data = torch.split(fc2_in_col_data.contiguous(), split_sizes_int) + fc2_in_col_data = torch.split(fc2_in_col_data.contiguous(), split_sizes_cpu) fc2_in_col_scale = fc1_kernel_out["sfd_col_tensor"] fc2_in_col_scale = fc2_in_col_scale.permute(5, 2, 4, 0, 1, 3) - fc2_in_col_scale = torch.split(fc2_in_col_scale, [s // 128 for s in split_sizes_int], dim=2) + fc2_in_col_scale = torch.split(fc2_in_col_scale, [s // 128 for s in split_sizes_cpu], dim=2) fc2_in_col_scale = [s.contiguous().view(-1, fc2_weight_shape[1]) for s in fc2_in_col_scale] # Construct MXFP8 tensors for FC2 fc2_xs = [] for group_idx in range(group_size): x = MXFP8Tensor( - shape=(split_sizes_int[group_idx], fc2_weight_shape[1]), + shape=(split_sizes_cpu[group_idx], fc2_weight_shape[1]), dtype=dtype, fp8_dtype=tex.DType.kFloat8E4M3, rowwise_data=fc2_in_row_data[group_idx], @@ -239,7 +259,7 @@ def fuser_forward( [fc2_out], [None] * group_size, # quantization_params dtype, - m_splits=split_sizes_int, + m_splits=split_sizes_cpu, bias=[None] * group_size, use_bias=False, single_output=True, From fb28b6e0356388b11e0d7cba96988aed9ccc2a00 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Fri, 23 Jan 2026 06:00:21 +0000 Subject: [PATCH 28/45] Debug scaled SwiGLU Signed-off-by: Tim Moon --- transformer_engine/pytorch/ops/basic/scaled_swiglu.py | 3 +++ transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py | 5 +++++ 2 files changed, 8 insertions(+) diff --git a/transformer_engine/pytorch/ops/basic/scaled_swiglu.py b/transformer_engine/pytorch/ops/basic/scaled_swiglu.py index 9749b29edd..9b398fa217 100644 --- a/transformer_engine/pytorch/ops/basic/scaled_swiglu.py +++ b/transformer_engine/pytorch/ops/basic/scaled_swiglu.py @@ -110,6 +110,9 @@ def fuser_backward( ]: ctx = basic_op_ctxs[0] input_, scales = ctx.saved_tensors + input_ = maybe_dequantize(input_, ctx.dtype) + if scales is not None: + scales = maybe_dequantize(scales, ctx.dtype) grad_output = maybe_dequantize(grad_output, ctx.dtype) # Remove gate interleaving if needed diff --git a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py index 8eb8efd961..22817eaa0e 100644 --- a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py @@ -6,6 +6,7 @@ from __future__ import annotations from collections.abc import Iterable +import itertools from typing import Any, Optional import torch @@ -265,6 +266,10 @@ def fuser_forward( single_output=True, ) + # Prepare input tensors for backward pass + for x in itertools.chain(fc1_xs, fc2_xs): + x.update_usage(rowwise_usage=False, columnwise_usage=True) + # Save state for backward pass if requires_grad: # FC1 From b0bf34d85f15526cf66aa627dc9309fa5cfe5f3c Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Sat, 24 Jan 2026 02:06:00 +0000 Subject: [PATCH 29/45] Handle case where fused kernel is not available Signed-off-by: Tim Moon --- tests/pytorch/test_fusible_ops.py | 34 ++++++++----- transformer_engine/pytorch/ops/__init__.py | 3 +- .../pytorch/ops/fused/__init__.py | 7 ++- .../pytorch/ops/fused/forward_grouped_mlp.py | 49 ++++++++++++++++--- 4 files changed, 72 insertions(+), 21 deletions(-) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 57bf63bdaa..09f024326c 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -3221,17 +3221,19 @@ def to_cpu(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]: torch.testing.assert_close(to_cpu(ffn1.bias.grad), b1_ref.grad, **tols) torch.testing.assert_close(to_cpu(ffn2.bias.grad), b2_ref.grad, **tols) + @pytest.mark.parametrize("dtype", _dtypes) + @pytest.mark.parametrize("quantization", _quantization_list) def test_grouped_mlp( self, *, - dtype: torch.dtype = torch.bfloat16, - quantization: Optional[str] = "mxfp8", + dtype: torch.dtype, + quantization: Optional[str], device: torch.device = "cuda", group_size: int = 4, hidden_size: int = 256, split_alignment: int = 256, ) -> None: - """GroupedLinear + SwiGLU + GroupedLinear""" + """GroupedLinear + ScaledSwiGLU + GroupedLinear""" # Split sizes split_sizes = [split_alignment * i for i in range(group_size)] @@ -3245,10 +3247,6 @@ def test_grouped_mlp( # Skip invalid configurations with_quantization = quantization is not None maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) - if quantization != "mxfp8": - pytest.skip("Quantization scheme is not supported") - if dtype != torch.bfloat16: - pytest.skip("Non-quantized dtype must be BF16") # Random data x_ref, x_test = make_reference_and_test_tensors( @@ -3294,6 +3292,7 @@ def test_grouped_mlp( t *= 1 / 2 for t in (x_ref, x_test, dy_ref, dy_test): t -= 0.5 + t *= 1 / 2 # Reference implementation xs = torch.split(x_ref, split_sizes.tolist()) @@ -3349,9 +3348,18 @@ def test_grouped_mlp( y_test = module(x_test, split_sizes, probs_test, split_sizes) y_test.backward(dy_test) - # Check that forward operations have been fused - forward_ops = module._module_groups[0]._forward_ops - assert len(forward_ops) == 1 + # Check for expected fusions + if ( + te_ops.fused.ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8.is_supported() + and quantization == "mxfp8" + and dtype == torch.bfloat16 + ): + forward_ops = module._module_groups[0]._forward_ops + assert len(forward_ops) == 1 + assert isinstance( + forward_ops[0][0], + te_ops.fused.ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8, + ) def to_cpu(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]: """Convert to FP64 CPU tensor""" @@ -3361,8 +3369,12 @@ def to_cpu(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]: out = out.requires_grad_(requires_grad=tensor.requires_grad) return out + # Loose tols for sanity checking + tols = {"rtol": 0.25, "atol": 0.5} + if quantization == "nvfp4": + tols = {"rtol": 0.5, "atol": 1} + # Check values - tols = {"rtol": 0.25, "atol": 0.5} # Loose tols for sanity checking torch.testing.assert_close(to_cpu(y_test), y_ref, **tols) torch.testing.assert_close(to_cpu(x_test.grad), x_ref.grad, **tols) torch.testing.assert_close(to_cpu(probs_test.grad), probs_ref.grad, **tols) diff --git a/transformer_engine/pytorch/ops/__init__.py b/transformer_engine/pytorch/ops/__init__.py index cd37144dfa..99f51a9c7a 100644 --- a/transformer_engine/pytorch/ops/__init__.py +++ b/transformer_engine/pytorch/ops/__init__.py @@ -13,5 +13,4 @@ from .linear import Linear from .op import BasicOperation, FusedOperation, FusibleOperation from .sequential import Sequential - -import transformer_engine.pytorch.ops.fused +from . import fused diff --git a/transformer_engine/pytorch/ops/fused/__init__.py b/transformer_engine/pytorch/ops/fused/__init__.py index cabd86442a..0020ad5976 100644 --- a/transformer_engine/pytorch/ops/fused/__init__.py +++ b/transformer_engine/pytorch/ops/fused/__init__.py @@ -9,12 +9,14 @@ from .backward_add_rmsnorm import BackwardAddRMSNorm from .backward_linear_add import BackwardLinearAdd from .backward_linear_scale import BackwardLinearScale +from .forward_grouped_mlp import ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8 from .forward_linear_bias_activation import ForwardLinearBiasActivation from .forward_linear_bias_add import ForwardLinearBiasAdd from .forward_linear_scale_add import ForwardLinearScaleAdd from .userbuffers_backward_linear import UserbuffersBackwardLinear from .userbuffers_forward_linear import UserbuffersForwardLinear + # Register forward fusions register_forward_fusion(UserbuffersForwardLinear.fuse_forward_ops) register_forward_fusion(ForwardLinearBiasAdd.fuse_forward_ops) @@ -28,5 +30,6 @@ register_backward_fusion(BackwardActivationBias.fuse_backward_ops) register_backward_fusion(BackwardAddRMSNorm.fuse_backward_ops) -from .forward_grouped_mlp import fuse_forward_ops as forward_grouped_mlp_fuse_ops -register_forward_fusion(forward_grouped_mlp_fuse_ops, prepend=True) + +# Import specialized fusions +from .forward_grouped_mlp import ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8 diff --git a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py index 22817eaa0e..1937dab7d2 100644 --- a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py @@ -5,18 +5,19 @@ """Fused operation for forward GEMM + scale + add.""" from __future__ import annotations -from collections.abc import Iterable +from collections.abc import Callable, Iterable +import functools import itertools from typing import Any, Optional import torch -from cudnn import grouped_gemm_swiglu_wrapper_sm100 ### TODO Check if available import transformer_engine_torch as tex from ...cpp_extensions import general_grouped_gemm -from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload +from ...module._common import noop_cat from ...quantization import FP8GlobalStateManager from ...tensor import MXFP8Tensor, Quantizer +from ...utils import get_device_compute_capability from ..basic import GroupedLinear, ScaledSwiGLU from ..fuser import register_forward_fusion from ..op import FusedOperation, FusibleOperation, OperationContext @@ -24,6 +25,32 @@ class ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8(FusedOperation): + """Fused op for MXFP8 GroupedLinear + ScaledSwiGLU + GroupedLinear + + Uses experimental CuTe DSL kernel. + + """ + + @classmethod + @functools.lru_cache(maxsize=None) + def grouped_gemm_swiglu_kernel(cls) -> Callable: + """Fused kernel for grouped GEMM, SwiGLU, and post-multiplication.""" + from cudnn import grouped_gemm_swiglu_wrapper_sm100 + return grouped_gemm_swiglu_wrapper_sm100 + + @classmethod + @functools.lru_cache(maxsize=None) + def is_supported(cls) -> bool: + """Whether this fused operation is supported on the current system.""" + if get_device_compute_capability() < (10, 0): + # Kernel requires SM100+ + return False + try: + # Make sure kernel is available + cls.grouped_gemm_swiglu_kernel() + except ImportError: + return False + return True def __init__( self, @@ -138,10 +165,10 @@ def fuser_forward( fc1_xs = tex.split_quantize(fc1_x, split_sizes_cpu, fc1_input_quantizers) # Pack data tensors - fc1_x_data = torch.cat([x._rowwise_data for x in fc1_xs]) + fc1_x_data = noop_cat([x._rowwise_data for x in fc1_xs]) fc1_x_data = fc1_x_data.view(dtype=torch.float8_e4m3fn) fc1_x_data = fc1_x_data.unsqueeze(0).permute(1, 2, 0) - fc1_x_scales = torch.cat([x._rowwise_scale_inv for x in fc1_xs]) + fc1_x_scales = noop_cat([x._rowwise_scale_inv for x in fc1_xs]) fc1_x_scales = fc1_x_scales.view(dtype=torch.float8_e8m0fnu) fc1_x_scales = fc1_x_scales.view( 1, @@ -191,7 +218,7 @@ def fuser_forward( ) # Fused kernel for FC1 + SwiGLU + post-scale - fc1_kernel_out = grouped_gemm_swiglu_wrapper_sm100( + fc1_kernel_out = self.grouped_gemm_swiglu_kernel()( fc1_x_data, fc1_w_data, fc1_x_scales, @@ -302,6 +329,7 @@ def fuser_forward( return fc2_out, [(), (), ()] + def fuse_forward_ops( ops: list[FusibleOperation], *, @@ -324,6 +352,10 @@ def fuse_forward_ops( """ + # Return immediately if fused kernel is not supported + if not ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8.is_supported(): + return ops + # Check if recipe is supported if recipe is None: return ops @@ -380,3 +412,8 @@ def fuse_forward_ops( # Return list of ops out.extend(window) return out + + +# Register fusion if available +if ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8.is_supported(): + register_forward_fusion(fuse_forward_ops, prepend=True) From 000c2736de297055fa8464b776ca62a2766dd6b2 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Sat, 24 Jan 2026 03:30:05 +0000 Subject: [PATCH 30/45] Revert to plain tensor concat Signed-off-by: Tim Moon --- transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py index 1937dab7d2..78e87e2938 100644 --- a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py @@ -14,7 +14,6 @@ import transformer_engine_torch as tex from ...cpp_extensions import general_grouped_gemm -from ...module._common import noop_cat from ...quantization import FP8GlobalStateManager from ...tensor import MXFP8Tensor, Quantizer from ...utils import get_device_compute_capability @@ -165,10 +164,10 @@ def fuser_forward( fc1_xs = tex.split_quantize(fc1_x, split_sizes_cpu, fc1_input_quantizers) # Pack data tensors - fc1_x_data = noop_cat([x._rowwise_data for x in fc1_xs]) + fc1_x_data = torch.cat([x._rowwise_data for x in fc1_xs]) fc1_x_data = fc1_x_data.view(dtype=torch.float8_e4m3fn) fc1_x_data = fc1_x_data.unsqueeze(0).permute(1, 2, 0) - fc1_x_scales = noop_cat([x._rowwise_scale_inv for x in fc1_xs]) + fc1_x_scales = torch.cat([x._rowwise_scale_inv for x in fc1_xs]) fc1_x_scales = fc1_x_scales.view(dtype=torch.float8_e8m0fnu) fc1_x_scales = fc1_x_scales.view( 1, From e2ea4d2c7726de736f453801a677b58633b5a105 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Sat, 24 Jan 2026 05:56:51 +0000 Subject: [PATCH 31/45] Support GLU interleaving in plain SwiGLU op Signed-off-by: Tim Moon --- tests/pytorch/test_fusible_ops.py | 154 +++++-- .../pytorch/ops/basic/__init__.py | 4 +- .../pytorch/ops/basic/activation.py | 75 ---- .../pytorch/ops/basic/scaled_swiglu.py | 154 ------- .../pytorch/ops/basic/swiglu.py | 418 ++++++++++++++++++ .../pytorch/ops/fused/forward_grouped_mlp.py | 35 +- 6 files changed, 572 insertions(+), 268 deletions(-) delete mode 100644 transformer_engine/pytorch/ops/basic/scaled_swiglu.py create mode 100644 transformer_engine/pytorch/ops/basic/swiglu.py diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 09f024326c..4d5ef0b226 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -173,6 +173,29 @@ def make_reference_and_test_tensors( return ref, test +def assert_close( + a: Optional[torch.Tensor], + b: Optional[torch.Tensor], + *, + rtol: float, + atol: float, +) -> None: + """Assert that two tensors are close.""" + if a is None and b is None: + return + assert a is not None + assert b is not None + a = a.detach() + b = b.detach() + if isinstance(a, QuantizedTensor): + a = a.dequantize() + if isinstance(b, QuantizedTensor): + b = b.dequantize() + a = a.to(dtype=torch.float64, device="cpu") + b = b.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(a, b, rtol=rtol, atol=atol) + + class TestSequentialContainer: """Tests for sequential container""" @@ -1681,6 +1704,7 @@ def test_swiglu( quantization: Optional[str], quantize_forward: bool, quantize_backward: bool, + glu_interleave_size: Optional[int] = None, ): # Tensor dimensions @@ -1707,7 +1731,17 @@ def test_swiglu( ) # Plain PyTorch implementation - x1, x2 = x_ref.chunk(2, dim=-1) + x = x_ref + if glu_interleave_size is not None: + x = x.reshape( + *in_shape[:-1], + in_shape[-1] // (2 * glu_interleave_size), + 2, + glu_interleave_size, + ) + x = x.transpose(-3, -2) + x = x.reshape(in_shape) + x1, x2 = x.chunk(2, dim=-1) y_ref = torch.nn.functional.silu(x1) * x2 y_ref.backward(dy_ref) @@ -1715,7 +1749,7 @@ def test_swiglu( recipe = make_recipe(quantization) forward = te_ops.Sequential( te_ops.Quantize(forward=False, backward=quantize_backward), - te_ops.SwiGLU(), + te_ops.SwiGLU(glu_interleave_size=glu_interleave_size), te_ops.Quantize(forward=quantize_forward, backward=False), ) with te.autocast(enabled=quantized_compute, recipe=recipe): @@ -1728,10 +1762,18 @@ def test_swiglu( tols = quantization_tols(quantization) # Check results - y_test = y_test.to(dtype=torch.float64, device="cpu") - dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") - torch.testing.assert_close(y_test, y_ref, **tols) - torch.testing.assert_close(dx_test, x_ref.grad, **tols) + assert_close(y_test, y_ref, **tols) + assert_close(x_test.grad, x_ref.grad, **tols) + + def test_interleaved_swiglu(self): + self.test_swiglu( + out_shape=(32, 192), + dtype=torch.float32, + quantization=None, + quantize_forward=False, + quantize_backward=False, + glu_interleave_size=32, + ) @pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("quantization", _quantization_list) @@ -2137,14 +2179,14 @@ def test_multiply_extra_input( assert x2_test.grad is None @pytest.mark.parametrize("in_shape", ((71, 192), (5, 7, 128))) - @pytest.mark.parametrize("gate_interleave_size", (None, 32)) + @pytest.mark.parametrize("glu_interleave_size", (None, 32)) @pytest.mark.parametrize("input_requires_grad", (False, True)) @pytest.mark.parametrize("scales_requires_grad", (False, True)) def test_scaled_swiglu( self, *, in_shape: Iterable[int], - gate_interleave_size: Optional[int], + glu_interleave_size: Optional[int], dtype: torch.dtype = torch.float32, device: torch.device = "cuda", input_requires_grad: bool, @@ -2178,12 +2220,12 @@ def test_scaled_swiglu( # Plain PyTorch implementation x = x_ref - if gate_interleave_size is not None: + if glu_interleave_size is not None: x = x.reshape( -1, - in_shape[-1] // (2 * gate_interleave_size), + in_shape[-1] // (2 * glu_interleave_size), 2, - gate_interleave_size, + glu_interleave_size, ) x = x.transpose(1, 2) x = x.reshape(in_shape) @@ -2194,7 +2236,7 @@ def test_scaled_swiglu( y_ref.backward(dy_ref) # Implementation with fusible operation - op = te_ops.ScaledSwiGLU(gate_interleave_size=gate_interleave_size) + op = te_ops.ScaledSwiGLU(glu_interleave_size=glu_interleave_size) y_test = op(x_test, scales_test) if input_requires_grad or scales_requires_grad: y_test.backward(dy_test) @@ -3221,17 +3263,21 @@ def to_cpu(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]: torch.testing.assert_close(to_cpu(ffn1.bias.grad), b1_ref.grad, **tols) torch.testing.assert_close(to_cpu(ffn2.bias.grad), b2_ref.grad, **tols) + @pytest.mark.parametrize("bias", (False, True)) @pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("quantization", _quantization_list) + @pytest.mark.parametrize("glu_interleave_size", (None, 32)) def test_grouped_mlp( self, *, + group_size: int = 4, + bias: bool, + hidden_size: int = 256, dtype: torch.dtype, quantization: Optional[str], device: torch.device = "cuda", - group_size: int = 4, - hidden_size: int = 256, split_alignment: int = 256, + glu_interleave_size: Optional[int], ) -> None: """GroupedLinear + ScaledSwiGLU + GroupedLinear""" @@ -3247,6 +3293,8 @@ def test_grouped_mlp( # Skip invalid configurations with_quantization = quantization is not None maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) + if with_quantization and dtype not in (torch.bfloat16, torch.float16): + pytest.skip("Quantized group GEMM is only supported with BF16/FP16") # Random data x_ref, x_test = make_reference_and_test_tensors( @@ -3268,7 +3316,9 @@ def test_grouped_mlp( test_device=device, ) fc1_ws_ref, fc1_ws_test = [], [] + fc1_bs_ref, fc1_bs_test = [], [] fc2_ws_ref, fc2_ws_test = [], [] + fc2_bs_ref, fc2_bs_test = [], [] for _ in range(group_size): fc1_w_ref, fc1_w_test = make_reference_and_test_tensors( (2 * hidden_size, hidden_size), @@ -3282,10 +3332,27 @@ def test_grouped_mlp( test_dtype=dtype, test_device=device, ) + fc1_b_ref, fc1_b_test = None, None + fc2_b_ref, fc2_b_test = None, None + if bias: + fc1_b_ref, fc1_b_test = make_reference_and_test_tensors( + (2 * hidden_size,), + test_dtype=dtype, + test_device=device, + ) + fc2_b_ref, fc2_b_test = make_reference_and_test_tensors( + (hidden_size,), + test_dtype=dtype, + test_device=device, + ) fc1_ws_ref.append(fc1_w_ref) + fc1_bs_ref.append(fc1_b_ref) fc1_ws_test.append(fc1_w_test) + fc1_bs_test.append(fc1_b_test) fc2_ws_ref.append(fc2_w_ref) + fc2_bs_ref.append(fc2_b_ref) fc2_ws_test.append(fc2_w_test) + fc2_bs_test.append(fc2_b_test) with torch.no_grad(): for t in fc1_ws_ref + fc1_ws_test + fc2_ws_ref + fc2_ws_test: t -= 0.5 @@ -3293,20 +3360,30 @@ def test_grouped_mlp( for t in (x_ref, x_test, dy_ref, dy_test): t -= 0.5 t *= 1 / 2 + if bias: + for t in fc1_bs_ref + fc1_bs_test + fc2_bs_ref + fc2_bs_test: + t -= 0.5 # Reference implementation xs = torch.split(x_ref, split_sizes.tolist()) probs = torch.split(probs_ref, split_sizes.tolist()) ys = [] - for x, fc1_w, fc2_w, prob in zip(xs, fc1_ws_ref, fc2_ws_ref, probs): - x = torch.nn.functional.linear(x, fc1_w) - x = x.reshape(-1, 2 * hidden_size // 64, 2, 32) - x = x.transpose(1, 2) - x = x.reshape(-1, 2 * hidden_size) + for group_idx in range(group_size): + x = xs[group_idx] + x = torch.nn.functional.linear(x, fc1_ws_ref[group_idx], bias=fc1_bs_ref[group_idx]) + if glu_interleave_size is not None: + x = x.reshape( + -1, + 2 * hidden_size // (2 * glu_interleave_size), + 2, + glu_interleave_size, + ) + x = x.transpose(1, 2) + x = x.reshape(-1, 2 * hidden_size) x1, x2 = x.chunk(2, dim=-1) x = torch.nn.functional.silu(x1) * x2 - x = x * prob.unsqueeze(-1) - x = torch.nn.functional.linear(x, fc2_w) + x = x * probs[group_idx].unsqueeze(-1) + x = torch.nn.functional.linear(x, fc2_ws_ref[group_idx], bias=fc2_bs_ref[group_idx]) ys.append(x) y_ref = torch.cat(ys) y_ref.backward(dy_ref) @@ -3318,7 +3395,7 @@ def test_grouped_mlp( group_size, hidden_size, 2 * hidden_size, - bias=False, + bias=bias, device=device, dtype=dtype, ) @@ -3326,13 +3403,13 @@ def test_grouped_mlp( group_size, hidden_size, hidden_size, - bias=False, + bias=bias, device=device, dtype=dtype, ) module = te_ops.Sequential( fc1, - te_ops.ScaledSwiGLU(gate_interleave_size=32), + te_ops.ScaledSwiGLU(glu_interleave_size=glu_interleave_size), fc2, ) @@ -3341,7 +3418,10 @@ def test_grouped_mlp( for group_idx in range(group_size): getattr(fc1, f"weight{group_idx}").copy_(fc1_ws_test[group_idx]) getattr(fc2, f"weight{group_idx}").copy_(fc2_ws_test[group_idx]) - del fc1_ws_test, fc2_ws_test + if bias: + getattr(fc1, f"bias{group_idx}").copy_(fc1_bs_test[group_idx]) + getattr(fc2, f"bias{group_idx}").copy_(fc2_bs_test[group_idx]) + del fc1_ws_test, fc1_bs_test, fc2_ws_test, fc2_bs_test # Fuse ops and perform forward and backward pass with te.autocast(enabled=with_quantization, recipe=recipe): @@ -3353,6 +3433,8 @@ def test_grouped_mlp( te_ops.fused.ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8.is_supported() and quantization == "mxfp8" and dtype == torch.bfloat16 + and not bias + and glu_interleave_size == 32 ): forward_ops = module._module_groups[0]._forward_ops assert len(forward_ops) == 1 @@ -3375,16 +3457,20 @@ def to_cpu(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]: tols = {"rtol": 0.5, "atol": 1} # Check values - torch.testing.assert_close(to_cpu(y_test), y_ref, **tols) - torch.testing.assert_close(to_cpu(x_test.grad), x_ref.grad, **tols) - torch.testing.assert_close(to_cpu(probs_test.grad), probs_ref.grad, **tols) + assert_close(y_test, y_ref, **tols) + assert_close(x_test.grad, x_ref.grad, **tols) + assert_close(probs_test.grad, probs_ref.grad, **tols) for group_idx in range(group_size): - fc1_dw_test = to_cpu(getattr(fc1, f"weight{group_idx}").grad) - fc1_dw_ref = fc1_ws_ref[group_idx].grad - fc2_dw_test = to_cpu(getattr(fc2, f"weight{group_idx}").grad) - fc2_dw_ref = fc2_ws_ref[group_idx].grad - torch.testing.assert_close(fc2_dw_test, fc2_dw_ref, **tols) - torch.testing.assert_close(fc1_dw_test, fc1_dw_ref, **tols) + assert_close( + getattr(fc2, f"weight{group_idx}").grad, + fc2_ws_ref[group_idx].grad, + **tols, + ) + assert_close( + getattr(fc1, f"weight{group_idx}").grad, + fc1_ws_ref[group_idx].grad, + **tols, + ) class TestCustomOps: diff --git a/transformer_engine/pytorch/ops/basic/__init__.py b/transformer_engine/pytorch/ops/basic/__init__.py index c02340817a..f317deca12 100644 --- a/transformer_engine/pytorch/ops/basic/__init__.py +++ b/transformer_engine/pytorch/ops/basic/__init__.py @@ -14,8 +14,6 @@ SReLU, SReGLU, SiLU, - SwiGLU, - ClampedSwiGLU, ) from .add_extra_input import AddExtraInput from .all_gather import AllGather @@ -34,4 +32,4 @@ from .reduce_scatter import ReduceScatter from .reshape import Reshape from .rmsnorm import RMSNorm -from .scaled_swiglu import ScaledSwiGLU +from .swiglu import ClampedSwiGLU, ScaledSwiGLU, SwiGLU diff --git a/transformer_engine/pytorch/ops/basic/activation.py b/transformer_engine/pytorch/ops/basic/activation.py index 9d54e12dba..2f1debdf5e 100644 --- a/transformer_engine/pytorch/ops/basic/activation.py +++ b/transformer_engine/pytorch/ops/basic/activation.py @@ -27,8 +27,6 @@ "SReLU", "SReGLU", "SiLU", - "SwiGLU", - "ClampedSwiGLU", ] @@ -355,76 +353,3 @@ def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: return tex.dsilu(*args, **kwargs) - - -class SwiGLU(_ActivationOperation): - r"""Swish gated linear unit - - The input tensor is split into chunks :math:`a` and :math:`b` - along the last dimension and the following is computed: - - .. math:: - - \text{GEGLU}(a,b) = \text{SiLU}(a) * b - - where - - .. math:: - - \text{SiLU}(x) = x \sigma(x) = \frac{x}{1+\exp(-x)} - - .. warning:: - - Transformer Engine's gated activations and PyTorch's GLU - activation follow opposite conventions for :math:`a` and - :math:`b`. Transformer Engine applies the gating function to - the first half of the input tensor, while PyTorch applies it to - the second half. - - The Sigmoid Linear Unit (SiLU) gating function is also known as - the swish function. See - `GLU Variants Improve Transformer`__ - and `Gaussian Error Linear Units (GELUs)`__. - - """ - - def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: - return tex.swiglu(*args, **kwargs) - - def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: - return tex.dswiglu(*args, **kwargs) - - -class ClampedSwiGLU(_ActivationOperation): - r"""GPT-OSS - Implementation based on `GPT-OSS`__. - - This activation has two differences compared to the original SwiGLU - 1. Both gate and pre-activations are clipped based on parameter limit. - 2. Activation uses sigmoid(alpha * x) instead of sigmoid(x) used in Swish activation. - - .. warning:: The input tensor is chunked along the last dimension to get gates/pre-activations which is differnt - from GPT OSS implementation where the gates/pre-activations are assumed to be interleaved in the input tensor. - - Parameters - ---------- - limit : float - The clamp limit. - alpha : float - The scaling factor for the sigmoid function used in the activation. - cache_quantized_input : bool, default = False - Quantize input tensor when caching for use in the backward pass. - """ - - def __init__( - self, *, limit: float = 7.0, alpha: float = 1.702, cache_quantized_input: bool = False - ): - super().__init__(cache_quantized_input=cache_quantized_input) - self.limit = limit - self.alpha = alpha - - def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: - return tex.clamped_swiglu(*args, limit=self.limit, alpha=self.alpha, **kwargs) - - def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: - return tex.clamped_dswiglu(*args, limit=self.limit, alpha=self.alpha, **kwargs) diff --git a/transformer_engine/pytorch/ops/basic/scaled_swiglu.py b/transformer_engine/pytorch/ops/basic/scaled_swiglu.py deleted file mode 100644 index 9b398fa217..0000000000 --- a/transformer_engine/pytorch/ops/basic/scaled_swiglu.py +++ /dev/null @@ -1,154 +0,0 @@ -# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -"""Fusible operation for multiplying with extra input tensor.""" - -from __future__ import annotations -from collections.abc import Iterable -from typing import Any, Optional - -import torch - -import transformer_engine_torch as tex -from ...tensor import Quantizer -from ..op import BasicOperation, OperationContext -from .._common import maybe_dequantize - - -class ScaledSwiGLU(BasicOperation): - """SwiGLU with post-scaling - """ - - # Operation expects scales - num_extra_inputs: int = 1 - - def __init__(self, gate_interleave_size: Optional[int] = None): - super().__init__() - self.gate_interleave_size: Optional[int] = gate_interleave_size - - def op_forward(self, *args, **kwargs) -> None: - raise RuntimeError( - "{self.__class__.__name__} operation has " - f"{self.num_extra_inputs} extra tensor inputs " - f"and {self.num_extra_outputs} extra tensor outputs. " - "It overrides `fuser_forward` instead of `op_forward`." - ) - - def op_backward(self, *args, **kwargs) -> None: - raise RuntimeError( - "{self.__class__.__name__} operation has " - f"{self.num_extra_inputs} extra tensor inputs " - f"and {self.num_extra_outputs} extra tensor outputs. " - "It overrides `fuser_backward` instead of `op_backward`." - ) - - def fuser_forward( - self, - basic_op_ctxs: list[OperationContext], - input_: torch.Tensor, - *, - basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], - prev_op_grad_output_quantizer: Optional[Quantizer], - next_op_input_quantizer: Optional[Quantizer], - basic_op_kwargs: list[dict[str, Any]], - ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: - extra_input = basic_op_extra_inputs[0][0] - - # Determine compute dtype - if torch.is_autocast_enabled(): - dtype = torch.get_autocast_dtype("cuda") - elif isinstance(input_, torch.Tensor): - dtype = input_.dtype - else: - dtype = extra_input.dtype - - # Make sure inputs are in correct dtype - input_ = maybe_dequantize(input_, dtype) - scales = maybe_dequantize(extra_input, dtype) - - # Remove gate interleaving if needed - swiglu_in = input_ - if self.gate_interleave_size is not None: - shape = swiglu_in.size() - swiglu_in = swiglu_in.reshape( - -1, - shape[-1] // (2 * self.gate_interleave_size), - 2, - self.gate_interleave_size, - ) - swiglu_in = swiglu_in.transpose(1, 2).contiguous() - swiglu_in = swiglu_in.view(shape) - - # Compute scaled SwiGLU - swiglu_out = tex.swiglu(swiglu_in, None) - out = swiglu_out * scales.unsqueeze(-1) - - # Save state for backward pass - ctx = basic_op_ctxs[0] - if ctx.requires_grad: - ctx.input_requires_grad = True - ctx.extra_input_requires_grad = extra_input.requires_grad - ctx.dtype = dtype - ctx.save_for_backward( - input_, - scales if ctx.input_requires_grad else None, - ) - - return out, [()] - - def fuser_backward( - self, - basic_op_ctxs: list[OperationContext], - grad_output: torch.Tensor, - *, - basic_op_grad_extra_outputs: list[tuple[torch.Tensor, ...]], - ) -> tuple[ - torch.Tensor, - Iterable[Iterable[Optional[torch.Tensor]]], - Iterable[Iterable[Optional[torch.Tensor]]], - ]: - ctx = basic_op_ctxs[0] - input_, scales = ctx.saved_tensors - input_ = maybe_dequantize(input_, ctx.dtype) - if scales is not None: - scales = maybe_dequantize(scales, ctx.dtype) - grad_output = maybe_dequantize(grad_output, ctx.dtype) - - # Remove gate interleaving if needed - swiglu_in = input_ - if self.gate_interleave_size is not None: - shape = swiglu_in.size() - swiglu_in = swiglu_in.reshape( - -1, - shape[-1] // (2 * self.gate_interleave_size), - 2, - self.gate_interleave_size, - ) - swiglu_in = swiglu_in.transpose(1, 2).contiguous() - swiglu_in = swiglu_in.view(shape) - - # Compute input grad - grad_input = None - if ctx.input_requires_grad: - grad_swiglu_out = grad_output * scales.unsqueeze(-1) - grad_swiglu_in = tex.dswiglu(grad_swiglu_out, swiglu_in, None) - grad_input = grad_swiglu_in - if self.gate_interleave_size is not None: - shape = grad_input.size() - grad_input = grad_input.reshape( - -1, - 2, - shape[-1] // (2 * self.gate_interleave_size), - self.gate_interleave_size, - ) - grad_input = grad_input.transpose(1, 2).contiguous() - grad_input = grad_input.view(shape) - - # Compute scales grad by recomputing SwiGLU - grad_extra_input = None - if ctx.extra_input_requires_grad: - swiglu_out = tex.swiglu(swiglu_in, None) - grad_extra_input = torch.linalg.vecdot(swiglu_out, grad_output) - - return grad_input, [()], [(grad_extra_input,)] diff --git a/transformer_engine/pytorch/ops/basic/swiglu.py b/transformer_engine/pytorch/ops/basic/swiglu.py new file mode 100644 index 0000000000..03c08426e5 --- /dev/null +++ b/transformer_engine/pytorch/ops/basic/swiglu.py @@ -0,0 +1,418 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Fusible operation for multiplying with extra input tensor.""" + +from __future__ import annotations +from collections.abc import Iterable +from typing import Any, Optional + +import torch + +import transformer_engine_torch as tex +from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload +from ...tensor import Float8CurrentScalingQuantizer, Quantizer +from ...utils import clear_tensor_data +from ..op import BasicOperation, OperationContext +from .._common import maybe_dequantize + +__all__ = ["SwiGLU", "ClampedSwiGLU", "ScaledSwiGLU"] + + +class SwiGLU(BasicOperation): + r"""Swish gated linear unit + + The input tensor is split into chunks :math:``a`` and :math:``b`` + along the last dimension and the following is computed: + + .. math:: + + \text{GEGLU}(a,b) = \text{SiLU}(a) * b + + where + + .. math:: + + \text{SiLU}(x) = x \sigma(x) = \frac{x}{1+\exp(-x)} + + .. warning:: + + Transformer Engine's gated activations and PyTorch's GLU + activation follow opposite conventions for :math:``a`` and + :math:``b``. Transformer Engine applies the gating function to + the first half of the input tensor, while PyTorch applies it to + the second half. + + The Sigmoid Linear Unit (SiLU) gating function is also known as + the swish function. See + ``GLU Variants Improve Transformer``__ + and ``Gaussian Error Linear Units (GELUs)``__. + + """ + + def __init__( + self, + *, + cache_quantized_input: bool = False, + glu_interleave_size: Optional[int] = None + ): + super().__init__() + self.cache_quantized_input: bool = cache_quantized_input + self.glu_interleave_size: Optional[int] = glu_interleave_size + + def op_forward( + self, + ctx: OperationContext, + input_: torch.Tensor, + prev_op_grad_output_quantizer: Optional[Quantizer], + next_op_input_quantizer: Optional[Quantizer], + ) -> torch.Tensor: + + # Compute dtype + dtype: torch.dtype + if torch.is_autocast_enabled(): + dtype = torch.get_autocast_dtype("cuda") + else: + dtype = input_.dtype + if dtype not in (torch.float32, torch.float16, torch.bfloat16): + raise RuntimeError(f"Unsupported dtype ({dtype})") + + # Check input tensor + input_ = maybe_dequantize(input_.contiguous(), dtype) + + # Remove interleaving if needed + swiglu_in = input_ + if self.glu_interleave_size is not None: + shape = swiglu_in.size() + swiglu_in = swiglu_in.reshape( + -1, + shape[-1] // (2 * self.glu_interleave_size), + 2, + self.glu_interleave_size, + ) + swiglu_in = swiglu_in.transpose(1, 2).contiguous() + swiglu_in = swiglu_in.view(shape) + + # Launch kernel + out = tex.swiglu(swiglu_in, next_op_input_quantizer) + + # Quantize input to FP8 before caching if needed + if self.cache_quantized_input: + input_quantizer = Float8CurrentScalingQuantizer( + tex.DType.kFloat8E4M3, + input_.device, + ) + input_quantizer.set_usage(rowwise=True, columnwise=False) + input_ = input_quantizer(input_) + + # Save state for backward pass + if ctx.requires_grad: + if is_cpu_offload_enabled(): + mark_activation_offload(input_) + ctx.save_for_backward(input_) + ctx.dtype = dtype + ctx.prev_op_grad_output_quantizer = prev_op_grad_output_quantizer + + return out + + def op_backward( + self, + ctx: OperationContext, + grad_output: torch.Tensor, + ) -> tuple[torch.Tensor, tuple[()]]: + + # Saved tensors from forward pass + (input_,) = ctx.saved_tensors + + # Make sure tensors have correct dtypes + x = maybe_dequantize(input_.contiguous(), ctx.dtype) + dy = maybe_dequantize(grad_output.contiguous(), ctx.dtype) + + # Remove interleaving if needed + swiglu_in = x + if self.glu_interleave_size is not None: + shape = swiglu_in.size() + swiglu_in = swiglu_in.reshape( + -1, + shape[-1] // (2 * self.glu_interleave_size), + 2, + self.glu_interleave_size, + ) + swiglu_in = swiglu_in.transpose(1, 2).contiguous() + swiglu_in = swiglu_in.view(shape) + + # Quantizer for grad input + quantizer = ctx.prev_op_grad_output_quantizer + if self.glu_interleave_size is not None: + quantizer = None + + # Launch kernel + grad_swiglu_in = tex.dswiglu(dy, swiglu_in, quantizer) + + # Apply interleaving if needed + dx = grad_swiglu_in + if self.glu_interleave_size is not None: + shape = dx.size() + dx = dx.reshape( + -1, + 2, + shape[-1] // (2 * self.glu_interleave_size), + self.glu_interleave_size, + ) + dx = dx.transpose(1, 2).contiguous() + dx = dx.view(shape) + + # Clear input tensor if possible + clear_tensor_data(input_) + + return dx, () + + +class ClampedSwiGLU(BasicOperation): + r"""GPT-OSS + Implementation based on ``GPT-OSS``__. + + This activation has two differences compared to the original SwiGLU + 1. Both gate and pre-activations are clipped based on parameter limit. + 2. Activation uses sigmoid(alpha * x) instead of sigmoid(x) used in Swish activation. + + .. warning:: The input tensor is chunked along the last dimension to get gates/pre-activations which is differnt + from GPT OSS implementation where the gates/pre-activations are assumed to be interleaved in the input tensor. + + Parameters + ---------- + limit : float + The clamp limit. + alpha : float + The scaling factor for the sigmoid function used in the activation. + cache_quantized_input : bool, default = ``False`` + Quantize input tensor when caching for use in the backward pass. + """ + + def __init__( + self, *, limit: float = 7.0, alpha: float = 1.702, cache_quantized_input: bool = False + ): + super().__init__() + self.limit: float = limit + self.alpha: float = alpha + self.cache_quantized_input: bool = cache_quantized_input + + def op_forward( + self, + ctx: OperationContext, + input_: torch.Tensor, + prev_op_grad_output_quantizer: Optional[Quantizer], + next_op_input_quantizer: Optional[Quantizer], + ) -> torch.Tensor: + + # Compute dtype + dtype: torch.dtype + if torch.is_autocast_enabled(): + dtype = torch.get_autocast_dtype("cuda") + else: + dtype = input_.dtype + if dtype not in (torch.float32, torch.float16, torch.bfloat16): + raise RuntimeError(f"Unsupported dtype ({dtype})") + + # Check input tensor + x = maybe_dequantize(input_.contiguous(), dtype) + + # Launch kernel + y = tex.clamped_swiglu( + x, + next_op_input_quantizer, + limit=self.limit, + alpha=self.alpha, + ) + + # Quantize input to FP8 before caching if needed + if self.cache_quantized_input: + input_quantizer = Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, x.device) + input_quantizer.set_usage(rowwise=True, columnwise=False) + x = input_quantizer(x) + + # Save state for backward pass + if ctx.requires_grad: + if is_cpu_offload_enabled(): + mark_activation_offload(x) + ctx.save_for_backward(x) + ctx.dtype = dtype + ctx.prev_op_grad_output_quantizer = prev_op_grad_output_quantizer + + return y + + def op_backward( + self, + ctx: OperationContext, + grad_output: torch.Tensor, + ) -> tuple[torch.Tensor, tuple[()]]: + + # Saved tensors from forward pass + (input_,) = ctx.saved_tensors + + # Make sure tensors have correct dtypes + x = maybe_dequantize(input_.contiguous(), ctx.dtype) + dy = maybe_dequantize(grad_output.contiguous(), ctx.dtype) + + # Launch kernel + dx = tex.clamped_dswiglu( + dy, + x, + ctx.prev_op_grad_output_quantizer, + limit=self.limit, + alpha=self.alpha, + ) + + # Clear input tensor if possible + clear_tensor_data(input_) + + return dx, () + + +class ScaledSwiGLU(BasicOperation): + """SwiGLU with post-scaling + + If the SwiGLU output has shape ``(d_1, ..., d_n)``, it is + multiplied with an extra input tensor of shape + ``(d_1, ..., d_{n-1})``. + + """ + + # Operation expects scales + num_extra_inputs: int = 1 + + def __init__(self, glu_interleave_size: Optional[int] = None): + super().__init__() + self.glu_interleave_size: Optional[int] = glu_interleave_size + + def op_forward(self, *args, **kwargs) -> None: + raise RuntimeError( + "{self.__class__.__name__} operation has " + f"{self.num_extra_inputs} extra tensor inputs " + f"and {self.num_extra_outputs} extra tensor outputs. " + "It overrides `fuser_forward` instead of `op_forward`." + ) + + def op_backward(self, *args, **kwargs) -> None: + raise RuntimeError( + "{self.__class__.__name__} operation has " + f"{self.num_extra_inputs} extra tensor inputs " + f"and {self.num_extra_outputs} extra tensor outputs. " + "It overrides `fuser_backward` instead of `op_backward`." + ) + + def fuser_forward( + self, + basic_op_ctxs: list[OperationContext], + input_: torch.Tensor, + *, + basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], + prev_op_grad_output_quantizer: Optional[Quantizer], + next_op_input_quantizer: Optional[Quantizer], + basic_op_kwargs: list[dict[str, Any]], + ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: + extra_input = basic_op_extra_inputs[0][0] + + # Determine compute dtype + if torch.is_autocast_enabled(): + dtype = torch.get_autocast_dtype("cuda") + elif isinstance(input_, torch.Tensor): + dtype = input_.dtype + else: + dtype = extra_input.dtype + + # Make sure inputs are in correct dtype + input_ = maybe_dequantize(input_, dtype) + scales = maybe_dequantize(extra_input, dtype) + + # Remove gate interleaving if needed + swiglu_in = input_ + if self.glu_interleave_size is not None: + shape = swiglu_in.size() + swiglu_in = swiglu_in.reshape( + -1, + shape[-1] // (2 * self.glu_interleave_size), + 2, + self.glu_interleave_size, + ) + swiglu_in = swiglu_in.transpose(1, 2).contiguous() + swiglu_in = swiglu_in.view(shape) + + # Compute scaled SwiGLU + swiglu_out = tex.swiglu(swiglu_in, None) + out = swiglu_out * scales.unsqueeze(-1) + + # Save state for backward pass + ctx = basic_op_ctxs[0] + if ctx.requires_grad: + if is_cpu_offload_enabled(): + mark_activation_offload(input_) + ctx.input_requires_grad = True + ctx.extra_input_requires_grad = extra_input.requires_grad + ctx.dtype = dtype + ctx.save_for_backward( + input_, + scales if ctx.input_requires_grad else None, + ) + + return out, [()] + + def fuser_backward( + self, + basic_op_ctxs: list[OperationContext], + grad_output: torch.Tensor, + *, + basic_op_grad_extra_outputs: list[tuple[torch.Tensor, ...]], + ) -> tuple[ + torch.Tensor, + Iterable[Iterable[Optional[torch.Tensor]]], + Iterable[Iterable[Optional[torch.Tensor]]], + ]: + ctx = basic_op_ctxs[0] + input_, scales = ctx.saved_tensors + input_ = maybe_dequantize(input_, ctx.dtype) + if scales is not None: + scales = maybe_dequantize(scales, ctx.dtype) + grad_output = maybe_dequantize(grad_output, ctx.dtype) + + # Remove gate interleaving if needed + swiglu_in = input_ + if self.glu_interleave_size is not None: + shape = swiglu_in.size() + swiglu_in = swiglu_in.reshape( + -1, + shape[-1] // (2 * self.glu_interleave_size), + 2, + self.glu_interleave_size, + ) + swiglu_in = swiglu_in.transpose(1, 2).contiguous() + swiglu_in = swiglu_in.view(shape) + + # Compute input grad + grad_input = None + if ctx.input_requires_grad: + grad_swiglu_out = grad_output * scales.unsqueeze(-1) + grad_swiglu_in = tex.dswiglu(grad_swiglu_out, swiglu_in, None) + grad_input = grad_swiglu_in + if self.glu_interleave_size is not None: + shape = grad_input.size() + grad_input = grad_input.reshape( + -1, + 2, + shape[-1] // (2 * self.glu_interleave_size), + self.glu_interleave_size, + ) + grad_input = grad_input.transpose(1, 2).contiguous() + grad_input = grad_input.view(shape) + + # Compute scales grad by recomputing SwiGLU + grad_extra_input = None + if ctx.extra_input_requires_grad: + swiglu_out = tex.swiglu(swiglu_in, None) + grad_extra_input = torch.linalg.vecdot(swiglu_out, grad_output) + + # Clear input tensor if possible + clear_tensor_data(ctx.saved_tensors[0]) # input_ + + return grad_input, [()], [(grad_extra_input,)] diff --git a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py index 78e87e2938..a4f5a06d16 100644 --- a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py @@ -26,7 +26,7 @@ class ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8(FusedOperation): """Fused op for MXFP8 GroupedLinear + ScaledSwiGLU + GroupedLinear - Uses experimental CuTe DSL kernel. + Uses experimental CuTe DSL kernel from cuDNN front-end. """ @@ -60,6 +60,37 @@ def __init__( ) -> None: super().__init__((fc1, swiglu, fc2)) + # Check for unsupported configurations + if not self.is_supported(): + self.grouped_gemm_swiglu_kernel() # Try triggering import error + raise RuntimeError( + f"{self.__class__.__name__} is not supported on this system." + ) + if fc1.in_features % 256 != 0 or fc1.in_features % 256 != 0: + raise ValueError( + f"Unsupported dims for FC1 (group_size={fc1.group_size}, " + f"in_features={fc1.in_features}, out_features={fc1.out_features})." + ) + if fc2.in_features % 256 != 0 or fc2.in_features % 256 != 0: + raise ValueError( + f"Unsupported dims for FC2 (group_size={fc2.group_size}, " + f"in_features={fc2.in_features}, out_features={fc2.out_features})." + ) + if fc1.out_features != 2 * fc2.in_features or fc1.group_size != fc2.group_size: + raise ValueError( + f"FC1 (group_size={fc1.group_size}, in_features={fc1.in_features}, " + f"out_features={fc1.out_features}) " + f"and FC2 (group_size={fc2.group_size}, in_features={fc2.in_features}, " + f"out_features={fc2.out_features}) do not match." + ) + if fc1.has_bias or fc2.has_bias: + raise ValueError("Fused kernel does not support bias.") + if swiglu.glu_interleave_size != 32: + raise ValueError( + "Fused kernel requires 32-wide GLU interleaving, " + "but got glu_interleave_size={swiglu.glu_interleave_size}." + ) + def fuser_forward( self, basic_op_ctxs: list[OperationContext], @@ -385,7 +416,7 @@ def fuse_forward_ops( or window[2].out_features % 256 != 0 ): matches_pattern = False - elif window[1].gate_interleave_size != 32: + elif window[1].glu_interleave_size != 32: matches_pattern = False if matches_pattern: From 4c6c35fb54053096625e3990131813a7d22e0020 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 24 Jan 2026 06:40:44 +0000 Subject: [PATCH 32/45] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../pytorch/ops/basic/swiglu.py | 5 +--- .../pytorch/ops/fused/forward_grouped_mlp.py | 26 ++++++++++++------- 2 files changed, 17 insertions(+), 14 deletions(-) diff --git a/transformer_engine/pytorch/ops/basic/swiglu.py b/transformer_engine/pytorch/ops/basic/swiglu.py index 03c08426e5..8f1068ba46 100644 --- a/transformer_engine/pytorch/ops/basic/swiglu.py +++ b/transformer_engine/pytorch/ops/basic/swiglu.py @@ -52,10 +52,7 @@ class SwiGLU(BasicOperation): """ def __init__( - self, - *, - cache_quantized_input: bool = False, - glu_interleave_size: Optional[int] = None + self, *, cache_quantized_input: bool = False, glu_interleave_size: Optional[int] = None ): super().__init__() self.cache_quantized_input: bool = cache_quantized_input diff --git a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py index a4f5a06d16..fec9d6789e 100644 --- a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py @@ -35,6 +35,7 @@ class ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8(FusedOperation): def grouped_gemm_swiglu_kernel(cls) -> Callable: """Fused kernel for grouped GEMM, SwiGLU, and post-multiplication.""" from cudnn import grouped_gemm_swiglu_wrapper_sm100 + return grouped_gemm_swiglu_wrapper_sm100 @classmethod @@ -63,9 +64,7 @@ def __init__( # Check for unsupported configurations if not self.is_supported(): self.grouped_gemm_swiglu_kernel() # Try triggering import error - raise RuntimeError( - f"{self.__class__.__name__} is not supported on this system." - ) + raise RuntimeError(f"{self.__class__.__name__} is not supported on this system.") if fc1.in_features % 256 != 0 or fc1.in_features % 256 != 0: raise ValueError( f"Unsupported dims for FC1 (group_size={fc1.group_size}, " @@ -121,9 +120,8 @@ def fuser_forward( # Check which grads are required requires_grad = any(ctx.requires_grad for ctx in basic_op_ctxs) input_requires_grad = requires_grad - weight_requires_grad = ( - requires_grad - and (fc1_op.weight0.requires_grad or fc2_op.weight0.requires_grad) + weight_requires_grad = requires_grad and ( + fc1_op.weight0.requires_grad or fc2_op.weight0.requires_grad ) # Quantizers @@ -213,16 +211,24 @@ def fuser_forward( # Pack weight tensors fc1_w_data = torch.stack([w._rowwise_data for w in fc1_weights]) fc1_w_data = fc1_w_data.view(dtype=torch.float8_e4m3fn) - fc1_w_data = fc1_w_data.view(group_size, fc1_weight_shape[0] // 64, 2, 32, fc1_weight_shape[1]) + fc1_w_data = fc1_w_data.view( + group_size, fc1_weight_shape[0] // 64, 2, 32, fc1_weight_shape[1] + ) fc1_w_data = fc1_w_data.flip(2).contiguous() # Swap SwiGLU gate/activation fc1_w_data = fc1_w_data.view(group_size, fc1_weight_shape[0], fc1_weight_shape[1]) fc1_w_data = fc1_w_data.permute(1, 2, 0) fc1_w_scales = torch.stack([w._rowwise_scale_inv for w in fc1_weights]) fc1_w_scales = fc1_w_scales.view(dtype=torch.float8_e8m0fnu) - fc1_w_scales = fc1_w_scales.view(group_size, fc1_weight_shape[0] // 64, 2, 32, fc1_weight_shape[1] // 32) + fc1_w_scales = fc1_w_scales.view( + group_size, fc1_weight_shape[0] // 64, 2, 32, fc1_weight_shape[1] // 32 + ) fc1_w_scales = fc1_w_scales.flip(2).contiguous() # Swap SwiGLU gate/activation - fc1_w_scales = fc1_w_scales.view(group_size, fc1_weight_shape[0] // 128, 4, 32, fc1_weight_shape[1] // 128, 4) - fc1_w_scales = fc1_w_scales.permute(0, 1, 4, 3, 2, 5).contiguous() # Convert to swizzled layout + fc1_w_scales = fc1_w_scales.view( + group_size, fc1_weight_shape[0] // 128, 4, 32, fc1_weight_shape[1] // 128, 4 + ) + fc1_w_scales = fc1_w_scales.permute( + 0, 1, 4, 3, 2, 5 + ).contiguous() # Convert to swizzled layout fc1_w_scales = fc1_w_scales.permute(3, 4, 1, 5, 2, 0) # Kernel tile logic From caf580bba9fec7955046ae1afea8d6fad4fcab75 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Sat, 24 Jan 2026 06:58:07 +0000 Subject: [PATCH 33/45] Remove MultiplyExtraInput op Signed-off-by: Tim Moon --- tests/pytorch/test_fusible_ops.py | 64 -------- .../pytorch/ops/basic/__init__.py | 1 - .../pytorch/ops/basic/multiply_extra_input.py | 149 ------------------ 3 files changed, 214 deletions(-) delete mode 100644 transformer_engine/pytorch/ops/basic/multiply_extra_input.py diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 4d5ef0b226..b1dd7b69eb 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -2114,70 +2114,6 @@ def test_grouped_linear( else: assert b_test.grad is None - @pytest.mark.parametrize( - "input_shape,extra_input_shape", - ( - ((3, 4, 5), (3, 4, 5)), - ((6, 7), ()), - ((), (8, 9)), - ((10, 11, 12), (11, 1)), - ((1, 15), (13, 14, 15)), - ), - ) - @pytest.mark.parametrize("input_requires_grad", (False, True)) - @pytest.mark.parametrize("extra_input_requires_grad", (False, True)) - def test_multiply_extra_input( - self, - *, - input_shape: Iterable[int], - extra_input_shape: Iterable[int], - dtype: torch.dtype = torch.float32, - device: torch.device = "cuda", - input_requires_grad: bool, - extra_input_requires_grad: bool, - ) -> None: - """Multiply two tensors""" - - # Random data - x1_ref, x1_test = make_reference_and_test_tensors( - input_shape, - test_dtype=dtype, - test_device=device, - requires_grad=input_requires_grad, - ) - x2_ref, x2_test = make_reference_and_test_tensors( - extra_input_shape, - test_dtype=dtype, - test_device=device, - requires_grad=extra_input_requires_grad, - ) - - # Plain PyTorch implementation - y_ref = x1_ref * x2_ref - if input_requires_grad or extra_input_requires_grad: - torch.square(y_ref).sum().backward() - - # Implementation with fusible operation - op = te_ops.MultiplyExtraInput() - y_test = op(x1_test, x2_test) - if input_requires_grad or extra_input_requires_grad: - torch.square(y_test).sum().backward() - - # Check results - tols = dtype_tols(dtype) - y_test = y_test.to(dtype=torch.float64, device="cpu") - torch.testing.assert_close(y_test, y_ref, **tols) - if input_requires_grad: - dx1_test = x1_test.grad.to(dtype=torch.float64, device="cpu") - torch.testing.assert_close(dx1_test, x1_ref.grad, **tols) - else: - assert x1_test.grad is None - if extra_input_requires_grad: - dx2_test = x2_test.grad.to(dtype=torch.float64, device="cpu") - torch.testing.assert_close(dx2_test, x2_ref.grad, **tols) - else: - assert x2_test.grad is None - @pytest.mark.parametrize("in_shape", ((71, 192), (5, 7, 128))) @pytest.mark.parametrize("glu_interleave_size", (None, 32)) @pytest.mark.parametrize("input_requires_grad", (False, True)) diff --git a/transformer_engine/pytorch/ops/basic/__init__.py b/transformer_engine/pytorch/ops/basic/__init__.py index f317deca12..32da121cce 100644 --- a/transformer_engine/pytorch/ops/basic/__init__.py +++ b/transformer_engine/pytorch/ops/basic/__init__.py @@ -27,7 +27,6 @@ from .l2normalization import L2Normalization from .layer_norm import LayerNorm from .make_extra_output import MakeExtraOutput -from .multiply_extra_input import MultiplyExtraInput from .quantize import Quantize from .reduce_scatter import ReduceScatter from .reshape import Reshape diff --git a/transformer_engine/pytorch/ops/basic/multiply_extra_input.py b/transformer_engine/pytorch/ops/basic/multiply_extra_input.py deleted file mode 100644 index f9dfef4d81..0000000000 --- a/transformer_engine/pytorch/ops/basic/multiply_extra_input.py +++ /dev/null @@ -1,149 +0,0 @@ -# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -"""Fusible operation for multiplying with extra input tensor.""" - -from __future__ import annotations -from collections.abc import Iterable -from typing import Any, Optional - -import torch - -from ...tensor import Quantizer -from ..op import BasicOperation, OperationContext -from .._common import maybe_dequantize - - -def _reduce_broadcast_dims( - x: torch.Tensor, - target_shape: Iterable[int], -) -> torch.Tensor: - """Reduce a tensor down to a target shape. - - The input tensor shape and target shape are assumed to be - broadcast-compatible. In other words, a tensor with the target - shape can be broadcast to match the input tensor shape. - - """ - shape = tuple(x.size()) - target_shape = tuple(target_shape) - - # Return immediately if tensor already has correct shape - if shape == target_shape: - return x - - # Determine reduction dimensions - reduce_dims = [] - if len(shape) < len(target_shape): - raise ValueError( - f"Invalid target shape (shape={shape} cannot be broadcast to shape={target_shape})." - ) - if len(shape) > len(target_shape): - reduce_dims.extend(range(len(shape) - len(target_shape))) - for idx in range(-len(target_shape), 0): - if shape[idx] == target_shape[idx]: - pass - elif target_shape[idx] != 1: - raise ValueError( - f"Invalid target shape (shape={shape} cannot be broadcast to shape={target_shape})." - ) - else: - reduce_dims.append(idx) - - # Perform reduction - return x.sum(reduce_dims).reshape(target_shape) - - -class MultiplyExtraInput(BasicOperation): - """Multiply with extra input tensor. - - If the tensor shapes do not match, they will follow NumPy - broadcasting semantics. - - """ - - # Operation expects extra input tensor - num_extra_inputs: int = 1 - - def op_forward(self, *args, **kwargs) -> None: - raise RuntimeError( - "{self.__class__.__name__} operation has " - f"{self.num_extra_inputs} extra tensor inputs " - f"and {self.num_extra_outputs} extra tensor outputs. " - "It overrides `fuser_forward` instead of `op_forward`." - ) - - def op_backward(self, *args, **kwargs) -> None: - raise RuntimeError( - "{self.__class__.__name__} operation has " - f"{self.num_extra_inputs} extra tensor inputs " - f"and {self.num_extra_outputs} extra tensor outputs. " - "It overrides `fuser_backward` instead of `op_backward`." - ) - - def fuser_forward( - self, - basic_op_ctxs: list[OperationContext], - input_: torch.Tensor, - *, - basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], - prev_op_grad_output_quantizer: Optional[Quantizer], - next_op_input_quantizer: Optional[Quantizer], - basic_op_kwargs: list[dict[str, Any]], - ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: - extra_input = basic_op_extra_inputs[0][0] - - # Determine compute dtype - if torch.is_autocast_enabled(): - dtype = torch.get_autocast_dtype("cuda") - elif isinstance(input_, torch.Tensor): - dtype = input_.dtype - else: - dtype = extra_input.dtype - - # Perform multiplication - x1 = maybe_dequantize(input_, dtype) - x2 = maybe_dequantize(extra_input, dtype) - output = input_ * extra_input - - # Save state for backward pass - ctx = basic_op_ctxs[0] - if ctx.requires_grad: - ctx.input_shape = x1.size() - ctx.extra_input_shape = extra_input.size() - ctx.input_requires_grad = True - ctx.extra_input_requires_grad = extra_input.requires_grad - ctx.save_for_backward( - x1 if ctx.extra_input_requires_grad else None, - x2 if ctx.input_requires_grad else None, - ) - - return output, [()] - - def fuser_backward( - self, - basic_op_ctxs: list[OperationContext], - grad_output: torch.Tensor, - *, - basic_op_grad_extra_outputs: list[tuple[torch.Tensor, ...]], - ) -> tuple[ - torch.Tensor, - Iterable[Iterable[Optional[torch.Tensor]]], - Iterable[Iterable[Optional[torch.Tensor]]], - ]: - ctx = basic_op_ctxs[0] - input_, extra_input = ctx.saved_tensors - grad_input = None - if ctx.input_requires_grad: - grad_input = _reduce_broadcast_dims( - grad_output * extra_input, - ctx.input_shape, - ) - grad_extra_input = None - if ctx.extra_input_requires_grad: - grad_extra_input = _reduce_broadcast_dims( - grad_output * input_, - ctx.extra_input_shape, - ) - return grad_input, [()], [(grad_extra_input,)] From 36e6918b929d678f37bc8d4ba1d9324ca00516ae Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 25 Jan 2026 00:33:42 +0000 Subject: [PATCH 34/45] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/test_fusible_ops.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 08fddc0b65..b1dd7b69eb 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -3199,7 +3199,6 @@ def to_cpu(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]: torch.testing.assert_close(to_cpu(ffn1.bias.grad), b1_ref.grad, **tols) torch.testing.assert_close(to_cpu(ffn2.bias.grad), b2_ref.grad, **tols) - @pytest.mark.parametrize("bias", (False, True)) @pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("quantization", _quantization_list) From ba28c6f89355ac269e1b2425efb43c77fa3515e7 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Sun, 25 Jan 2026 00:50:55 +0000 Subject: [PATCH 35/45] Fix linter warnings Signed-off-by: Tim Moon --- transformer_engine/pytorch/ops/fused/__init__.py | 7 +++++-- .../pytorch/ops/fused/forward_grouped_mlp.py | 8 +++----- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/transformer_engine/pytorch/ops/fused/__init__.py b/transformer_engine/pytorch/ops/fused/__init__.py index 190234bac1..52cda9caf6 100644 --- a/transformer_engine/pytorch/ops/fused/__init__.py +++ b/transformer_engine/pytorch/ops/fused/__init__.py @@ -29,5 +29,8 @@ register_backward_fusion(BackwardActivationBias.fuse_backward_ops) register_backward_fusion(BackwardAddRMSNorm.fuse_backward_ops) -# Import specialized fusions -from .forward_grouped_mlp import ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8 +# Import experimental fusions +# Note: Registration logic is non-trivial, so submodule handles it internally. +from .forward_grouped_mlp import ( # pylint: disable=wrong-import-position + ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8, +) diff --git a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py index fec9d6789e..617da5bb52 100644 --- a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py @@ -14,7 +14,7 @@ import transformer_engine_torch as tex from ...cpp_extensions import general_grouped_gemm -from ...quantization import FP8GlobalStateManager +from ...quantization import Recipe from ...tensor import MXFP8Tensor, Quantizer from ...utils import get_device_compute_capability from ..basic import GroupedLinear, ScaledSwiGLU @@ -34,7 +34,7 @@ class ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8(FusedOperation): @functools.lru_cache(maxsize=None) def grouped_gemm_swiglu_kernel(cls) -> Callable: """Fused kernel for grouped GEMM, SwiGLU, and post-multiplication.""" - from cudnn import grouped_gemm_swiglu_wrapper_sm100 + from cudnn import grouped_gemm_swiglu_wrapper_sm100 # pylint: disable=no-name-in-module return grouped_gemm_swiglu_wrapper_sm100 @@ -102,7 +102,7 @@ def fuser_forward( ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: # Get basic operations - fc1_op, swiglu_op, fc2_op = self.basic_ops + fc1_op, _, fc2_op = self.basic_ops fc1_ctx, swiglu_ctx, fc2_ctx = basic_op_ctxs # Tensor properties @@ -173,13 +173,11 @@ def fuser_forward( fc2_ws = [] for w, quantizer in zip(fc1_weights, fc1_weight_quantizers): if not is_quantized_tensor(w): - quantizer = weight_quantizers[group_idx] quantizer.set_usage(rowwise=True, columnwise=input_requires_grad) w = quantizer(w) fc1_ws.append(w) for w, quantizer in zip(fc2_weights, fc2_weight_quantizers): if not is_quantized_tensor(w): - quantizer = weight_quantizers[group_idx] quantizer.set_usage(rowwise=True, columnwise=input_requires_grad) w = quantizer(w) fc2_ws.append(w) From 575da6e9928c1f8af2ce40d9b35366f9d09c8ef9 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Mon, 26 Jan 2026 21:10:27 +0000 Subject: [PATCH 36/45] Review suggestions from @greptile-apps Signed-off-by: Tim Moon --- transformer_engine/pytorch/ops/basic/grouped_linear.py | 5 ++--- transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py | 4 ++-- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/transformer_engine/pytorch/ops/basic/grouped_linear.py b/transformer_engine/pytorch/ops/basic/grouped_linear.py index ed8d5c6012..4c2ad3a00f 100644 --- a/transformer_engine/pytorch/ops/basic/grouped_linear.py +++ b/transformer_engine/pytorch/ops/basic/grouped_linear.py @@ -40,8 +40,8 @@ class GroupedLinear(BasicOperation): dimension, applying a separate ``torch.nn.Linear`` to each split, and concatenating along the first dimension. - Paramters - --------- + Parameters + ---------- group_size : int Number of linear transformations. in_features : int @@ -416,7 +416,6 @@ def fuser_forward( if not with_quantized_compute: w = maybe_dequantize(w, dtype) elif with_quantized_compute and not is_quantized_tensor(w): - quantizer = weight_quantizers[group_idx] quantizer.set_usage(rowwise=True, columnwise=input_requires_grad) w = quantizer(w) ws.append(w) diff --git a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py index 617da5bb52..279ab1fcc0 100644 --- a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py @@ -65,12 +65,12 @@ def __init__( if not self.is_supported(): self.grouped_gemm_swiglu_kernel() # Try triggering import error raise RuntimeError(f"{self.__class__.__name__} is not supported on this system.") - if fc1.in_features % 256 != 0 or fc1.in_features % 256 != 0: + if fc1.in_features % 256 != 0 or fc1.out_features % 256 != 0: raise ValueError( f"Unsupported dims for FC1 (group_size={fc1.group_size}, " f"in_features={fc1.in_features}, out_features={fc1.out_features})." ) - if fc2.in_features % 256 != 0 or fc2.in_features % 256 != 0: + if fc2.in_features % 256 != 0 or fc2.out_features % 256 != 0: raise ValueError( f"Unsupported dims for FC2 (group_size={fc2.group_size}, " f"in_features={fc2.in_features}, out_features={fc2.out_features})." From 46294be478f6551e2cf251283adc7529ddb2964e Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Mon, 26 Jan 2026 14:34:44 -0800 Subject: [PATCH 37/45] Apply suggestion from @greptile-apps[bot] Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- transformer_engine/pytorch/ops/basic/grouped_linear.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/ops/basic/grouped_linear.py b/transformer_engine/pytorch/ops/basic/grouped_linear.py index 4c2ad3a00f..0a16ffe57d 100644 --- a/transformer_engine/pytorch/ops/basic/grouped_linear.py +++ b/transformer_engine/pytorch/ops/basic/grouped_linear.py @@ -317,11 +317,11 @@ def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None: weight_quantizer = self.get_quantizer("forward", 2 * group_idx + 1) if weight_quantizer is None: pass - elif is_quantized_tensor(getattr(self, "weight", None)): + elif is_quantized_tensor(getattr(self, f"weight{group_idx}", None)): # Make sure weight param has correct quantizer weight_quantizer.set_usage(rowwise=True, columnwise=torch.is_grad_enabled()) weight_quantizer.internal = False - self.weight.update_quantizer(weight_quantizer.copy()) + getattr(self, f"weight{group_idx}").update_quantizer(weight_quantizer.copy()) else: # Use internal tensors if quantized weights will not be # exposed externally From fccb0bb726648eb31d94a40e2cf6d24476c33ef9 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Thu, 29 Jan 2026 02:50:12 +0000 Subject: [PATCH 38/45] Tweak variable names Signed-off-by: Tim Moon --- .../pytorch/ops/basic/grouped_linear.py | 76 +++++++++---------- .../pytorch/ops/fused/forward_grouped_mlp.py | 54 ++++++------- 2 files changed, 65 insertions(+), 65 deletions(-) diff --git a/transformer_engine/pytorch/ops/basic/grouped_linear.py b/transformer_engine/pytorch/ops/basic/grouped_linear.py index 0a16ffe57d..ac6b5665e3 100644 --- a/transformer_engine/pytorch/ops/basic/grouped_linear.py +++ b/transformer_engine/pytorch/ops/basic/grouped_linear.py @@ -42,7 +42,7 @@ class GroupedLinear(BasicOperation): Parameters ---------- - group_size : int + num_groups : int Number of linear transformations. in_features : int Inner dimension of input tensor. @@ -74,7 +74,7 @@ class GroupedLinear(BasicOperation): def __init__( self, - group_size: int, + num_groups: int, in_features: int, out_features: int, *, @@ -87,11 +87,11 @@ def __init__( super().__init__() # Weight tensor dimensions - self.group_size: int = group_size + self.num_groups: int = num_groups self.in_features: int = in_features self.out_features: int = out_features - if self.group_size <= 0: - raise ValueError(f"Invalid group size ({self.group_size})") + if self.num_groups <= 0: + raise ValueError(f"Invalid number of groups ({self.num_groups})") if self.in_features <= 0: raise ValueError(f"Invalid input size ({self.in_features})") if self.out_features <= 0: @@ -114,7 +114,7 @@ def __init__( # Register weights self.weight0: torch.nn.Parameter - for group_idx in range(self.group_size): + for group_idx in range(self.num_groups): weight_tensor = torch.empty( self.out_features, self.in_features, @@ -128,7 +128,7 @@ def __init__( # Register biases self.bias0: Optional[torch.nn.Parameter] - for group_idx in range(self.group_size): + for group_idx in range(self.num_groups): bias_tensor = None if bias: bias_tensor = torch.empty( @@ -148,9 +148,9 @@ def __init__( def num_quantizers(self, mode: str) -> int: if mode == "forward": - return 2 * self.group_size + return 2 * self.num_groups if mode == "backward": - return self.group_size + return self.num_groups return 0 @property @@ -167,7 +167,7 @@ def reset_parameters(self) -> None: device = canonicalize_device(None) # Initialize weights - for group_idx in range(self.group_size): + for group_idx in range(self.num_groups): weight = getattr(self, f"weight{group_idx}") # Allocate buffers if needed @@ -214,7 +214,7 @@ def reset_parameters(self) -> None: # Initialize biases if needed if self.bias0 is not None: with torch.no_grad(): - for group_idx in range(self.group_size): + for group_idx in range(self.num_groups): bias = getattr(self, f"bias{group_idx}") if not devices_match(bias.device, device): bias = torch.empty_like(bias, device=device) @@ -235,7 +235,7 @@ def pre_first_fuser_forward(self) -> None: device = self.weight0.device weight_requires_grad = self.weight0.requires_grad weight_tensor_type = type(self.weight0.data) - for group_idx in range(self.group_size): + for group_idx in range(self.num_groups): weight = getattr(self, f"weight{group_idx}") if weight.dtype != dtype: raise RuntimeError( @@ -259,7 +259,7 @@ def pre_first_fuser_forward(self) -> None: ) # Check that biases are consistent - for group_idx in range(self.group_size): + for group_idx in range(self.num_groups): bias = getattr(self, f"bias{group_idx}") if self.has_bias: if bias is None: @@ -291,7 +291,7 @@ def pre_fuser_forward(self, *, requires_grad: bool) -> None: # Configure quantizer usages # Note: We cache the quantized input for backward pass, # but discard the quantized weights. - for group_idx in range(self.group_size): + for group_idx in range(self.num_groups): input_quantizer = self.get_quantizer("forward", 2 * group_idx) weight_quantizer = self.get_quantizer("forward", 2 * group_idx + 1) grad_output_quantizer = self.get_quantizer("backward", group_idx) @@ -302,7 +302,7 @@ def pre_fuser_forward(self, *, requires_grad: bool) -> None: def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None: super().reset_recipe_state(recipe=recipe) - for group_idx in range(self.group_size): + for group_idx in range(self.num_groups): # Input/grad output quantizers use internal tensors input_quantizer = self.get_quantizer("forward", 2 * group_idx) grad_output_quantizer = self.get_quantizer("backward", group_idx) @@ -372,7 +372,7 @@ def fuser_forward( next_op_input_quantizer: Optional[Quantizer], basic_op_kwargs: list[dict[str, Any]], ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: - group_size = self.group_size + num_groups = self.num_groups has_bias = self.has_bias device = self.weight0.device @@ -382,12 +382,12 @@ def fuser_forward( weight_requires_grad = ctx.requires_grad and self.weight0.requires_grad # Quantizers - input_quantizers = [None] * group_size - weight_quantizers = [None] * group_size - grad_output_quantizers = [None] * group_size + input_quantizers = [None] * num_groups + weight_quantizers = [None] * num_groups + grad_output_quantizers = [None] * num_groups with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() if with_quantized_compute: - for group_idx in range(group_size): + for group_idx in range(num_groups): input_quantizers[group_idx] = self.get_quantizer("forward", 2 * group_idx) weight_quantizers[group_idx] = self.get_quantizer("forward", 2 * group_idx + 1) grad_output_quantizers[group_idx] = self.get_quantizer("backward", group_idx) @@ -401,14 +401,14 @@ def fuser_forward( # Extract split sizes from extra input split_sizes = basic_op_extra_inputs[0][0] split_sizes_int = [int(s) for s in split_sizes.tolist()] - if len(split_sizes_int) != group_size: - raise ValueError(f"Expected {group_size} splits, but got {len(split_sizes_int)}.") + if len(split_sizes_int) != num_groups: + raise ValueError(f"Expected {num_groups} splits, but got {len(split_sizes_int)}.") # Extract params - weights = [getattr(self, f"weight{idx}") for idx in range(group_size)] + weights = [getattr(self, f"weight{idx}") for idx in range(num_groups)] bs = None if has_bias: - bs = [maybe_dequantize(getattr(self, f"bias{idx}"), dtype) for idx in range(group_size)] + bs = [maybe_dequantize(getattr(self, f"bias{idx}"), dtype) for idx in range(num_groups)] # Convert weight dtype if needed ws = [] @@ -440,7 +440,7 @@ def fuser_forward( ws, xs, [out], - [None] * group_size, # quantization_params + [None] * num_groups, # quantization_params dtype, m_splits=split_sizes_int, bias=bs, @@ -451,7 +451,7 @@ def fuser_forward( # Prepare weight tensors for backward pass if not input_requires_grad: - ws = [None] * group_size + ws = [None] * num_groups elif with_quantized_compute: for w, weight_param in zip(ws, weights): if w is not weight_param: @@ -459,7 +459,7 @@ def fuser_forward( # Prepare input tensor for backward pass if not weight_requires_grad: - xs = [None] * group_size + xs = [None] * num_groups elif with_quantized_compute: for x in xs: x.update_usage(rowwise_usage=False, columnwise_usage=True) @@ -489,7 +489,7 @@ def fuser_backward( Iterable[Iterable[Optional[torch.Tensor]]], Iterable[Iterable[Optional[torch.Tensor]]], ]: - group_size = self.group_size + num_groups = self.num_groups has_bias = self.has_bias device = self.weight0.device @@ -497,14 +497,14 @@ def fuser_backward( ctx = basic_op_ctxs[0] saved_tensors = ctx.saved_tensors split_sizes, saved_tensors = saved_tensors[0], saved_tensors[1:] - xs, saved_tensors = saved_tensors[:group_size], saved_tensors[group_size:] - ws, saved_tensors = saved_tensors[:group_size], saved_tensors[group_size:] + xs, saved_tensors = saved_tensors[:num_groups], saved_tensors[num_groups:] + ws, saved_tensors = saved_tensors[:num_groups], saved_tensors[num_groups:] # Split grad output tensor and convert dtypes if needed split_sizes_int = [int(s) for s in split_sizes.tolist()] dy = maybe_dequantize(grad_output, ctx.dtype) dys = None - grad_biases = [None] * group_size + grad_biases = [None] * num_groups if ctx.with_quantized_compute: for quantizer in ctx.grad_output_quantizers: quantizer.set_usage( @@ -524,13 +524,13 @@ def fuser_backward( # Initialize grad weight grads accumulate_into_main_grad = self._accumulate_into_main_grad - grad_weights = [None] * group_size + grad_weights = [None] * num_groups if ctx.weight_requires_grad: if accumulate_into_main_grad: # Megatron-LM wgrad fusion # Note: Get grad tensors from params so we can # accumulate directly into it. - for group_idx in range(group_size): + for group_idx in range(num_groups): weight_param = getattr(self, f"weight{group_idx}") if hasattr(weight_param, "__fsdp_param__"): weight_param.main_grad = weight_param.get_main_grad() @@ -540,7 +540,7 @@ def fuser_backward( grad_weights[group_idx] = weight_param.main_grad else: weight_shape = ws[0].size() - for group_idx in range(group_size): + for group_idx in range(num_groups): grad_weights[group_idx] = torch.empty( weight_shape, dtype=ctx.dtype, @@ -563,7 +563,7 @@ def fuser_backward( ws, dys, [grad_input], - [None] * group_size, # quantization_params + [None] * num_groups, # quantization_params ctx.dtype, layout="NN", m_splits=split_sizes_int, @@ -577,7 +577,7 @@ def fuser_backward( xs, dys, grad_weights, - [None] * group_size, # quantization_params + [None] * num_groups, # quantization_params ctx.dtype, layout="NT", m_splits=split_sizes_int, @@ -591,8 +591,8 @@ def fuser_backward( # Megatron-LM wgrad fusion # Note: Return dummy tensor for grad weight if needed. if accumulate_into_main_grad: - grad_weights = [None] * group_size - for group_idx in range(group_size): + grad_weights = [None] * num_groups + for group_idx in range(num_groups): weight_param = getattr(self, f"weight{group_idx}") if hasattr(weight_param, "grad_added_to_main_grad"): weight_param.grad_added_to_main_grad = True diff --git a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py index 279ab1fcc0..19901cb4af 100644 --- a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py @@ -2,7 +2,7 @@ # # See LICENSE for license information. -"""Fused operation for forward GEMM + scale + add.""" +"""Fused operation for MoE grouped MLP.""" from __future__ import annotations from collections.abc import Callable, Iterable @@ -67,19 +67,19 @@ def __init__( raise RuntimeError(f"{self.__class__.__name__} is not supported on this system.") if fc1.in_features % 256 != 0 or fc1.out_features % 256 != 0: raise ValueError( - f"Unsupported dims for FC1 (group_size={fc1.group_size}, " + f"Unsupported dims for FC1 (num_groups={fc1.num_groups}, " f"in_features={fc1.in_features}, out_features={fc1.out_features})." ) if fc2.in_features % 256 != 0 or fc2.out_features % 256 != 0: raise ValueError( - f"Unsupported dims for FC2 (group_size={fc2.group_size}, " + f"Unsupported dims for FC2 (num_groups={fc2.num_groups}, " f"in_features={fc2.in_features}, out_features={fc2.out_features})." ) - if fc1.out_features != 2 * fc2.in_features or fc1.group_size != fc2.group_size: + if fc1.out_features != 2 * fc2.in_features or fc1.num_groups != fc2.num_groups: raise ValueError( - f"FC1 (group_size={fc1.group_size}, in_features={fc1.in_features}, " + f"FC1 (num_groups={fc1.num_groups}, in_features={fc1.in_features}, " f"out_features={fc1.out_features}) " - f"and FC2 (group_size={fc2.group_size}, in_features={fc2.in_features}, " + f"and FC2 (num_groups={fc2.num_groups}, in_features={fc2.in_features}, " f"out_features={fc2.out_features}) do not match." ) if fc1.has_bias or fc2.has_bias: @@ -110,7 +110,7 @@ def fuser_forward( assert len(in_shape) == 2, f"Expected 2D input tensor, got shape={in_shape}." fc1_weight_shape = (fc1_op.out_features, fc1_op.in_features) fc2_weight_shape = (fc2_op.out_features, fc2_op.in_features) - group_size = fc1_op.group_size + num_groups = fc1_op.num_groups device = fc1_op.weight0.device if torch.is_autocast_enabled(): dtype = torch.get_autocast_dtype("cuda") @@ -125,13 +125,13 @@ def fuser_forward( ) # Quantizers - fc1_input_quantizers = [None] * group_size - fc1_weight_quantizers = [None] * group_size - fc1_grad_output_quantizers = [None] * group_size - fc2_input_quantizers = [None] * group_size - fc2_weight_quantizers = [None] * group_size - fc2_grad_output_quantizers = [None] * group_size - for idx in range(group_size): + fc1_input_quantizers = [None] * num_groups + fc1_weight_quantizers = [None] * num_groups + fc1_grad_output_quantizers = [None] * num_groups + fc2_input_quantizers = [None] * num_groups + fc2_weight_quantizers = [None] * num_groups + fc2_grad_output_quantizers = [None] * num_groups + for idx in range(num_groups): fc1_input_quantizers[idx] = fc1_op.get_quantizer("forward", 2 * idx) fc1_weight_quantizers[idx] = fc1_op.get_quantizer("forward", 2 * idx + 1) fc1_grad_output_quantizers[idx] = fc1_op.get_quantizer("backward", idx) @@ -151,8 +151,8 @@ def fuser_forward( ) split_sizes = fc1_split_sizes split_sizes_cpu = [int(s) for s in split_sizes.tolist()] - if len(split_sizes_cpu) != group_size: - raise ValueError(f"Expected {group_size} splits, but got {len(split_sizes_cpu)}.") + if len(split_sizes_cpu) != num_groups: + raise ValueError(f"Expected {num_groups} splits, but got {len(split_sizes_cpu)}.") split_sizes = split_sizes.to(dtype=torch.int, device=device) split_points = torch.zeros( split_sizes.numel() + 1, @@ -165,8 +165,8 @@ def fuser_forward( scales = basic_op_extra_inputs[1][0] # Extract params - fc1_weights = [getattr(fc1_op, f"weight{idx}") for idx in range(group_size)] - fc2_weights = [getattr(fc2_op, f"weight{idx}") for idx in range(group_size)] + fc1_weights = [getattr(fc1_op, f"weight{idx}") for idx in range(num_groups)] + fc2_weights = [getattr(fc2_op, f"weight{idx}") for idx in range(num_groups)] # Convert weight dtype if needed fc1_ws = [] @@ -210,19 +210,19 @@ def fuser_forward( fc1_w_data = torch.stack([w._rowwise_data for w in fc1_weights]) fc1_w_data = fc1_w_data.view(dtype=torch.float8_e4m3fn) fc1_w_data = fc1_w_data.view( - group_size, fc1_weight_shape[0] // 64, 2, 32, fc1_weight_shape[1] + num_groups, fc1_weight_shape[0] // 64, 2, 32, fc1_weight_shape[1] ) fc1_w_data = fc1_w_data.flip(2).contiguous() # Swap SwiGLU gate/activation - fc1_w_data = fc1_w_data.view(group_size, fc1_weight_shape[0], fc1_weight_shape[1]) + fc1_w_data = fc1_w_data.view(num_groups, fc1_weight_shape[0], fc1_weight_shape[1]) fc1_w_data = fc1_w_data.permute(1, 2, 0) fc1_w_scales = torch.stack([w._rowwise_scale_inv for w in fc1_weights]) fc1_w_scales = fc1_w_scales.view(dtype=torch.float8_e8m0fnu) fc1_w_scales = fc1_w_scales.view( - group_size, fc1_weight_shape[0] // 64, 2, 32, fc1_weight_shape[1] // 32 + num_groups, fc1_weight_shape[0] // 64, 2, 32, fc1_weight_shape[1] // 32 ) fc1_w_scales = fc1_w_scales.flip(2).contiguous() # Swap SwiGLU gate/activation fc1_w_scales = fc1_w_scales.view( - group_size, fc1_weight_shape[0] // 128, 4, 32, fc1_weight_shape[1] // 128, 4 + num_groups, fc1_weight_shape[0] // 128, 4, 32, fc1_weight_shape[1] // 128, 4 ) fc1_w_scales = fc1_w_scales.permute( 0, 1, 4, 3, 2, 5 @@ -259,7 +259,7 @@ def fuser_forward( fc1_w_scales, tile_idx_to_expert_idx, num_non_exiting_tiles, - torch.ones(group_size, dtype=dtype, device=device), # alpha_tensor + torch.ones(num_groups, dtype=dtype, device=device), # alpha_tensor torch.ones(1, dtype=dtype, device=device), # norm_const_tensor scales.detach().reshape(-1, 1, 1), split_points, @@ -297,7 +297,7 @@ def fuser_forward( # Construct MXFP8 tensors for FC2 fc2_xs = [] - for group_idx in range(group_size): + for group_idx in range(num_groups): x = MXFP8Tensor( shape=(split_sizes_cpu[group_idx], fc2_weight_shape[1]), dtype=dtype, @@ -319,10 +319,10 @@ def fuser_forward( fc2_ws, fc2_xs, [fc2_out], - [None] * group_size, # quantization_params + [None] * num_groups, # quantization_params dtype, m_splits=split_sizes_cpu, - bias=[None] * group_size, + bias=[None] * num_groups, use_bias=False, single_output=True, ) @@ -411,7 +411,7 @@ def fuse_forward_ops( matches_pattern = False elif window[0].has_bias or window[2].has_bias: matches_pattern = False - elif window[0].group_size != window[2].group_size: + elif window[0].num_groups != window[2].num_groups: matches_pattern = False elif ( window[0].in_features % 256 != 0 From 4259e275710a371c6137ab0f923d2e93dbd46de5 Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Thu, 29 Jan 2026 19:08:25 -0800 Subject: [PATCH 39/45] Fix f-strings Review suggestions from @greptile-apps Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- transformer_engine/pytorch/ops/basic/grouped_linear.py | 4 ++-- transformer_engine/pytorch/ops/basic/swiglu.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/transformer_engine/pytorch/ops/basic/grouped_linear.py b/transformer_engine/pytorch/ops/basic/grouped_linear.py index ac6b5665e3..3a8c21c625 100644 --- a/transformer_engine/pytorch/ops/basic/grouped_linear.py +++ b/transformer_engine/pytorch/ops/basic/grouped_linear.py @@ -348,7 +348,7 @@ def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None: def op_forward(self, *args, **kwargs): raise RuntimeError( - "{self.__class__.__name__} operation has " + f"{self.__class__.__name__} operation has " f"{self.num_extra_inputs} extra tensor inputs " f"and {self.num_extra_outputs} extra tensor outputs. " "It overrides `fuser_forward` instead of `op_forward`." @@ -356,7 +356,7 @@ def op_forward(self, *args, **kwargs): def op_backward(self, *args, **kwargs): raise RuntimeError( - "{self.__class__.__name__} operation has " + f"{self.__class__.__name__} operation has " f"{self.num_extra_inputs} extra tensor inputs " f"and {self.num_extra_outputs} extra tensor outputs. " "It overrides `fuser_backward` instead of `op_backward`." diff --git a/transformer_engine/pytorch/ops/basic/swiglu.py b/transformer_engine/pytorch/ops/basic/swiglu.py index 8f1068ba46..f8917db033 100644 --- a/transformer_engine/pytorch/ops/basic/swiglu.py +++ b/transformer_engine/pytorch/ops/basic/swiglu.py @@ -285,7 +285,7 @@ def __init__(self, glu_interleave_size: Optional[int] = None): def op_forward(self, *args, **kwargs) -> None: raise RuntimeError( - "{self.__class__.__name__} operation has " + f"{self.__class__.__name__} operation has " f"{self.num_extra_inputs} extra tensor inputs " f"and {self.num_extra_outputs} extra tensor outputs. " "It overrides `fuser_forward` instead of `op_forward`." @@ -293,7 +293,7 @@ def op_forward(self, *args, **kwargs) -> None: def op_backward(self, *args, **kwargs) -> None: raise RuntimeError( - "{self.__class__.__name__} operation has " + f"{self.__class__.__name__} operation has " f"{self.num_extra_inputs} extra tensor inputs " f"and {self.num_extra_outputs} extra tensor outputs. " "It overrides `fuser_backward` instead of `op_backward`." From 2442d34e3b057b527ea5ed4a3907f5a785c6a405 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Fri, 30 Jan 2026 04:15:24 +0000 Subject: [PATCH 40/45] Fix bug when grouped MLP is not being trained Signed-off-by: Tim Moon --- .../pytorch/ops/fused/forward_grouped_mlp.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py index 19901cb4af..7d27074edc 100644 --- a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py @@ -328,8 +328,12 @@ def fuser_forward( ) # Prepare input tensors for backward pass - for x in itertools.chain(fc1_xs, fc2_xs): - x.update_usage(rowwise_usage=False, columnwise_usage=True) + if not weight_requires_grad: + fc1_xs = [None] * num_groups + fc2_xs = [None] * num_groups + else: + for x in itertools.chain(fc1_xs, fc2_xs): + x.update_usage(rowwise_usage=False, columnwise_usage=True) # Save state for backward pass if requires_grad: From a7351e5b1b7841d7b315ca989ca359b57b1c9d77 Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Fri, 30 Jan 2026 17:41:47 -0800 Subject: [PATCH 41/45] Fix f-string Review suggestion from @greptile-apps Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py index 7d27074edc..45d58c1d7b 100644 --- a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py @@ -87,7 +87,7 @@ def __init__( if swiglu.glu_interleave_size != 32: raise ValueError( "Fused kernel requires 32-wide GLU interleaving, " - "but got glu_interleave_size={swiglu.glu_interleave_size}." + f"but got glu_interleave_size={swiglu.glu_interleave_size}." ) def fuser_forward( From 9be1c49ca0770b45d3a81764fbb8308fe70ef716 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Sat, 31 Jan 2026 02:55:15 +0000 Subject: [PATCH 42/45] Replace explicit concat with optional concat Signed-off-by: Tim Moon --- transformer_engine/pytorch/module/_common.py | 19 ++++++++++++------- .../pytorch/ops/fused/forward_grouped_mlp.py | 9 +++++---- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/transformer_engine/pytorch/module/_common.py b/transformer_engine/pytorch/module/_common.py index 88b58a353a..e39a53bd38 100644 --- a/transformer_engine/pytorch/module/_common.py +++ b/transformer_engine/pytorch/module/_common.py @@ -125,14 +125,19 @@ def forward( return torch.cat(tensors, dim=dim) data_ptr += tensor.size(dim) * data_ptr_stride + # Out-of-place concatenation when view tensors have different storage + # Note: This works around an edge case with the split_quantize + # function, which might allocate a buffer and construct + # subviews. However, in order to reduce CPU overheads, these + # views are configured manually outside of PyTorch. PyTorch + # doesn't know these views share the same memory, and it + # blocks us from reconstructing the full tensor because it + # thinks we are accessing out-of-bounds memory. + if tensors[0].untyped_storage().nbytes() < out_shape[dim] * data_ptr_stride: + return torch.cat(tensors, dim=dim) + # No-op concatenation - out = tensors[0].new() - out.set_( - tensors[0].untyped_storage(), - tensors[0].storage_offset(), - out_shape, - strides, - ) + out = tensors[0].as_strided(out_shape, strides) out.requires_grad = any(tensor.requires_grad for tensor in tensors) return out diff --git a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py index 45d58c1d7b..c544ac6420 100644 --- a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py @@ -14,6 +14,7 @@ import transformer_engine_torch as tex from ...cpp_extensions import general_grouped_gemm +from ...module._common import noop_cat from ...quantization import Recipe from ...tensor import MXFP8Tensor, Quantizer from ...utils import get_device_compute_capability @@ -191,10 +192,10 @@ def fuser_forward( fc1_xs = tex.split_quantize(fc1_x, split_sizes_cpu, fc1_input_quantizers) # Pack data tensors - fc1_x_data = torch.cat([x._rowwise_data for x in fc1_xs]) + fc1_x_data = noop_cat([x._rowwise_data for x in fc1_xs]) fc1_x_data = fc1_x_data.view(dtype=torch.float8_e4m3fn) fc1_x_data = fc1_x_data.unsqueeze(0).permute(1, 2, 0) - fc1_x_scales = torch.cat([x._rowwise_scale_inv for x in fc1_xs]) + fc1_x_scales = noop_cat([x._rowwise_scale_inv for x in fc1_xs]) fc1_x_scales = fc1_x_scales.view(dtype=torch.float8_e8m0fnu) fc1_x_scales = fc1_x_scales.view( 1, @@ -207,7 +208,7 @@ def fuser_forward( fc1_x_scales = fc1_x_scales.permute(3, 4, 1, 5, 2, 0) # Pack weight tensors - fc1_w_data = torch.stack([w._rowwise_data for w in fc1_weights]) + fc1_w_data = noop_cat([w._rowwise_data for w in fc1_weights]) fc1_w_data = fc1_w_data.view(dtype=torch.float8_e4m3fn) fc1_w_data = fc1_w_data.view( num_groups, fc1_weight_shape[0] // 64, 2, 32, fc1_weight_shape[1] @@ -215,7 +216,7 @@ def fuser_forward( fc1_w_data = fc1_w_data.flip(2).contiguous() # Swap SwiGLU gate/activation fc1_w_data = fc1_w_data.view(num_groups, fc1_weight_shape[0], fc1_weight_shape[1]) fc1_w_data = fc1_w_data.permute(1, 2, 0) - fc1_w_scales = torch.stack([w._rowwise_scale_inv for w in fc1_weights]) + fc1_w_scales = noop_cat([w._rowwise_scale_inv for w in fc1_weights]) fc1_w_scales = fc1_w_scales.view(dtype=torch.float8_e8m0fnu) fc1_w_scales = fc1_w_scales.view( num_groups, fc1_weight_shape[0] // 64, 2, 32, fc1_weight_shape[1] // 32 From 2c6d6be675e4cace07fc387681d8d836f89ebfb8 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Tue, 3 Feb 2026 00:30:05 +0000 Subject: [PATCH 43/45] Add comments for CuTe DSL expected tensor shapes Signed-off-by: Tim Moon --- .../pytorch/ops/fused/forward_grouped_mlp.py | 26 ++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py index c544ac6420..06fe3fca93 100644 --- a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py @@ -192,6 +192,14 @@ def fuser_forward( fc1_xs = tex.split_quantize(fc1_x, split_sizes_cpu, fc1_input_quantizers) # Pack data tensors + # Note: Fused kernel expects tensor with non-contiguous + # logical dims. + # Data actual shape: (1, sum(m), k) + # Scale actual shape: (1, sum(m)/128, k/128, 32 (block row), + # 4 (block row), 4 (block col)) + # Data logical shape: (sum(m), k, 1) + # Scale logical shape: (32 (block row), 4 (block row), + # sum(m)/128, 4 (block col), k/128, 1) fc1_x_data = noop_cat([x._rowwise_data for x in fc1_xs]) fc1_x_data = fc1_x_data.view(dtype=torch.float8_e4m3fn) fc1_x_data = fc1_x_data.unsqueeze(0).permute(1, 2, 0) @@ -208,6 +216,14 @@ def fuser_forward( fc1_x_scales = fc1_x_scales.permute(3, 4, 1, 5, 2, 0) # Pack weight tensors + # Note: Fused kernel expects tensor with non-contiguous + # logical dims. + # Data actual shape: (num_groups, n, k) + # Scale actual shape: (num_groups, n/128, k/128, 32 (block row), + # 4 (block row), 4 (block col)) + # Data logical shape: (n, k, num_groups) + # Scale logical shape: (32 (block row), 4 (block row), n/128, + # 4 (block col), k/128, num_groups) fc1_w_data = noop_cat([w._rowwise_data for w in fc1_weights]) fc1_w_data = fc1_w_data.view(dtype=torch.float8_e4m3fn) fc1_w_data = fc1_w_data.view( @@ -224,7 +240,7 @@ def fuser_forward( fc1_w_scales = fc1_w_scales.flip(2).contiguous() # Swap SwiGLU gate/activation fc1_w_scales = fc1_w_scales.view( num_groups, fc1_weight_shape[0] // 128, 4, 32, fc1_weight_shape[1] // 128, 4 - ) + ) # Unswizzled layout fc1_w_scales = fc1_w_scales.permute( 0, 1, 4, 3, 2, 5 ).contiguous() # Convert to swizzled layout @@ -274,6 +290,14 @@ def fuser_forward( ) # Unpack kernel outputs + # Note: Fused kernel outputs tensors with non-contiguous + # logical dims. + # Row-wise data logical shape: (sum(m), k, 1) + # Row-wise scale logical shape: (32 (block row), 4 (block row), + # sum(m)/128, 4 (block col), k/128, 1) + # Column-wise data logical shape: (sum(m), k, 1) + # Column-wise scale logical shape: (32 (block col), 4 (block col), + # k/128, 4 (block row), sum(m)/128, 1) swiglu_in = fc1_kernel_out["c_tensor"] swiglu_in = swiglu_in.permute(2, 0, 1) swiglu_in = swiglu_in.view(in_shape[0], fc1_weight_shape[0] // 64, 2, 32) From 65cb77d09615581327efb1da7888ac12c64e27b4 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Wed, 4 Feb 2026 03:38:34 +0000 Subject: [PATCH 44/45] Support contiguous weights in grouped linear op Signed-off-by: Tim Moon --- .../pytorch/ops/basic/grouped_linear.py | 172 ++++++++++++++---- .../pytorch/ops/fused/forward_grouped_mlp.py | 28 ++- 2 files changed, 147 insertions(+), 53 deletions(-) diff --git a/transformer_engine/pytorch/ops/basic/grouped_linear.py b/transformer_engine/pytorch/ops/basic/grouped_linear.py index 3a8c21c625..2fda7d9abc 100644 --- a/transformer_engine/pytorch/ops/basic/grouped_linear.py +++ b/transformer_engine/pytorch/ops/basic/grouped_linear.py @@ -5,7 +5,7 @@ """Fusible operation for bias.""" from __future__ import annotations -from collections.abc import Callable, Iterable +from collections.abc import Callable, Iterable, Sequence import contextlib import math from typing import Any, Optional @@ -22,12 +22,13 @@ get_dummy_wgrad, ) from ...quantization import FP8GlobalStateManager, Recipe -from ...tensor import Quantizer +from ...tensor import MXFP8Quantizer, MXFP8Tensor, Quantizer from ...utils import ( canonicalize_device, canonicalize_dtype, clear_tensor_data, devices_match, + round_up_to_nearest_multiple, ) from .._common import is_quantized_tensor, maybe_dequantize from ..op import BasicOperation, OperationContext @@ -118,7 +119,7 @@ def __init__( weight_tensor = torch.empty( self.out_features, self.in_features, - device=device, + device="meta", dtype=dtype, ) self.register_parameter( @@ -133,7 +134,7 @@ def __init__( if bias: bias_tensor = torch.empty( self.out_features, - device=device, + device="meta", dtype=dtype, ) bias_tensor = torch.nn.Parameter(bias_tensor) @@ -166,30 +167,35 @@ def reset_parameters(self) -> None: if device.type == "meta": device = canonicalize_device(None) - # Initialize weights - for group_idx in range(self.num_groups): - weight = getattr(self, f"weight{group_idx}") - - # Allocate buffers if needed - if is_quantized_tensor(weight): - weight = torch.empty( - weight.size(), - dtype=weight.dtype, - device=device, - ) - elif not devices_match(weight.device, device): - weight = torch.empty_like(weight, device=device) - - # Initialize values + # Initialize weight values + # Note: Allocate a single buffer in order to support grouped + # GEMM kernels that expect a single weight buffer. + packed_weights = torch.empty( + self.num_groups, + self.out_features, + self.in_features, + dtype=self.weight0.dtype, + device=device, + ) + weights = [packed_weights[idx] for idx in range(self.num_groups)] + for weight in weights: init_context = contextlib.nullcontext() if self._rng_state_tracker_function is not None: init_context = self._rng_state_tracker_function().fork() with init_context: torch.nn.init.kaiming_uniform_(weight, a=math.sqrt(5)) - # Quantize weight if needed - if self._with_quantized_weight: - quantizer = self.get_quantizer("forward", 2 * group_idx + 1) + # Quantize weights if needed + if self._with_quantized_weight: + + # Configure quantizers + quantizers = [ + self.get_quantizer("forward", 2 * idx + 1) + for idx in range(self.num_groups) + ] + with_rowwise_usage = True + with_columnwise_usage = torch.is_grad_enabled() + for quantizer in quantizers: if quantizer is None: raise RuntimeError( "Tried to quantize weight with deferred initialization " @@ -199,29 +205,123 @@ def reset_parameters(self) -> None: "performed within autocast." ) quantizer.set_usage( - rowwise=True, - columnwise=torch.is_grad_enabled(), + rowwise=with_rowwise_usage, + columnwise=with_columnwise_usage, ) quantizer.internal = False - with torch.no_grad(): - weight = quantizer(weight) - # Save updated parameters + # Quantize weights + weights = self._quantize_weights(weights, quantizers) + + # Register weights + for group_idx, weight in enumerate(weights): if not isinstance(weight, torch.nn.Parameter): weight = torch.nn.Parameter(weight) setattr(self, f"weight{group_idx}", weight) # Initialize biases if needed if self.bias0 is not None: - with torch.no_grad(): - for group_idx in range(self.num_groups): - bias = getattr(self, f"bias{group_idx}") - if not devices_match(bias.device, device): - bias = torch.empty_like(bias, device=device) - bias.zero_() - if not isinstance(bias, torch.nn.Parameter): - bias = torch.nn.Parameter(bias) - setattr(self, f"bias{group_idx}", bias) + packed_biases = torch.zeros( + self.num_groups, + self.out_features, + dtype=self.bias0.dtype, + device=device, + ) + for group_idx in range(self.num_groups): + bias = torch.nn.Parameter(packed_biases[group_idx]) + setattr(self, f"bias{group_idx}", bias) + + def _quantize_weights( + self, + weights: Sequence[torch.Tensor], + quantizers: Sequence[Quantizer], + ) -> Sequence[torch.Tensor]: + """Construct quantized weight tensors.""" + + # Manually construct MXFP8 weights + if isinstance(quantizers[0], MXFP8Quantizer): + return self._quantize_weights_mxfp8(weights, quantizers) + + # Use quantizers to construct quantized weights + with torch.no_grad(): + return [ + quantizer(weight) + for quantizer, weight in zip(quantizers, weights) + ] + + def _quantize_weights_mxfp8( + self, + weights: Sequence[torch.Tensor], + quantizers: Sequence[Quantizer], + ) -> Sequence[MXFP8Tensor]: + """Construct MXFP8 weight tensors. + + Instead of allocating separate buffers for each weight tensor, + this function constructs large buffers and assigns subviews to + each tensor. This is intended to support grouped GEMM kernels + that expect packed buffers. + + """ + + # Tensor dimensions + num_groups = len(weights) + out_features, in_features = weights[0].size() + packed_shape = (num_groups, out_features, in_features) + unpacked_shape = (out_features, in_features) + + # Tensor attributes + device = weights[0].device + dtype = weights[0].dtype + requires_grad = torch.is_grad_enabled() + with_rowwise_usage = quantizers[0].rowwise_usage + with_columnwise_usage = quantizers[0].columnwise_usage + + # Construct packed buffers + rowwise_data = [None] * num_groups + rowwise_scales = [None] * num_groups + columnwise_data = [None] * num_groups + columnwise_scales = [None] * num_groups + if with_rowwise_usage: + scale_shape = ( + num_groups, + round_up_to_nearest_multiple(out_features, 128), + round_up_to_nearest_multiple(in_features // 32, 4), + ) + packed_data = torch.empty(packed_shape, dtype=torch.uint8, device=device) + packed_scales = torch.empty(scale_shape, dtype=torch.uint8, device=device) + rowwise_data = [packed_data[idx] for idx in range(num_groups)] + rowwise_scales = [packed_scales[idx] for idx in range(num_groups)] + if with_columnwise_usage: + scale_shape = ( + num_groups, + round_up_to_nearest_multiple(out_features // 32, 4), + round_up_to_nearest_multiple(in_features, 128), + ) + packed_data = torch.empty(packed_shape, dtype=torch.uint8, device=device) + packed_scales = torch.empty(scale_shape, dtype=torch.uint8, device=device) + columnwise_data = [packed_data[idx] for idx in range(num_groups)] + columnwise_scales = [packed_scales[idx] for idx in range(num_groups)] + + # Construct MXFP8 tensors and cast to MXFP8 + out = [] + with torch.no_grad(): + for group_idx in range(num_groups): + weight = MXFP8Tensor( + shape=unpacked_shape, + dtype=dtype, + fp8_dtype=dtype, + rowwise_data=rowwise_data[group_idx], + rowwise_scale_inv=rowwise_scales[group_idx], + columnwise_data=columnwise_data[group_idx], + columnwise_scale_inv=columnwise_scales[group_idx], + quantizer=quantizers[group_idx], + requires_grad=requires_grad, + with_gemm_swizzled_scales=False, + ) + weight.copy_(weights[group_idx]) + out.append(weight) + + return out def pre_first_fuser_forward(self) -> None: super().pre_first_fuser_forward() diff --git a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py index 06fe3fca93..ab07df47a9 100644 --- a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py @@ -165,23 +165,17 @@ def fuser_forward( # Extract post-scales from extra input scales = basic_op_extra_inputs[1][0] - # Extract params - fc1_weights = [getattr(fc1_op, f"weight{idx}") for idx in range(num_groups)] - fc2_weights = [getattr(fc2_op, f"weight{idx}") for idx in range(num_groups)] - - # Convert weight dtype if needed - fc1_ws = [] - fc2_ws = [] - for w, quantizer in zip(fc1_weights, fc1_weight_quantizers): - if not is_quantized_tensor(w): + # Extract params and quantize to MXFP8 if needed + fc1_ws = [getattr(fc1_op, f"weight{idx}") for idx in range(num_groups)] + fc2_ws = [getattr(fc2_op, f"weight{idx}") for idx in range(num_groups)] + if not is_quantized_tensor(fc1_ws[0]): + for quantizer in fc1_weight_quantizers: quantizer.set_usage(rowwise=True, columnwise=input_requires_grad) - w = quantizer(w) - fc1_ws.append(w) - for w, quantizer in zip(fc2_weights, fc2_weight_quantizers): - if not is_quantized_tensor(w): + fc1_ws = fc1_op._quantize_weights_mxfp8(fc1_ws, fc1_weight_quantizers) + if not is_quantized_tensor(fc2_ws[0]): + for quantizer in fc2_weight_quantizers: quantizer.set_usage(rowwise=True, columnwise=input_requires_grad) - w = quantizer(w) - fc2_ws.append(w) + fc2_ws = fc2_op._quantize_weights_mxfp8(fc2_ws, fc2_weight_quantizers) # Split input tensor and convert dtypes if needed fc1_x = maybe_dequantize(input_, dtype) @@ -224,7 +218,7 @@ def fuser_forward( # Data logical shape: (n, k, num_groups) # Scale logical shape: (32 (block row), 4 (block row), n/128, # 4 (block col), k/128, num_groups) - fc1_w_data = noop_cat([w._rowwise_data for w in fc1_weights]) + fc1_w_data = noop_cat([w._rowwise_data for w in fc1_ws]) fc1_w_data = fc1_w_data.view(dtype=torch.float8_e4m3fn) fc1_w_data = fc1_w_data.view( num_groups, fc1_weight_shape[0] // 64, 2, 32, fc1_weight_shape[1] @@ -232,7 +226,7 @@ def fuser_forward( fc1_w_data = fc1_w_data.flip(2).contiguous() # Swap SwiGLU gate/activation fc1_w_data = fc1_w_data.view(num_groups, fc1_weight_shape[0], fc1_weight_shape[1]) fc1_w_data = fc1_w_data.permute(1, 2, 0) - fc1_w_scales = noop_cat([w._rowwise_scale_inv for w in fc1_weights]) + fc1_w_scales = noop_cat([w._rowwise_scale_inv for w in fc1_ws]) fc1_w_scales = fc1_w_scales.view(dtype=torch.float8_e8m0fnu) fc1_w_scales = fc1_w_scales.view( num_groups, fc1_weight_shape[0] // 64, 2, 32, fc1_weight_shape[1] // 32 From 2e577d16fc6119a682b1dd120abe9021c4062e26 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 4 Feb 2026 03:39:49 +0000 Subject: [PATCH 45/45] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/ops/basic/grouped_linear.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/transformer_engine/pytorch/ops/basic/grouped_linear.py b/transformer_engine/pytorch/ops/basic/grouped_linear.py index 2fda7d9abc..75edc19c08 100644 --- a/transformer_engine/pytorch/ops/basic/grouped_linear.py +++ b/transformer_engine/pytorch/ops/basic/grouped_linear.py @@ -190,8 +190,7 @@ def reset_parameters(self) -> None: # Configure quantizers quantizers = [ - self.get_quantizer("forward", 2 * idx + 1) - for idx in range(self.num_groups) + self.get_quantizer("forward", 2 * idx + 1) for idx in range(self.num_groups) ] with_rowwise_usage = True with_columnwise_usage = torch.is_grad_enabled() @@ -244,10 +243,7 @@ def _quantize_weights( # Use quantizers to construct quantized weights with torch.no_grad(): - return [ - quantizer(weight) - for quantizer, weight in zip(quantizers, weights) - ] + return [quantizer(weight) for quantizer, weight in zip(quantizers, weights)] def _quantize_weights_mxfp8( self,