Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions arraycontext/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@
from .context import (
ArrayContext,
ArrayContextFactory,
CSRMatrix,
SparseMatrix,
tag_axes,
)
from .impl.jax import EagerJAXArrayContext
Expand Down Expand Up @@ -129,6 +131,7 @@
"ArrayOrScalarT",
"ArrayT",
"BcastUntilActxArray",
"CSRMatrix",
"CommonSubexpressionTag",
"ContainerOrScalarT",
"EagerJAXArrayContext",
Expand All @@ -144,6 +147,7 @@
"ScalarLike",
"SerializationKey",
"SerializedContainer",
"SparseMatrix",
"dataclass_array_container",
"deserialize_container",
"flat_size_and_dtype",
Expand Down
214 changes: 213 additions & 1 deletion arraycontext/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@

.. autoclass:: ArrayContext

.. autoclass:: SparseMatrix
.. autoclass:: CSRMatrix

.. autofunction:: tag_axes

.. class:: P
Expand Down Expand Up @@ -114,39 +117,47 @@
"""


import dataclasses
from abc import ABC, abstractmethod
from collections.abc import Callable, Hashable, Mapping
from typing import (
TYPE_CHECKING,
Any,
ParamSpec,
TypeAlias,
cast,
overload,
)
from warnings import warn

from typing_extensions import Self, TypeIs, override

from pytools import memoize_method
from pytools.tag import normalize_tags

# FIXME: remove sometime, this import was used in grudge in July 2025.
from .typing import ArrayOrArithContainerTc as ArrayOrArithContainerTc
from arraycontext.container.traversal import (
rec_map_container,
)


if TYPE_CHECKING:
import numpy as np
from numpy.typing import DTypeLike

import loopy
from pytools.tag import ToTagSetConvertible
from pytools.tag import Tag, ToTagSetConvertible

from .fake_numpy import BaseFakeNumpyNamespace
from .typing import (
Array,
ArrayContainerT,
ArrayOrArithContainerOrScalarT,
ArrayOrContainer,
ArrayOrContainerOrScalar,
ArrayOrContainerOrScalarT,
ArrayOrScalar,
ContainerOrScalarT,
NumpyOrContainerOrScalar,
ScalarLike,
Expand All @@ -155,6 +166,45 @@

P = ParamSpec("P")

_EMPTY_TAG_SET: frozenset[Tag] = frozenset()


@dataclasses.dataclass(frozen=True, eq=False, repr=False)
class SparseMatrix(ABC):
# FIXME: Type for shape?
shape: Any
tags: frozenset[Tag] = dataclasses.field(kw_only=True)
axes: tuple[ToTagSetConvertible, ...] = dataclasses.field(kw_only=True)
_actx: ArrayContext = dataclasses.field(kw_only=True)

@abstractmethod
def __matmul__(self, other: ArrayOrContainer) -> ArrayOrContainer:
...


@dataclasses.dataclass(frozen=True, eq=False, repr=False)
class CSRMatrix(SparseMatrix):
elem_values: Array
elem_col_indices: Array
row_starts: Array

@override
def __matmul__(self, other: ArrayOrContainer) -> ArrayOrContainer:
def _matmul(ary: ArrayOrScalar) -> ArrayOrScalar:
# FIXME: Should this do something to scalars? e.g., promote to uniform
# array?
assert self._actx.is_array_type(ary)
prg = self._actx._get_csr_matmul_prg(len(ary.shape))
out_ary = self._actx.call_loopy(
prg, elem_values=self.elem_values,
elem_col_indices=self.elem_col_indices,
row_starts=self.row_starts, array=ary)["out"]
# FIXME
# return self.tag(tagged, out_ary)
return out_ary

return cast("ArrayOrContainer", rec_map_container(_matmul, other))


# {{{ ArrayContext

Expand All @@ -172,6 +222,8 @@
.. automethod:: to_numpy
.. automethod:: call_loopy
.. automethod:: einsum
.. automethod:: make_csr_matrix
.. automethod:: sparse_matmul
.. attribute:: np

Provides access to a namespace that serves as a work-alike to
Expand Down Expand Up @@ -424,6 +476,166 @@
)["out"]
return self.tag(tagged, out_ary)

# FIXME: Not sure what type annotations to use for shape
def make_csr_matrix(
self,
shape,

Check warning on line 482 in arraycontext/context.py

View workflow job for this annotation

GitHub Actions / basedpyright

Type annotation is missing for parameter "shape" (reportMissingParameterType)

Check warning on line 482 in arraycontext/context.py

View workflow job for this annotation

GitHub Actions / basedpyright

Type of parameter "shape" is unknown (reportUnknownParameterType)
elem_values: Array,
elem_col_indices: Array,
row_starts: Array,
*,
tags: ToTagSetConvertible = _EMPTY_TAG_SET,
axes: tuple[ToTagSetConvertible, ...] | None = None) -> CSRMatrix:
"""Return a sparse matrix in compressed sparse row (CSR) format, to be used
with :meth:`sparse_matmul`.

:arg shape: the (two-dimensional) shape of the matrix
:arg elem_values: a one-dimensional array containing the values of all of the
nonzero entries of the matrix, grouped by row.
:arg elem_col_indices: a one-dimensional array containing the column index
values corresponding to each entry in *elem_values*.
:arg row_starts: a one-dimensional array of length `nrows+1`, where each entry
gives the starting index in *elem_values* and *elem_col_indices* for the
given row, with the last entry being equal to `nrows`.
"""
tags = normalize_tags(tags)

if axes is None:
axes = (frozenset(), frozenset())

return CSRMatrix(
shape, elem_values, elem_col_indices, row_starts,
tags=tags, axes=axes,
_actx=self)

@memoize_method
def _get_csr_matmul_prg(self, out_ndim: int) -> loopy.TranslationUnit:
import numpy as np

import loopy as lp

out_extra_inames = tuple(f"i{n}" for n in range(1, out_ndim))
out_inames = ("irow", *out_extra_inames)
out_inames_set = frozenset(out_inames)

out_extra_shape_comp_names = tuple(f"n{n}" for n in range(1, out_ndim))
out_shape_comp_names = ("nrows", *out_extra_shape_comp_names)

domains: list[str] = []
domains.append(
"{ [" + ",".join(out_inames) + "] : "
+ " and ".join(
f"0 <= {iname} < {shape_comp_name}"
for iname, shape_comp_name in zip(
out_inames, out_shape_comp_names, strict=True))
+ " }")
domains.append(
"{ [iel] : iel_lbound <= iel < iel_ubound }")

temporary_variables: Mapping[str, lp.TemporaryVariable] = {
"iel_lbound": lp.TemporaryVariable(
"iel_lbound",
shape=(),
address_space=lp.AddressSpace.GLOBAL,
# FIXME: Need to do anything with tags?
),
"iel_ubound": lp.TemporaryVariable(
"iel_ubound",
shape=(),
address_space=lp.AddressSpace.GLOBAL,
# FIXME: Need to do anything with tags?
)}

from loopy.kernel.instruction import make_assignment
from pymbolic import var
# FIXME: Need tags for any of these?
instructions: list[lp.Assignment | lp.CallInstruction] = [
make_assignment(
(var("iel_lbound"),),
var("row_starts")[var("irow")],
id="insn0",
within_inames=out_inames_set),
make_assignment(
(var("iel_ubound"),),
var("row_starts")[var("irow") + 1],
id="insn1",
within_inames=out_inames_set),
make_assignment(
(var("out")[tuple(var(iname) for iname in out_inames)],),
lp.Reduction(
"sum",
(var("iel"),),
var("elem_values")[var("iel"),]
* var("array")[(
var("elem_col_indices")[var("iel"),],
*(var(iname) for iname in out_extra_inames))]),
id="insn2",
within_inames=out_inames_set,
depends_on=frozenset({"insn0", "insn1"}))]

from loopy.version import MOST_RECENT_LANGUAGE_VERSION

from .loopy import _DEFAULT_LOOPY_OPTIONS

knl = lp.make_kernel(
domains=domains,
instructions=instructions,
temporary_variables=temporary_variables,
kernel_data=[
lp.ValueArg("nrows", is_input=True),
lp.ValueArg("ncols", is_input=True),
lp.ValueArg("nels", is_input=True),
*(
lp.ValueArg(shape_comp_name, is_input=True)
for shape_comp_name in out_extra_shape_comp_names),
lp.GlobalArg("elem_values", shape=(var("nels"),), is_input=True),

Check warning on line 591 in arraycontext/context.py

View workflow job for this annotation

GitHub Actions / basedpyright

Type of "GlobalArg" is partially unknown   Type of "GlobalArg" is "(...) -> ArrayArg" (reportUnknownMemberType)
lp.GlobalArg("elem_col_indices", shape=(var("nels"),), is_input=True),

Check warning on line 592 in arraycontext/context.py

View workflow job for this annotation

GitHub Actions / basedpyright

Type of "GlobalArg" is partially unknown   Type of "GlobalArg" is "(...) -> ArrayArg" (reportUnknownMemberType)
lp.GlobalArg("row_starts", shape=lp.auto, is_input=True),

Check warning on line 593 in arraycontext/context.py

View workflow job for this annotation

GitHub Actions / basedpyright

Type of "GlobalArg" is partially unknown   Type of "GlobalArg" is "(...) -> ArrayArg" (reportUnknownMemberType)
lp.GlobalArg(

Check warning on line 594 in arraycontext/context.py

View workflow job for this annotation

GitHub Actions / basedpyright

Type of "GlobalArg" is partially unknown   Type of "GlobalArg" is "(...) -> ArrayArg" (reportUnknownMemberType)
"array",
shape=(
var("ncols"),
*(
var(shape_comp_name)
for shape_comp_name in out_extra_shape_comp_names),),
is_input=True),
lp.GlobalArg(

Check warning on line 602 in arraycontext/context.py

View workflow job for this annotation

GitHub Actions / basedpyright

Type of "GlobalArg" is partially unknown   Type of "GlobalArg" is "(...) -> ArrayArg" (reportUnknownMemberType)
"out",
shape=(
var("nrows"),
*(
var(shape_comp_name)
for shape_comp_name in out_extra_shape_comp_names),),
is_input=False),
...],
name="csr_matmul_kernel",
lang_version=MOST_RECENT_LANGUAGE_VERSION,
options=_DEFAULT_LOOPY_OPTIONS,
default_order=lp.auto,
default_offset=lp.auto,
# FIXME: Need to do anything with tags?
)

idx_dtype = knl.default_entrypoint.index_dtype

return lp.add_and_infer_dtypes(

Check warning on line 621 in arraycontext/context.py

View workflow job for this annotation

GitHub Actions / basedpyright

Type of "add_and_infer_dtypes" is partially unknown   Type of "add_and_infer_dtypes" is "(prog: Unknown, dtype_dict: Unknown, expect_completion: bool = False, kernel_name: Unknown | None = None) -> TranslationUnit" (reportUnknownMemberType)
knl,
{
",".join([
"ncols", "nrows", "nels",
*out_extra_shape_comp_names]): idx_dtype,
"elem_values,array,out": np.float64,
"elem_col_indices,row_starts": idx_dtype})

def sparse_matmul(
self, x1: SparseMatrix, x2: ArrayOrContainer) -> ArrayOrContainer:
"""Multiply a sparse matrix by an array.

:arg x1: the sparse matrix.
:arg x2: the array.
"""
return x1 @ x2

@abstractmethod
def clone(self) -> Self:
"""If possible, return a version of *self* that is semantically
Expand Down
27 changes: 25 additions & 2 deletions arraycontext/impl/jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,16 @@
from typing import TYPE_CHECKING, cast

import numpy as np
from typing_extensions import override

from arraycontext.container.traversal import (
rec_map_container,
with_array_context,
)
from arraycontext.context import ArrayContext
from arraycontext.context import ArrayContext, CSRMatrix, SparseMatrix
from arraycontext.typing import (
Array,
ArrayOrContainer,
ArrayOrContainerOrScalar,
ArrayOrScalar,
ScalarLike,
Expand All @@ -51,7 +53,10 @@
if TYPE_CHECKING:
from collections.abc import Callable

from pytools.tag import ToTagSetConvertible
from pytools.tag import Tag, ToTagSetConvertible


_EMPTY_TAG_SET: frozenset[Tag] = frozenset()


class EagerJAXArrayContext(ArrayContext):
Expand Down Expand Up @@ -150,6 +155,24 @@
import jax.numpy as jnp
return jnp.einsum(spec, *args)

# FIXME: Not sure what type annotations to use for shape
@override
def make_csr_matrix(
self,
shape,

Check warning on line 162 in arraycontext/impl/jax/__init__.py

View workflow job for this annotation

GitHub Actions / basedpyright

Type annotation is missing for parameter "shape" (reportMissingParameterType)

Check warning on line 162 in arraycontext/impl/jax/__init__.py

View workflow job for this annotation

GitHub Actions / basedpyright

Type of parameter "shape" is unknown (reportUnknownParameterType)
elem_values: Array,
elem_col_indices: Array,
row_starts: Array,
*,
tags: ToTagSetConvertible = _EMPTY_TAG_SET,
axes: tuple[ToTagSetConvertible, ...] | None = None) -> CSRMatrix:
raise NotImplementedError("Sparse matrices aren't yet supported with JAX.")

@override
def sparse_matmul(
self, x1: SparseMatrix, x2: ArrayOrContainer) -> ArrayOrContainer:
raise NotImplementedError("Sparse matrices aren't yet supported with JAX.")

def clone(self):
return type(self)()

Expand Down
Loading
Loading