Skip to content

Conversation

@faradawn
Copy link

@faradawn faradawn commented Feb 2, 2026

Description

Create a MoE tutorial for TE. The model used is Mixtral 7B.

View the notebook better: https://github.com/faradawn/TransformerEngine/blob/add-moe-example/docs/examples/te_mixtral/tutorial_accelerate_hf_mixtral_with_te.ipynb

Fixes #2573

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Faradawn Yang <73060648+faradawn@users.noreply.github.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 2, 2026

Greptile Overview

Greptile Summary

This PR adds a tutorial demonstrating how to integrate Mixtral's MoE (Mixture of Experts) layers with Transformer Engine's GroupedLinear module for efficient expert processing. The tutorial includes a complete implementation of TEMixtralSparseMoeBlock that uses TE's MoE primitives (moe_permute, moe_unpermute) and GroupedLinear for parallel expert computation.

Key changes:

  • Created TEMixtralSparseMoeBlock class that wraps Mixtral MoE with TE's GroupedLinear
  • Uses moe_permute/unpermute for efficient token routing
  • Combines gate and up projections in a single GroupedLinear layer for SwiGLU
  • Includes working test code with mock config

Critical issue found:

  • The m_splits calculation on lines 98-102 contains a logic error that incorrectly multiplies token counts by top_k, which will cause dimension mismatches when passed to GroupedLinear. This needs to be fixed before the tutorial can work correctly.

Confidence Score: 2/5

  • This PR has a critical bug that will cause runtime errors
  • The m_splits calculation contains a logic error that double-counts tokens by multiplying by top_k, which will cause dimension mismatches in GroupedLinear operations and lead to runtime failures
  • The notebook file docs/examples/te_mixtral/tutorial_accelerate_hf_mixtral_with_te.ipynb requires fixing the m_splits calculation logic before it can work correctly

Important Files Changed

Filename Overview
docs/examples/te_mixtral/tutorial_accelerate_hf_mixtral_with_te.ipynb Added MoE tutorial with Mixtral integration, but contains critical bug in m_splits calculation that will cause runtime errors

Sequence Diagram

sequenceDiagram
    participant User as User Code
    participant MoE as TEMixtralSparseMoeBlock
    participant Router as Router (gate)
    participant Permute as te.moe_permute
    participant GLinear as GroupedLinear
    participant Unpermute as te.moe_unpermute

    User->>MoE: forward(hidden_states)
    MoE->>MoE: Flatten to [num_tokens, hidden_dim]
    
    MoE->>Router: Get expert assignments
    Router-->>MoE: router_logits [num_tokens, num_experts]
    MoE->>MoE: Compute routing_weights & select top_k experts
    
    MoE->>Permute: moe_permute(tokens, selected_experts)
    Permute-->>MoE: permuted_tokens, row_id_map
    
    MoE->>MoE: Calculate m_splits per expert
    
    MoE->>GLinear: experts_gate_up(permuted_tokens, m_splits)
    GLinear-->>MoE: intermediate [combined gate+up projections]
    
    MoE->>MoE: Apply SwiGLU: silu(gate) * up
    
    MoE->>GLinear: experts_down(intermediate_act, m_splits)
    GLinear-->>MoE: expert_outputs
    
    MoE->>Unpermute: moe_unpermute(expert_outputs, row_id_map, routing_weights)
    Unpermute-->>MoE: final_hidden_states
    
    MoE->>MoE: Reshape to [batch, seq_len, hidden_dim]
    MoE-->>User: final_hidden_states, router_logits
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +98 to +102
" # Calculate m_splits: number of tokens assigned to each expert\n",
" m_splits = []\n",
" for expert_idx in range(self.num_experts):\n",
" expert_mask = (selected_experts == expert_idx).any(dim=-1)\n",
" m_splits.append(expert_mask.sum().item() * self.top_k)\n",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Logic error in m_splits calculation. The current approach counts tokens incorrectly by multiplying by top_k after already considering all expert assignments.

The issue: expert_mask already captures ALL tokens that selected this expert (across all top-k positions), so multiplying by self.top_k double-counts.

For example, if token 0 selects experts [1, 3] and token 1 selects experts [1, 2], then for expert 1: expert_mask will be [True, True] (sum=2). Multiplying by top_k=2 gives 4, but only 2 tokens actually go to expert 1.

Suggested change
" # Calculate m_splits: number of tokens assigned to each expert\n",
" m_splits = []\n",
" for expert_idx in range(self.num_experts):\n",
" expert_mask = (selected_experts == expert_idx).any(dim=-1)\n",
" m_splits.append(expert_mask.sum().item() * self.top_k)\n",
# Calculate m_splits: number of tokens assigned to each expert
m_splits = []
for expert_idx in range(self.num_experts):
expert_mask = (selected_experts == expert_idx).any(dim=-1)
m_splits.append(expert_mask.sum().item())

Comment on lines +91 to +95
" permuted_tokens, row_id_map = te.moe_permute(\n",
" hidden_states_flat,\n",
" selected_experts.to(torch.int32),\n",
" num_out_tokens=None, # Auto-calculate\n",
" max_token_num=num_tokens\n",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Setting num_out_tokens to None is fine for auto-calculation, but when using top_k > 1, the expected output token count should be num_tokens times top_k since each token is routed to multiple experts.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add examples for MoE models (non-Megatron)

1 participant