Skip to content

Conversation

@mawad-amd
Copy link
Collaborator

Motivation

Provide a cleaner, object-oriented API for device-side Iris operations by introducing DeviceContext to the main Iris module, following the Gluon pattern. This eliminates the need to explicitly pass heap_bases to every operation and provides a more ergonomic interface for kernel development.

Technical Details

Core API:

  • Added DeviceContext aggregate class to iris/iris.py with initialize() static method
  • Encapsulates rank, world_size, and heap_bases in a single device-side context object
  • Provides OOP methods for all operations: load, store, get, put, copy, and all atomics (add, sub, cas, xchg, xor, and, or, min, max)
  • Added Iris.get_device_context() method that returns an encoded context tensor (format: [rank, world_size, heap_base_0, ...])

Usage example:

# Host-side
ctx = iris.iris()
context_tensor = ctx.get_device_context()

# Device-side
@triton.jit
def kernel(context_tensor, rank: tl.constexpr, world_size: tl.constexpr):
    ctx = DeviceContext.initialize(context_tensor, rank, world_size)
    data = ctx.load(ptr, from_rank=1)          # Cleaner API
    ctx.atomic_add(counter, 1, to_rank=1)     # No heap_bases needed

Infrastructure updates:

  • Updated iris.x collective operations (all_reduce.py, all_gather.py, all_to_all.py, gather.py, reduce_scatter.py) to import DeviceContext from main module
  • Removed skeletal DeviceContext from iris/x/core.py
  • Exported DeviceContext from iris/__init__.py for direct import

Test Plan

  • Added unit tests in tests/unittests/test_device_context.py covering:
    • load(), store(), get(), put() operations with multiple dtypes and block sizes
    • Atomic operations (atomic_add, atomic_cas)
    • Context initialization and import paths
  • Added example in examples/06_message_passing/message_passing_device_context.py demonstrating producer-consumer pattern using DeviceContext API
  • All tests follow existing codebase patterns with proper parametrization and cleanup

Test Result

  • New tests are passing

Submission Checklist

@github-actions github-actions bot added in-progress We are working on it iris Iris project issue labels Feb 3, 2026
@mawad-amd mawad-amd marked this pull request as ready for review February 3, 2026 19:48
Copilot AI review requested due to automatic review settings February 3, 2026 19:48
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This pull request adds a DeviceContext class to the main Iris module to provide a cleaner, object-oriented API for device-side Iris operations. The PR follows the "Gluon pattern" by encapsulating rank, world_size, and heap_bases in a single context object that can be initialized on the device from an encoded tensor.

Changes:

  • Introduces DeviceContext aggregate class in iris/iris.py with methods for load, store, get, put, copy, and all atomic operations
  • Adds Iris.get_device_context() method to create an encoded context tensor containing rank, world size, and heap bases
  • Moves DeviceContext from iris/x/core.py to iris/iris.py and updates all iris.x collective operations to import from the new location

Reviewed changes

Copilot reviewed 12 out of 12 changed files in this pull request and generated 6 comments.

Show a summary per file
File Description
iris/iris.py Adds DeviceContext class with OOP API and get_device_context() method for creating encoded context tensors
iris/init.py Exports DeviceContext from main iris module
iris/x/core.py Removes skeletal DeviceContext definition and updates all exports
iris/x/init.py Updates documentation examples and removes DeviceContext from exports
iris/x/all_reduce.py Updates import to use DeviceContext from iris.iris instead of iris.x.core
iris/x/all_gather.py Updates import to use DeviceContext from iris.iris instead of iris.x.core
iris/x/all_to_all.py Updates import to use DeviceContext from iris.iris instead of iris.x.core
iris/x/gather.py Updates import to use DeviceContext from iris.iris instead of iris.x.core
iris/x/reduce_scatter.py Updates import to use DeviceContext from iris.iris instead of iris.x.core
tests/unittests/test_device_context.py Adds comprehensive unit tests for DeviceContext methods including load, store, get, put, atomic operations
examples/06_message_passing/message_passing_device_context.py Adds example demonstrating producer-consumer pattern using DeviceContext API
examples/common/utils.py Minor style fix removing redundant parentheses from lambda expression
Comments suppressed due to low confidence (2)

iris/x/init.py:28

  • The example code shows calling collective operations as methods on DeviceContext (e.g., ctx.all_reduce(...), ctx.all_gather(...)), but DeviceContext doesn't have these methods. The collective operations are standalone functions that take a DeviceContext as a parameter. The correct usage is iris.x.all_reduce(tile, src_view, dst_view, ctx) not ctx.all_reduce(tile, src_view, dst_view).
    >>>     ctx = DeviceContext.initialize(context_tensor, rank, world_size)
    >>>
    >>>     # Call collectives on ctx directly (default algorithms)
    >>>     ctx.all_reduce(tile, src_view, dst_view)
    >>>     ctx.all_gather(tile, src_view, dst_view, dim=0)
    >>>     ctx.all_to_all(tile, src_view, dst_view, N_per_rank)
    >>>     ctx.reduce_scatter(tile, src_view, dst_view)

iris/x/init.py:45

  • The example code shows calling ctx.all_reduce(...) as a method on DeviceContext, but DeviceContext doesn't have an all_reduce method. The correct usage is to call the standalone function: iris.x.all_reduce_ring(tile, src_view, dst_view, ctx) or use one of the other all_reduce variants.
    >>>     ctx = DeviceContext.initialize(context_tensor, rank, world_size)
    >>>
    >>>     # Use ring algorithm
    >>>     config = iris.x.AllReduceConfig("ring")
    >>>     ctx.all_reduce(tile, src_view, dst_view, config=config)
    >>>
    >>>     # Use spinlock with locks
    >>>     config = iris.x.AllReduceConfig("spinlock", locks_ptr)
    >>>     tile_id = pid_m * num_tiles_n + pid_n
    >>>     ctx.all_reduce(tile, src_view, dst_view, config=config, tile_id=tile_id)

Comment on lines +1177 to +1182
# Convert heap_bases to a list for concatenation
heap_bases_list = self.heap_bases.tolist()

# Create context tensor: [cur_rank, world_size, heap_base_0, heap_base_1, ...]
context_data = [self.cur_rank, self.num_ranks] + heap_bases_list
context_tensor = torch.tensor(context_data, dtype=torch.int64, device=self.device)
Copy link

Copilot AI Feb 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using tolist() to convert heap_bases from tensor to Python list causes a device-to-host transfer, which can be slow. Since the heap_bases are already on device, consider using tensor concatenation instead: context_tensor = torch.cat([torch.tensor([self.cur_rank, self.num_ranks], dtype=torch.int64, device=self.device), self.heap_bases]). This avoids the unnecessary round-trip through the host.

Suggested change
# Convert heap_bases to a list for concatenation
heap_bases_list = self.heap_bases.tolist()
# Create context tensor: [cur_rank, world_size, heap_base_0, heap_base_1, ...]
context_data = [self.cur_rank, self.num_ranks] + heap_bases_list
context_tensor = torch.tensor(context_data, dtype=torch.int64, device=self.device)
# Create a small tensor for [cur_rank, world_size] on the device
rank_world = torch.tensor(
[self.cur_rank, self.num_ranks],
dtype=torch.int64,
device=self.device,
)
# Ensure heap_bases is int64 and concatenate on device:
# [cur_rank, world_size, heap_base_0, heap_base_1, ...]
heap_bases_int64 = self.heap_bases.to(dtype=torch.int64)
context_tensor = torch.cat([rank_world, heap_bases_int64])

Copilot uses AI. Check for mistakes.

def to_json(self, filename, gpu_freq):
cycles_to_us = lambda cycles: (cycles / gpu_freq)
cycles_to_us = lambda cycles: cycles / gpu_freq
Copy link

Copilot AI Feb 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change (removing parentheses from lambda expression) appears to be an unrelated style fix that should ideally be in a separate commit or PR. While it's a minor improvement, mixing unrelated changes makes it harder to track the purpose of each modification.

Suggested change
cycles_to_us = lambda cycles: cycles / gpu_freq
def cycles_to_us(cycles):
return cycles / gpu_freq

Copilot uses AI. Check for mistakes.
pid = tl.program_id(0)
partner = int((cur_rank + num_ranks // 2) % num_ranks)

old = ctx.atomic_cas(flag + pid, 0, 1, to_rank=partner, sem="release", scope="sys")
Copy link

Copilot AI Feb 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Variable old is not used.

Suggested change
old = ctx.atomic_cas(flag + pid, 0, 1, to_rank=partner, sem="release", scope="sys")
ctx.atomic_cas(flag + pid, 0, 1, to_rank=partner, sem="release", scope="sys")

Copilot uses AI. Check for mistakes.
Comment on lines +8 to +11
import iris
from iris import DeviceContext


Copy link

Copilot AI Feb 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Module 'iris' is imported with both 'import' and 'import from'.

Suggested change
import iris
from iris import DeviceContext
from iris import DeviceContext

Copilot uses AI. Check for mistakes.
Copy link
Contributor

Copilot AI commented Feb 4, 2026

@mawad-amd I've opened a new pull request, #351, to work on those changes. Once the pull request is ready, I'll request review from you.

Copilot AI and others added 2 commits February 3, 2026 21:46
…tern (#351)

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

in-progress We are working on it iris Iris project issue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants