Skip to content

Conversation

@Wohox
Copy link
Contributor

@Wohox Wohox commented Jan 22, 2026

Description

This PR adds get_backward_dw_params for TE modules, which helps manage the hooks of parameters.

For Megatron-LM, get_backward_dw_params will be called once the wgrad cuda graph is executed. Currently the backward_post_hook of wgrad computation is discarded and will cause parameters to skip grad reduce.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 22, 2026

Greptile Overview

Greptile Summary

This PR fixes a critical bug where weight gradient reduction hooks were not triggered after CUDA graph replay in Megatron-LM integration. The fix extracts the hook triggering logic into a reusable _trigger_wgrad_accumulation_and_reduce_hooks() method in TransformerEngineBaseModule, and calls it after replaying the wgrad CUDA graph in graph.py.

Key changes:

  • Extracted hook execution into _trigger_wgrad_accumulation_and_reduce_hooks() for reusability
  • Added hook triggering after bwd_dw_graphs[graph_idx].replay() to ensure gradient reduction occurs
  • Maintains proper iteration over visited_te_modules with need_backward_dw() check

The implementation is clean, follows existing patterns in the codebase, and addresses the previous review concern about fuse_wgrad_accumulation by only triggering hooks when need_backward_dw() returns true.

Confidence Score: 5/5

  • This PR is safe to merge with minimal risk
  • The changes are well-scoped, address a specific bug, follow existing code patterns, and properly guard the hook triggering with need_backward_dw() checks
  • No files require special attention

Important Files Changed

Filename Overview
transformer_engine/pytorch/module/base.py Extracted hook triggering logic into _trigger_wgrad_accumulation_and_reduce_hooks() method for reusability in cuda graph context
transformer_engine/pytorch/graph.py Added hook triggering after wgrad cuda graph replay to ensure gradient reduction hooks are properly executed

Sequence Diagram

sequenceDiagram
    participant MegatronLM
    participant GraphedCallable
    participant CUDAGraph
    participant TEModule
    participant Hooks

    Note over MegatronLM,Hooks: Wgrad CUDA Graph Execution Flow

    MegatronLM->>GraphedCallable: Call backward_dw()
    GraphedCallable->>CUDAGraph: Check if need_bwd_dw_graph
    
    alt CUDA graph needed
        CUDAGraph->>CUDAGraph: Replay bwd_dw_graph
        Note over CUDAGraph: Weight gradient computation executed
        
        CUDAGraph->>TEModule: Iterate visited_te_modules
        loop For each TE module
            TEModule->>TEModule: Check need_backward_dw()
            alt Module needs backward_dw
                TEModule->>TEModule: _trigger_wgrad_accumulation_and_reduce_hooks()
                TEModule->>Hooks: Execute registered hooks
                Note over Hooks: Gradient accumulation<br/>and reduction performed
            end
        end
    end

    Note over MegatronLM,Hooks: Without this PR: Hooks would be skipped<br/>after CUDA graph replay
Loading

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 22, 2026

Greptile's behavior is changing!

From now on, if a review finishes with no comments, we will not post an additional "statistics" comment to confirm that our review found nothing to comment on. However, you can confirm that we reviewed your changes in the status check section.

This feature can be toggled off in your Code Review Settings by deselecting "Create a status check for each PR".

@Wohox
Copy link
Contributor Author

Wohox commented Jan 22, 2026

@buptzyb @lhb8125 Please help review this PR, thanks!

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Get the parameters for the backward weight gradient computation.
"""
params = []
params.append(noop_cat(self._get_weight_tensors()))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: in backward_dw() (line 1520-1522), weight tensors are only accessed when not self.fuse_wgrad_accumulation, but this method unconditionally returns weight parameters. depending on Megatron-LM's usage, this could cause hooks to be registered on parameters that shouldn't have them when fuse_wgrad_accumulation=True

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

commit content reverted.

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@Wohox
Copy link
Contributor Author

Wohox commented Jan 30, 2026

@ksivaman Can you help review this PR, it's a bug fix for #2376.

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant