-
Notifications
You must be signed in to change notification settings - Fork 624
[PyTorch] Add grouped linear op and experimental fusion for grouped MLP #2622
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Refactor fusion functions to remove index bookkeeping. Refactor fused ops to use consistent operation order. Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Test is too permissive since the test should still be failing. The weights are not properly interleaved yet. Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
|
/te-ci pytorch L1 |
Greptile OverviewGreptile SummaryAdds grouped linear operation and experimental fused grouped MLP for Mixture-of-Experts models. The implementation includes a new Key changes:
Issues previously reported have been addressed: Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant Sequential as te_ops.Sequential
participant GroupedLinear1 as GroupedLinear (FC1)
participant ScaledSwiGLU
participant GroupedLinear2 as GroupedLinear (FC2)
participant FusedOp as ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8
participant CuTeKernel as CuDNN CuTe DSL Kernel
Note over User,CuTeKernel: Regular Path (No Fusion)
User->>Sequential: forward(input, split_sizes, scales)
Sequential->>GroupedLinear1: forward(input, split_sizes)
GroupedLinear1->>GroupedLinear1: Split input by split_sizes
GroupedLinear1->>GroupedLinear1: Quantize to MXFP8 if enabled
GroupedLinear1->>GroupedLinear1: general_grouped_gemm(weights, inputs)
GroupedLinear1-->>Sequential: fc1_output
Sequential->>ScaledSwiGLU: forward(fc1_output, scales)
ScaledSwiGLU->>ScaledSwiGLU: Remove gate interleaving if needed
ScaledSwiGLU->>ScaledSwiGLU: Compute SwiGLU activation
ScaledSwiGLU->>ScaledSwiGLU: Apply post-scaling (output * scales)
ScaledSwiGLU-->>Sequential: swiglu_output
Sequential->>GroupedLinear2: forward(swiglu_output, split_sizes)
GroupedLinear2->>GroupedLinear2: Split input by split_sizes
GroupedLinear2->>GroupedLinear2: Quantize to MXFP8 if enabled
GroupedLinear2->>GroupedLinear2: general_grouped_gemm(weights, inputs)
GroupedLinear2-->>Sequential: final_output
Sequential-->>User: final_output
Note over User,CuTeKernel: Fused Path (MXFP8 + SM100+)
User->>Sequential: forward(input, split_sizes, scales)
Sequential->>FusedOp: fuser_forward(input, split_sizes, scales)
FusedOp->>FusedOp: Quantize FC1 inputs to MXFP8
FusedOp->>FusedOp: Pack FC1 data/scales with gate swapping
FusedOp->>CuTeKernel: grouped_gemm_swiglu_wrapper_sm100()
Note right of CuTeKernel: Single kernel:<br/>FC1 GEMM + SwiGLU + post-scale
CuTeKernel-->>FusedOp: FC2 inputs (MXFP8, row+col quantized)
FusedOp->>FusedOp: Unpack FC2 inputs and undo gate swap
FusedOp->>FusedOp: Construct MXFP8Tensor objects
FusedOp->>FusedOp: general_grouped_gemm(FC2 weights, FC2 inputs)
FusedOp-->>Sequential: final_output
Sequential-->>User: final_output
|
This comment was marked as outdated.
This comment was marked as outdated.
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
| quantizer.optimize_for_gemm = True | ||
| fc1_xs = tex.split_quantize(fc1_x, split_sizes_cpu, fc1_input_quantizers) | ||
|
|
||
| # Pack data tensors |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
May be a silly question: are these packing and unpacking code just for verification? Or will they be in the final version?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm working on getting rid of the concatenations, but the permutes are no-ops. The kernel API expects tensors with non-contiguous dims: https://github.com/NVIDIA/cudnn-frontend/blob/main/python/cudnn/grouped_gemm/grouped_gemm_swiglu/api.py#L240-L245
| ) | ||
|
|
||
| # Fused kernel for FC1 + SwiGLU + post-scale | ||
| fc1_kernel_out = self.grouped_gemm_swiglu_kernel()( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After swiglu, it usually needs to multiply with permuted_probs. Does this weighted swiglu supported?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the probs are passed into the kernel here: https://github.com/timmoon10/TransformerEngine/blob/46294be478f6551e2cf251283adc7529ddb2964e/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py#L264
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Review suggestions from @greptile-apps Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Review suggestion from @greptile-apps Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 files reviewed, 1 comment
| accumulate_into_main_grad = not getattr( | ||
| weight_param, "overwrite_main_grad", False | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
accumulate_into_main_grad reassigned in loop - last group's setting applies to all groups in GEMM call on line 576. If different weight groups have different overwrite_main_grad settings, this causes incorrect gradient accumulation behavior. Should either check consistency across groups or use per-group flags.
Signed-off-by: Tim Moon <tmoon@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 files reviewed, no comments
Signed-off-by: Tim Moon <tmoon@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 files reviewed, no comments
Signed-off-by: Tim Moon <tmoon@nvidia.com>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 files reviewed, no comments
Description
This PR adds a grouped linear op, which can be used in the grouped MLP block in Mixture-of-Experts models. It also adds an experimental fused operation for a grouped MLP block, using a CuTe DSL kernel that computes an MXFP8 grouped GEMM and SwiGLU.
Type of change
Changes
Checklist: