-
Notifications
You must be signed in to change notification settings - Fork 32
Add DeviceContext into Iris main class
#347
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This pull request adds a DeviceContext class to the main Iris module to provide a cleaner, object-oriented API for device-side Iris operations. The PR follows the "Gluon pattern" by encapsulating rank, world_size, and heap_bases in a single context object that can be initialized on the device from an encoded tensor.
Changes:
- Introduces
DeviceContextaggregate class iniris/iris.pywith methods for load, store, get, put, copy, and all atomic operations - Adds
Iris.get_device_context()method to create an encoded context tensor containing rank, world size, and heap bases - Moves
DeviceContextfromiris/x/core.pytoiris/iris.pyand updates all iris.x collective operations to import from the new location
Reviewed changes
Copilot reviewed 12 out of 12 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
| iris/iris.py | Adds DeviceContext class with OOP API and get_device_context() method for creating encoded context tensors |
| iris/init.py | Exports DeviceContext from main iris module |
| iris/x/core.py | Removes skeletal DeviceContext definition and updates all exports |
| iris/x/init.py | Updates documentation examples and removes DeviceContext from exports |
| iris/x/all_reduce.py | Updates import to use DeviceContext from iris.iris instead of iris.x.core |
| iris/x/all_gather.py | Updates import to use DeviceContext from iris.iris instead of iris.x.core |
| iris/x/all_to_all.py | Updates import to use DeviceContext from iris.iris instead of iris.x.core |
| iris/x/gather.py | Updates import to use DeviceContext from iris.iris instead of iris.x.core |
| iris/x/reduce_scatter.py | Updates import to use DeviceContext from iris.iris instead of iris.x.core |
| tests/unittests/test_device_context.py | Adds comprehensive unit tests for DeviceContext methods including load, store, get, put, atomic operations |
| examples/06_message_passing/message_passing_device_context.py | Adds example demonstrating producer-consumer pattern using DeviceContext API |
| examples/common/utils.py | Minor style fix removing redundant parentheses from lambda expression |
Comments suppressed due to low confidence (2)
iris/x/init.py:28
- The example code shows calling collective operations as methods on DeviceContext (e.g.,
ctx.all_reduce(...),ctx.all_gather(...)), but DeviceContext doesn't have these methods. The collective operations are standalone functions that take a DeviceContext as a parameter. The correct usage isiris.x.all_reduce(tile, src_view, dst_view, ctx)notctx.all_reduce(tile, src_view, dst_view).
>>> ctx = DeviceContext.initialize(context_tensor, rank, world_size)
>>>
>>> # Call collectives on ctx directly (default algorithms)
>>> ctx.all_reduce(tile, src_view, dst_view)
>>> ctx.all_gather(tile, src_view, dst_view, dim=0)
>>> ctx.all_to_all(tile, src_view, dst_view, N_per_rank)
>>> ctx.reduce_scatter(tile, src_view, dst_view)
iris/x/init.py:45
- The example code shows calling
ctx.all_reduce(...)as a method on DeviceContext, but DeviceContext doesn't have an all_reduce method. The correct usage is to call the standalone function:iris.x.all_reduce_ring(tile, src_view, dst_view, ctx)or use one of the other all_reduce variants.
>>> ctx = DeviceContext.initialize(context_tensor, rank, world_size)
>>>
>>> # Use ring algorithm
>>> config = iris.x.AllReduceConfig("ring")
>>> ctx.all_reduce(tile, src_view, dst_view, config=config)
>>>
>>> # Use spinlock with locks
>>> config = iris.x.AllReduceConfig("spinlock", locks_ptr)
>>> tile_id = pid_m * num_tiles_n + pid_n
>>> ctx.all_reduce(tile, src_view, dst_view, config=config, tile_id=tile_id)
| # Convert heap_bases to a list for concatenation | ||
| heap_bases_list = self.heap_bases.tolist() | ||
|
|
||
| # Create context tensor: [cur_rank, world_size, heap_base_0, heap_base_1, ...] | ||
| context_data = [self.cur_rank, self.num_ranks] + heap_bases_list | ||
| context_tensor = torch.tensor(context_data, dtype=torch.int64, device=self.device) |
Copilot
AI
Feb 3, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using tolist() to convert heap_bases from tensor to Python list causes a device-to-host transfer, which can be slow. Since the heap_bases are already on device, consider using tensor concatenation instead: context_tensor = torch.cat([torch.tensor([self.cur_rank, self.num_ranks], dtype=torch.int64, device=self.device), self.heap_bases]). This avoids the unnecessary round-trip through the host.
| # Convert heap_bases to a list for concatenation | |
| heap_bases_list = self.heap_bases.tolist() | |
| # Create context tensor: [cur_rank, world_size, heap_base_0, heap_base_1, ...] | |
| context_data = [self.cur_rank, self.num_ranks] + heap_bases_list | |
| context_tensor = torch.tensor(context_data, dtype=torch.int64, device=self.device) | |
| # Create a small tensor for [cur_rank, world_size] on the device | |
| rank_world = torch.tensor( | |
| [self.cur_rank, self.num_ranks], | |
| dtype=torch.int64, | |
| device=self.device, | |
| ) | |
| # Ensure heap_bases is int64 and concatenate on device: | |
| # [cur_rank, world_size, heap_base_0, heap_base_1, ...] | |
| heap_bases_int64 = self.heap_bases.to(dtype=torch.int64) | |
| context_tensor = torch.cat([rank_world, heap_bases_int64]) |
|
|
||
| def to_json(self, filename, gpu_freq): | ||
| cycles_to_us = lambda cycles: (cycles / gpu_freq) | ||
| cycles_to_us = lambda cycles: cycles / gpu_freq |
Copilot
AI
Feb 3, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change (removing parentheses from lambda expression) appears to be an unrelated style fix that should ideally be in a separate commit or PR. While it's a minor improvement, mixing unrelated changes makes it harder to track the purpose of each modification.
| cycles_to_us = lambda cycles: cycles / gpu_freq | |
| def cycles_to_us(cycles): | |
| return cycles / gpu_freq |
| pid = tl.program_id(0) | ||
| partner = int((cur_rank + num_ranks // 2) % num_ranks) | ||
|
|
||
| old = ctx.atomic_cas(flag + pid, 0, 1, to_rank=partner, sem="release", scope="sys") |
Copilot
AI
Feb 3, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Variable old is not used.
| old = ctx.atomic_cas(flag + pid, 0, 1, to_rank=partner, sem="release", scope="sys") | |
| ctx.atomic_cas(flag + pid, 0, 1, to_rank=partner, sem="release", scope="sys") |
| import iris | ||
| from iris import DeviceContext | ||
|
|
||
|
|
Copilot
AI
Feb 3, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Module 'iris' is imported with both 'import' and 'import from'.
| import iris | |
| from iris import DeviceContext | |
| from iris import DeviceContext |
|
@mawad-amd I've opened a new pull request, #351, to work on those changes. Once the pull request is ready, I'll request review from you. |
…tern (#351) Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com>
Motivation
Provide a cleaner, object-oriented API for device-side Iris operations by introducing
DeviceContextto the main Iris module, following the Gluon pattern. This eliminates the need to explicitly passheap_basesto every operation and provides a more ergonomic interface for kernel development.Technical Details
Core API:
DeviceContextaggregate class toiris/iris.pywithinitialize()static methodrank,world_size, andheap_basesin a single device-side context objectload,store,get,put,copy, and all atomics (add,sub,cas,xchg,xor,and,or,min,max)Iris.get_device_context()method that returns an encoded context tensor (format:[rank, world_size, heap_base_0, ...])Usage example:
Infrastructure updates:
iris.xcollective operations (all_reduce.py,all_gather.py,all_to_all.py,gather.py,reduce_scatter.py) to importDeviceContextfrom main moduleDeviceContextfromiris/x/core.pyDeviceContextfromiris/__init__.pyfor direct importTest Plan
tests/unittests/test_device_context.pycovering:load(),store(),get(),put()operations with multiple dtypes and block sizesatomic_add,atomic_cas)examples/06_message_passing/message_passing_device_context.pydemonstrating producer-consumer pattern usingDeviceContextAPITest Result
Submission Checklist