-
Notifications
You must be signed in to change notification settings - Fork 624
Add examples for MoE models - Mixtral in TE #2642
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Faradawn Yang <73060648+faradawn@users.noreply.github.com>
Greptile OverviewGreptile SummaryThis PR adds a tutorial demonstrating how to integrate Mixtral's MoE (Mixture of Experts) layers with Transformer Engine's Key changes:
Critical issue found:
Confidence Score: 2/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User as User Code
participant MoE as TEMixtralSparseMoeBlock
participant Router as Router (gate)
participant Permute as te.moe_permute
participant GLinear as GroupedLinear
participant Unpermute as te.moe_unpermute
User->>MoE: forward(hidden_states)
MoE->>MoE: Flatten to [num_tokens, hidden_dim]
MoE->>Router: Get expert assignments
Router-->>MoE: router_logits [num_tokens, num_experts]
MoE->>MoE: Compute routing_weights & select top_k experts
MoE->>Permute: moe_permute(tokens, selected_experts)
Permute-->>MoE: permuted_tokens, row_id_map
MoE->>MoE: Calculate m_splits per expert
MoE->>GLinear: experts_gate_up(permuted_tokens, m_splits)
GLinear-->>MoE: intermediate [combined gate+up projections]
MoE->>MoE: Apply SwiGLU: silu(gate) * up
MoE->>GLinear: experts_down(intermediate_act, m_splits)
GLinear-->>MoE: expert_outputs
MoE->>Unpermute: moe_unpermute(expert_outputs, row_id_map, routing_weights)
Unpermute-->>MoE: final_hidden_states
MoE->>MoE: Reshape to [batch, seq_len, hidden_dim]
MoE-->>User: final_hidden_states, router_logits
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 file reviewed, 2 comments
| " # Calculate m_splits: number of tokens assigned to each expert\n", | ||
| " m_splits = []\n", | ||
| " for expert_idx in range(self.num_experts):\n", | ||
| " expert_mask = (selected_experts == expert_idx).any(dim=-1)\n", | ||
| " m_splits.append(expert_mask.sum().item() * self.top_k)\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Logic error in m_splits calculation. The current approach counts tokens incorrectly by multiplying by top_k after already considering all expert assignments.
The issue: expert_mask already captures ALL tokens that selected this expert (across all top-k positions), so multiplying by self.top_k double-counts.
For example, if token 0 selects experts [1, 3] and token 1 selects experts [1, 2], then for expert 1: expert_mask will be [True, True] (sum=2). Multiplying by top_k=2 gives 4, but only 2 tokens actually go to expert 1.
| " # Calculate m_splits: number of tokens assigned to each expert\n", | |
| " m_splits = []\n", | |
| " for expert_idx in range(self.num_experts):\n", | |
| " expert_mask = (selected_experts == expert_idx).any(dim=-1)\n", | |
| " m_splits.append(expert_mask.sum().item() * self.top_k)\n", | |
| # Calculate m_splits: number of tokens assigned to each expert | |
| m_splits = [] | |
| for expert_idx in range(self.num_experts): | |
| expert_mask = (selected_experts == expert_idx).any(dim=-1) | |
| m_splits.append(expert_mask.sum().item()) |
| " permuted_tokens, row_id_map = te.moe_permute(\n", | ||
| " hidden_states_flat,\n", | ||
| " selected_experts.to(torch.int32),\n", | ||
| " num_out_tokens=None, # Auto-calculate\n", | ||
| " max_token_num=num_tokens\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Setting num_out_tokens to None is fine for auto-calculation, but when using top_k > 1, the expected output token count should be num_tokens times top_k since each token is routed to multiple experts.
Description
Create a MoE tutorial for TE. The model used is Mixtral 7B.
View the notebook better: https://github.com/faradawn/TransformerEngine/blob/add-moe-example/docs/examples/te_mixtral/tutorial_accelerate_hf_mixtral_with_te.ipynb
Fixes #2573
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: