-
Notifications
You must be signed in to change notification settings - Fork 624
[PyTorch] Pad V when Q/V head dims differ (MLA) for THD #2629
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Greptile OverviewGreptile SummaryAdds support for Multi-Latent Attention (MLA) in THD format when Q and V have different head dimensions. When V's head dimension is smaller than Q's, the implementation pads V to match Q's dimension before attention computation, then trims the output back to the original V dimension. Key changes:
Implementation details:
Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant Caller
participant DotProductAttention
participant AttentionBackend as Attention Backend<br/>(Flash/Fused/Unfused)
Caller->>DotProductAttention: forward(Q, K, V)<br/>head_dim_qk=128, head_dim_v=64
Note over DotProductAttention: Check THD format &<br/>head dim mismatch
alt head_dim_v < head_dim_qk
DotProductAttention->>DotProductAttention: Save orig_v_dim = 64
DotProductAttention->>DotProductAttention: Pad V: 64 → 128<br/>Set pad_v_for_thd = True
DotProductAttention->>DotProductAttention: Update head_dim_v = 128
end
DotProductAttention->>AttentionBackend: attention(Q, K, V_padded)
AttentionBackend-->>DotProductAttention: attn_out (head_dim=128)
alt pad_v_for_thd == True
DotProductAttention->>DotProductAttention: _trim_thd_output()
Note over DotProductAttention: Reshape using head_dim_v (128)<br/>Trim to orig_v_dim (64)
DotProductAttention->>DotProductAttention: attn_out[..., :64]
end
DotProductAttention-->>Caller: Return trimmed output<br/>(head_dim=64)
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 files reviewed, 2 comments
transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR adds support for Multi-head Latent Attention (MLA) with mismatched Q/V head dimensions in the THD (Total-Hidden-Dimension) format. When the value tensor has a smaller head dimension than the query/key tensors, the code pads the value tensor to match the Q/K head dimension, runs the attention operation, and then trims the output back to the original V dimension.
Changes:
- Added padding logic for V tensor when head dimensions differ in THD format
- Implemented trimming function to restore correct output dimensions after attention
- Added test case for THD attention with mismatched Q/V head dimensions
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
| transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py | Implements padding of V tensor before attention and trimming of output after attention for THD format with mismatched Q/V head dimensions |
| tests/pytorch/attention/test_attention.py | Adds test case to verify THD attention works with different Q/V head dimensions |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py
Show resolved
Hide resolved
Signed-off-by: Hollow Man <hollowman@opensuse.org>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 file reviewed, 1 comment
Description
For MLA, we shall pad V when Q/V head dims differ for THD
Similar to NVIDIA/Megatron-LM#3003
Fixes NVIDIA/Megatron-LM#1698
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: