Skip to content

Conversation

@timmoon10
Copy link
Collaborator

@timmoon10 timmoon10 commented Jan 24, 2026

Description

This PR adds a grouped linear op, which can be used in the grouped MLP block in Mixture-of-Experts models. It also adds an experimental fused operation for a grouped MLP block, using a CuTe DSL kernel that computes an MXFP8 grouped GEMM and SwiGLU.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Add a grouped linear operation
  • Add a post-scaled SwiGLU op and add support for interleaving SwiGLU gate and linear units
  • Add a fused operation for grouped MLP

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

timmoon10 and others added 30 commits January 7, 2026 00:15
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Refactor fusion functions to remove index bookkeeping. Refactor fused ops to use consistent operation order.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Test is too permissive since the test should still be failing. The weights are not properly interleaved yet.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
@timmoon10 timmoon10 marked this pull request as ready for review January 25, 2026 01:00
@timmoon10
Copy link
Collaborator Author

/te-ci pytorch L1

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 25, 2026

Greptile Overview

Greptile Summary

Adds grouped linear operation and experimental fused grouped MLP for Mixture-of-Experts models. The implementation includes a new GroupedLinear operation that splits input tensors and applies separate linear transformations to each group, a ScaledSwiGLU activation with post-scaling support, and an experimental fused operation ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8 that uses a CuTe DSL kernel from cuDNN to fuse FC1 + SwiGLU + FC2 into fewer kernel launches for SM100+ GPUs with MXFP8 quantization.

Key changes:

  • GroupedLinear supports MXFP8 quantization with packed weight buffers for efficient grouped GEMMs
  • ScaledSwiGLU enables post-scaling with optional 32-wide gate/activation interleaving
  • Experimental fusion uses CuTe DSL kernel (SM100+ only) to compute grouped GEMM + SwiGLU + post-scale in a single kernel
  • Helper function noop_cat added for efficient tensor concatenation without copying when tensors are already contiguous in memory
  • Comprehensive test coverage added for both operations

Issues previously reported have been addressed:
All previously flagged issues (undefined variables, duplicate condition checks, typos, missing f-string prefixes, incorrect attribute access, gradient accumulation flag handling) appear to have been fixed in the current version.

Confidence Score: 4/5

  • Safe to merge with minor considerations - the experimental fusion is properly gated behind SM100+ checks and MXFP8 recipe detection
  • The implementation is well-structured with comprehensive tests, proper error handling, and hardware capability checks. All previously reported issues have been addressed. Score is 4 rather than 5 because the experimental CuTe DSL fusion is complex and hardware-specific (SM100+ only), requiring thorough hardware testing that cannot be verified from code review alone.
  • The experimental fusion in forward_grouped_mlp.py requires SM100+ hardware validation. All other files appear production-ready.

Important Files Changed

Filename Overview
transformer_engine/pytorch/ops/basic/grouped_linear.py Implements grouped linear operation with MXFP8 quantization support, includes proper parameter initialization and gradient accumulation
transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py Implements experimental fused operation for grouped MLP using CuTe DSL kernel for MXFP8, SM100+ only
transformer_engine/pytorch/ops/basic/swiglu.py Adds ScaledSwiGLU operation with post-scaling support and optional gate/activation interleaving

Sequence Diagram

sequenceDiagram
    participant User
    participant Sequential as te_ops.Sequential
    participant GroupedLinear1 as GroupedLinear (FC1)
    participant ScaledSwiGLU
    participant GroupedLinear2 as GroupedLinear (FC2)
    participant FusedOp as ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8
    participant CuTeKernel as CuDNN CuTe DSL Kernel

    Note over User,CuTeKernel: Regular Path (No Fusion)
    User->>Sequential: forward(input, split_sizes, scales)
    Sequential->>GroupedLinear1: forward(input, split_sizes)
    GroupedLinear1->>GroupedLinear1: Split input by split_sizes
    GroupedLinear1->>GroupedLinear1: Quantize to MXFP8 if enabled
    GroupedLinear1->>GroupedLinear1: general_grouped_gemm(weights, inputs)
    GroupedLinear1-->>Sequential: fc1_output
    Sequential->>ScaledSwiGLU: forward(fc1_output, scales)
    ScaledSwiGLU->>ScaledSwiGLU: Remove gate interleaving if needed
    ScaledSwiGLU->>ScaledSwiGLU: Compute SwiGLU activation
    ScaledSwiGLU->>ScaledSwiGLU: Apply post-scaling (output * scales)
    ScaledSwiGLU-->>Sequential: swiglu_output
    Sequential->>GroupedLinear2: forward(swiglu_output, split_sizes)
    GroupedLinear2->>GroupedLinear2: Split input by split_sizes
    GroupedLinear2->>GroupedLinear2: Quantize to MXFP8 if enabled
    GroupedLinear2->>GroupedLinear2: general_grouped_gemm(weights, inputs)
    GroupedLinear2-->>Sequential: final_output
    Sequential-->>User: final_output

    Note over User,CuTeKernel: Fused Path (MXFP8 + SM100+)
    User->>Sequential: forward(input, split_sizes, scales)
    Sequential->>FusedOp: fuser_forward(input, split_sizes, scales)
    FusedOp->>FusedOp: Quantize FC1 inputs to MXFP8
    FusedOp->>FusedOp: Pack FC1 data/scales with gate swapping
    FusedOp->>CuTeKernel: grouped_gemm_swiglu_wrapper_sm100()
    Note right of CuTeKernel: Single kernel:<br/>FC1 GEMM + SwiGLU + post-scale
    CuTeKernel-->>FusedOp: FC2 inputs (MXFP8, row+col quantized)
    FusedOp->>FusedOp: Unpack FC2 inputs and undo gate swap
    FusedOp->>FusedOp: Construct MXFP8Tensor objects
    FusedOp->>FusedOp: general_grouped_gemm(FC2 weights, FC2 inputs)
    FusedOp-->>Sequential: final_output
    Sequential-->>User: final_output
Loading

greptile-apps[bot]

This comment was marked as resolved.

@greptile-apps

This comment was marked as outdated.

greptile-apps[bot]

This comment was marked as resolved.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
greptile-apps[bot]

This comment was marked as resolved.

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
greptile-apps[bot]

This comment was marked as outdated.

quantizer.optimize_for_gemm = True
fc1_xs = tex.split_quantize(fc1_x, split_sizes_cpu, fc1_input_quantizers)

# Pack data tensors
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May be a silly question: are these packing and unpacking code just for verification? Or will they be in the final version?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm working on getting rid of the concatenations, but the permutes are no-ops. The kernel API expects tensors with non-contiguous dims: https://github.com/NVIDIA/cudnn-frontend/blob/main/python/cudnn/grouped_gemm/grouped_gemm_swiglu/api.py#L240-L245

)

# Fused kernel for FC1 + SwiGLU + post-scale
fc1_kernel_out = self.grouped_gemm_swiglu_kernel()(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After swiglu, it usually needs to multiply with permuted_probs. Does this weighted swiglu supported?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
greptile-apps[bot]

This comment was marked as resolved.

Review suggestions from @greptile-apps

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
greptile-apps[bot]

This comment was marked as outdated.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
greptile-apps[bot]

This comment was marked as outdated.

greptile-apps[bot]

This comment was marked as resolved.

Review suggestion from @greptile-apps

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Comment on lines +537 to +539
accumulate_into_main_grad = not getattr(
weight_param, "overwrite_main_grad", False
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

accumulate_into_main_grad reassigned in loop - last group's setting applies to all groups in GEMM call on line 576. If different weight groups have different overwrite_main_grad settings, this causes incorrect gradient accumulation behavior. Should either check consistency across groups or use per-group flags.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

performance Performance issues

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants