diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py index cc577f55..82028207 100644 --- a/arraycontext/__init__.py +++ b/arraycontext/__init__.py @@ -73,6 +73,8 @@ from .context import ( ArrayContext, ArrayContextFactory, + CSRMatrix, + SparseMatrix, tag_axes, ) from .impl.jax import EagerJAXArrayContext @@ -129,6 +131,7 @@ "ArrayOrScalarT", "ArrayT", "BcastUntilActxArray", + "CSRMatrix", "CommonSubexpressionTag", "ContainerOrScalarT", "EagerJAXArrayContext", @@ -144,6 +147,7 @@ "ScalarLike", "SerializationKey", "SerializedContainer", + "SparseMatrix", "dataclass_array_container", "deserialize_container", "flat_size_and_dtype", diff --git a/arraycontext/context.py b/arraycontext/context.py index e0b2b907..57e45d75 100644 --- a/arraycontext/context.py +++ b/arraycontext/context.py @@ -78,6 +78,9 @@ .. autoclass:: ArrayContext +.. autoclass:: SparseMatrix +.. autoclass:: CSRMatrix + .. autofunction:: tag_axes .. class:: P @@ -114,6 +117,7 @@ """ +import dataclasses from abc import ABC, abstractmethod from collections.abc import Callable, Hashable, Mapping from typing import ( @@ -121,6 +125,7 @@ Any, ParamSpec, TypeAlias, + cast, overload, ) from warnings import warn @@ -128,9 +133,13 @@ from typing_extensions import Self, TypeIs, override from pytools import memoize_method +from pytools.tag import normalize_tags # FIXME: remove sometime, this import was used in grudge in July 2025. from .typing import ArrayOrArithContainerTc as ArrayOrArithContainerTc +from arraycontext.container.traversal import ( + rec_map_container, +) if TYPE_CHECKING: @@ -138,15 +147,17 @@ from numpy.typing import DTypeLike import loopy - from pytools.tag import ToTagSetConvertible + from pytools.tag import Tag, ToTagSetConvertible from .fake_numpy import BaseFakeNumpyNamespace from .typing import ( Array, ArrayContainerT, ArrayOrArithContainerOrScalarT, + ArrayOrContainer, ArrayOrContainerOrScalar, ArrayOrContainerOrScalarT, + ArrayOrScalar, ContainerOrScalarT, NumpyOrContainerOrScalar, ScalarLike, @@ -155,6 +166,45 @@ P = ParamSpec("P") +_EMPTY_TAG_SET: frozenset[Tag] = frozenset() + + +@dataclasses.dataclass(frozen=True, eq=False, repr=False) +class SparseMatrix(ABC): + # FIXME: Type for shape? + shape: Any + tags: frozenset[Tag] = dataclasses.field(kw_only=True) + axes: tuple[ToTagSetConvertible, ...] = dataclasses.field(kw_only=True) + _actx: ArrayContext = dataclasses.field(kw_only=True) + + @abstractmethod + def __matmul__(self, other: ArrayOrContainer) -> ArrayOrContainer: + ... + + +@dataclasses.dataclass(frozen=True, eq=False, repr=False) +class CSRMatrix(SparseMatrix): + elem_values: Array + elem_col_indices: Array + row_starts: Array + + @override + def __matmul__(self, other: ArrayOrContainer) -> ArrayOrContainer: + def _matmul(ary: ArrayOrScalar) -> ArrayOrScalar: + # FIXME: Should this do something to scalars? e.g., promote to uniform + # array? + assert self._actx.is_array_type(ary) + prg = self._actx._get_csr_matmul_prg(len(ary.shape)) + out_ary = self._actx.call_loopy( + prg, elem_values=self.elem_values, + elem_col_indices=self.elem_col_indices, + row_starts=self.row_starts, array=ary)["out"] + # FIXME + # return self.tag(tagged, out_ary) + return out_ary + + return cast("ArrayOrContainer", rec_map_container(_matmul, other)) + # {{{ ArrayContext @@ -172,6 +222,8 @@ class ArrayContext(ABC): .. automethod:: to_numpy .. automethod:: call_loopy .. automethod:: einsum + .. automethod:: make_csr_matrix + .. automethod:: sparse_matmul .. attribute:: np Provides access to a namespace that serves as a work-alike to @@ -424,6 +476,166 @@ def einsum(self, )["out"] return self.tag(tagged, out_ary) + # FIXME: Not sure what type annotations to use for shape + def make_csr_matrix( + self, + shape, + elem_values: Array, + elem_col_indices: Array, + row_starts: Array, + *, + tags: ToTagSetConvertible = _EMPTY_TAG_SET, + axes: tuple[ToTagSetConvertible, ...] | None = None) -> CSRMatrix: + """Return a sparse matrix in compressed sparse row (CSR) format, to be used + with :meth:`sparse_matmul`. + + :arg shape: the (two-dimensional) shape of the matrix + :arg elem_values: a one-dimensional array containing the values of all of the + nonzero entries of the matrix, grouped by row. + :arg elem_col_indices: a one-dimensional array containing the column index + values corresponding to each entry in *elem_values*. + :arg row_starts: a one-dimensional array of length `nrows+1`, where each entry + gives the starting index in *elem_values* and *elem_col_indices* for the + given row, with the last entry being equal to `nrows`. + """ + tags = normalize_tags(tags) + + if axes is None: + axes = (frozenset(), frozenset()) + + return CSRMatrix( + shape, elem_values, elem_col_indices, row_starts, + tags=tags, axes=axes, + _actx=self) + + @memoize_method + def _get_csr_matmul_prg(self, out_ndim: int) -> loopy.TranslationUnit: + import numpy as np + + import loopy as lp + + out_extra_inames = tuple(f"i{n}" for n in range(1, out_ndim)) + out_inames = ("irow", *out_extra_inames) + out_inames_set = frozenset(out_inames) + + out_extra_shape_comp_names = tuple(f"n{n}" for n in range(1, out_ndim)) + out_shape_comp_names = ("nrows", *out_extra_shape_comp_names) + + domains: list[str] = [] + domains.append( + "{ [" + ",".join(out_inames) + "] : " + + " and ".join( + f"0 <= {iname} < {shape_comp_name}" + for iname, shape_comp_name in zip( + out_inames, out_shape_comp_names, strict=True)) + + " }") + domains.append( + "{ [iel] : iel_lbound <= iel < iel_ubound }") + + temporary_variables: Mapping[str, lp.TemporaryVariable] = { + "iel_lbound": lp.TemporaryVariable( + "iel_lbound", + shape=(), + address_space=lp.AddressSpace.GLOBAL, + # FIXME: Need to do anything with tags? + ), + "iel_ubound": lp.TemporaryVariable( + "iel_ubound", + shape=(), + address_space=lp.AddressSpace.GLOBAL, + # FIXME: Need to do anything with tags? + )} + + from loopy.kernel.instruction import make_assignment + from pymbolic import var + # FIXME: Need tags for any of these? + instructions: list[lp.Assignment | lp.CallInstruction] = [ + make_assignment( + (var("iel_lbound"),), + var("row_starts")[var("irow")], + id="insn0", + within_inames=out_inames_set), + make_assignment( + (var("iel_ubound"),), + var("row_starts")[var("irow") + 1], + id="insn1", + within_inames=out_inames_set), + make_assignment( + (var("out")[tuple(var(iname) for iname in out_inames)],), + lp.Reduction( + "sum", + (var("iel"),), + var("elem_values")[var("iel"),] + * var("array")[( + var("elem_col_indices")[var("iel"),], + *(var(iname) for iname in out_extra_inames))]), + id="insn2", + within_inames=out_inames_set, + depends_on=frozenset({"insn0", "insn1"}))] + + from loopy.version import MOST_RECENT_LANGUAGE_VERSION + + from .loopy import _DEFAULT_LOOPY_OPTIONS + + knl = lp.make_kernel( + domains=domains, + instructions=instructions, + temporary_variables=temporary_variables, + kernel_data=[ + lp.ValueArg("nrows", is_input=True), + lp.ValueArg("ncols", is_input=True), + lp.ValueArg("nels", is_input=True), + *( + lp.ValueArg(shape_comp_name, is_input=True) + for shape_comp_name in out_extra_shape_comp_names), + lp.GlobalArg("elem_values", shape=(var("nels"),), is_input=True), + lp.GlobalArg("elem_col_indices", shape=(var("nels"),), is_input=True), + lp.GlobalArg("row_starts", shape=lp.auto, is_input=True), + lp.GlobalArg( + "array", + shape=( + var("ncols"), + *( + var(shape_comp_name) + for shape_comp_name in out_extra_shape_comp_names),), + is_input=True), + lp.GlobalArg( + "out", + shape=( + var("nrows"), + *( + var(shape_comp_name) + for shape_comp_name in out_extra_shape_comp_names),), + is_input=False), + ...], + name="csr_matmul_kernel", + lang_version=MOST_RECENT_LANGUAGE_VERSION, + options=_DEFAULT_LOOPY_OPTIONS, + default_order=lp.auto, + default_offset=lp.auto, + # FIXME: Need to do anything with tags? + ) + + idx_dtype = knl.default_entrypoint.index_dtype + + return lp.add_and_infer_dtypes( + knl, + { + ",".join([ + "ncols", "nrows", "nels", + *out_extra_shape_comp_names]): idx_dtype, + "elem_values,array,out": np.float64, + "elem_col_indices,row_starts": idx_dtype}) + + def sparse_matmul( + self, x1: SparseMatrix, x2: ArrayOrContainer) -> ArrayOrContainer: + """Multiply a sparse matrix by an array. + + :arg x1: the sparse matrix. + :arg x2: the array. + """ + return x1 @ x2 + @abstractmethod def clone(self) -> Self: """If possible, return a version of *self* that is semantically diff --git a/arraycontext/impl/jax/__init__.py b/arraycontext/impl/jax/__init__.py index 73f29cf5..1b5fb8e6 100644 --- a/arraycontext/impl/jax/__init__.py +++ b/arraycontext/impl/jax/__init__.py @@ -33,14 +33,16 @@ from typing import TYPE_CHECKING, cast import numpy as np +from typing_extensions import override from arraycontext.container.traversal import ( rec_map_container, with_array_context, ) -from arraycontext.context import ArrayContext +from arraycontext.context import ArrayContext, CSRMatrix, SparseMatrix from arraycontext.typing import ( Array, + ArrayOrContainer, ArrayOrContainerOrScalar, ArrayOrScalar, ScalarLike, @@ -51,7 +53,10 @@ if TYPE_CHECKING: from collections.abc import Callable - from pytools.tag import ToTagSetConvertible + from pytools.tag import Tag, ToTagSetConvertible + + +_EMPTY_TAG_SET: frozenset[Tag] = frozenset() class EagerJAXArrayContext(ArrayContext): @@ -150,6 +155,24 @@ def einsum(self, spec, *args, arg_names=None, tagged=()): import jax.numpy as jnp return jnp.einsum(spec, *args) + # FIXME: Not sure what type annotations to use for shape + @override + def make_csr_matrix( + self, + shape, + elem_values: Array, + elem_col_indices: Array, + row_starts: Array, + *, + tags: ToTagSetConvertible = _EMPTY_TAG_SET, + axes: tuple[ToTagSetConvertible, ...] | None = None) -> CSRMatrix: + raise NotImplementedError("Sparse matrices aren't yet supported with JAX.") + + @override + def sparse_matmul( + self, x1: SparseMatrix, x2: ArrayOrContainer) -> ArrayOrContainer: + raise NotImplementedError("Sparse matrices aren't yet supported with JAX.") + def clone(self): return type(self)() diff --git a/arraycontext/impl/numpy/__init__.py b/arraycontext/impl/numpy/__init__.py index 265d7a53..44e33cdf 100644 --- a/arraycontext/impl/numpy/__init__.py +++ b/arraycontext/impl/numpy/__init__.py @@ -33,12 +33,15 @@ THE SOFTWARE. """ +from dataclasses import dataclass +from functools import cached_property from typing import TYPE_CHECKING, Any, cast, overload import numpy as np from typing_extensions import override import loopy as lp +from pytools.tag import normalize_tags from arraycontext.container.traversal import ( rec_map_array_container as rec_map_array_container, @@ -47,10 +50,12 @@ ) from arraycontext.context import ( ArrayContext, + CSRMatrix as _BaseCSRMatrix, UntransformedCodeWarning, ) from arraycontext.typing import ( Array, + ArrayOrContainer, ArrayOrContainerOrScalar, ArrayOrContainerOrScalarT, ContainerOrScalarT, @@ -60,12 +65,17 @@ if TYPE_CHECKING: + import scipy.sparse + from pymbolic import Scalar - from pytools.tag import ToTagSetConvertible + from pytools.tag import Tag, ToTagSetConvertible from arraycontext.typing import ArrayContainerT +_EMPTY_TAG_SET: frozenset[Tag] = frozenset() + + class NumpyNonObjectArrayMetaclass(type): @override def __instancecheck__(cls, instance: object) -> bool: @@ -76,6 +86,29 @@ class NumpyNonObjectArray(metaclass=NumpyNonObjectArrayMetaclass): pass +@dataclass(frozen=True, eq=False, repr=False) +class CSRMatrix(_BaseCSRMatrix): + @cached_property + def _np_matrix(self) -> scipy.sparse.csr_matrix: + assert isinstance(self.elem_values, np.ndarray) + assert isinstance(self.elem_col_indices, np.ndarray) + assert isinstance(self.row_starts, np.ndarray) + # FIXME: Not sure if the scipy dependency is OK or if it should just use the + # call_loopy fallback? Currently getting errors with the loopy version: + # loopy.diagnostic.LoopyError: One of the kernels in the program has + # been preprocessed, cannot modify target now. + from scipy.sparse import csr_matrix + return csr_matrix( + (self.elem_values, self.elem_col_indices, self.row_starts), + shape=self.shape) + + @override + def __matmul__(self, other: ArrayOrContainer) -> ArrayOrContainer: + return cast( + "ArrayOrContainer", + rec_map_container(lambda ary: self._np_matrix @ ary, other)) + + class NumpyArrayContext(ArrayContext): """ A :class:`ArrayContext` that uses :class:`numpy.ndarray` to represent arrays. @@ -199,6 +232,25 @@ def tag_axis(self, def einsum(self, spec, *args, arg_names=None, tagged=()): return np.einsum(spec, *args, optimize="optimal") + # FIXME: Not sure what type annotations to use for shape + @override + def make_csr_matrix( + self, + shape, + elem_values: Array, + elem_col_indices: Array, + row_starts: Array, + *, + tags: ToTagSetConvertible = _EMPTY_TAG_SET, + axes: tuple[ToTagSetConvertible, ...] | None = None) -> CSRMatrix: + tags = normalize_tags(tags) + if axes is None: + axes = (frozenset(), frozenset()) + return CSRMatrix( + shape, elem_values, elem_col_indices, row_starts, + tags=tags, axes=axes, + _actx=self) + @property def permits_inplace_modification(self): return True diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index 6e883f55..8eb8baf1 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -54,6 +54,7 @@ import abc import sys from dataclasses import dataclass +from functools import cached_property from typing import TYPE_CHECKING, Any, cast import numpy as np @@ -68,13 +69,16 @@ ) from arraycontext.context import ( ArrayContext, + CSRMatrix as _BaseCSRMatrix, P, + SparseMatrix, UntransformedCodeWarning, ) from arraycontext.metadata import NameHint from arraycontext.typing import ( Array, ArrayOrArithContainerOrScalarT, + ArrayOrContainer, ArrayOrContainerOrScalarT, ArrayOrScalar, ScalarLike, @@ -140,6 +144,30 @@ class _NotOnlyDataWrappers(Exception): # noqa: N818 pass +@dataclass(frozen=True, eq=False, repr=False) +class CSRMatrix(_BaseCSRMatrix): + @cached_property + def _pt_matrix(self) -> pt.CSRMatrix: + import pytato as pt + assert isinstance(self.elem_values, pt.Array) + assert isinstance(self.elem_col_indices, pt.Array) + assert isinstance(self.row_starts, pt.Array) + return pt.make_csr_matrix( + self.shape, self.elem_values, self.elem_col_indices, self.row_starts, + tags=self.tags, axes=self.axes) + + @override + def __matmul__(self, other: ArrayOrContainer) -> ArrayOrContainer: + def _matmul(ary: ArrayOrScalar) -> ArrayOrScalar: + # FIXME: Should this do something to scalars? e.g., promote to uniform + # array? + import pytato as pt + assert isinstance(ary, pt.Array) + return self._pt_matrix @ ary + + return cast("ArrayOrContainer", rec_map_container(_matmul, other)) + + # {{{ _BasePytatoArrayContext class _BasePytatoArrayContext(ArrayContext, abc.ABC): @@ -833,6 +861,7 @@ def transform_dag(self, dag: pytato.AbstractResultWithNamedArrays dag = pt.transform.materialize_with_mpms(dag) return dag + @override def einsum(self, spec, *args, arg_names=None, tagged=()): import pytato as pt @@ -876,6 +905,25 @@ def preprocess_arg(name, arg): for name, arg in zip(arg_names, args, strict=True) ]).tagged(_preprocess_array_tags(tagged)) + # FIXME: Not sure what type annotations to use for shape + @override + def make_csr_matrix( + self, + shape, + elem_values: Array, + elem_col_indices: Array, + row_starts: Array, + *, + tags: ToTagSetConvertible = _EMPTY_TAG_SET, + axes: tuple[ToTagSetConvertible, ...] | None = None) -> CSRMatrix: + if axes is None: + axes = (frozenset(), frozenset()) + return CSRMatrix( + shape, elem_values, elem_col_indices, row_starts, + # FIXME: Do I need to call _preprocess_array_tags on axes? + tags=tags, axes=axes, + _actx=self) + def clone(self): return type(self)(self.queue, self.allocator) @@ -1115,6 +1163,24 @@ def preprocess_arg(name: str | None, arg: Array): for name, arg in zip(arg_names, args, strict=True) ]).tagged(_preprocess_array_tags(tagged))) + # FIXME: Not sure what type annotations to use for shape + @override + def make_csr_matrix( + self, + shape, + elem_values: Array, + elem_col_indices: Array, + row_starts: Array, + *, + tags: ToTagSetConvertible = _EMPTY_TAG_SET, + axes: tuple[ToTagSetConvertible, ...] | None = None) -> CSRMatrix: + raise NotImplementedError("Sparse matrices aren't yet supported with JAX.") + + @override + def sparse_matmul( + self, x1: SparseMatrix, x2: ArrayOrContainer) -> ArrayOrContainer: + raise NotImplementedError("Sparse matrices aren't yet supported with JAX.") + @override def clone(self): return type(self)() diff --git a/requirements.txt b/requirements.txt index a4cb4025..54a2a5ec 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,4 +6,4 @@ git+https://github.com/inducer/pyopencl.git#egg=pyopencl git+https://github.com/inducer/islpy.git#egg=islpy git+https://github.com/inducer/loopy.git#egg=loopy -git+https://github.com/inducer/pytato.git#egg=pytato +git+https://github.com/majosm/pytato.git@sparse-matrix#egg=pytato diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 6d7a38a4..bae0f12e 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -37,10 +37,12 @@ from arraycontext import ( ArrayContextFactory, + ArrayOrScalar, BcastUntilActxArray, EagerJAXArrayContext, NumpyArrayContext, PyOpenCLArrayContext, + PytatoJAXArrayContext, PytatoPyOpenCLArrayContext, dataclass_array_container, pytest_generate_tests_for_array_contexts, @@ -656,6 +658,71 @@ def test_array_context_einsum_array_tripleprod(actx_factory: ArrayContextFactory # }}} +def test_array_context_csr_matmul(actx_factory: ArrayContextFactory): + actx = actx_factory() + + if isinstance(actx, (EagerJAXArrayContext, PytatoJAXArrayContext)): + pytest.skip(f"not implemented for '{type(actx).__name__}'") + + n = 100 + + x = actx.from_numpy(np.arange(n, dtype=np.float64)) + ary_of_x = obj_array.new_1d([x] * 3) + dc_of_x = MyContainer( + name="container", + mass=x, + momentum=obj_array.new_1d([x] * 3), + enthalpy=x) + + elem_values = actx.zeros((n//2,), dtype=np.float64) + 1. + elem_col_indices = actx.from_numpy(2*np.arange(n//2, dtype=np.int32)) + row_starts = actx.from_numpy(np.arange(n//2 + 1, dtype=np.int32)) + + mat = actx.make_csr_matrix( + shape=(n//2, n), + elem_values=elem_values, + elem_col_indices=elem_col_indices, + row_starts=row_starts) + + expected_mat_x = actx.from_numpy(2 * np.arange(n//2, dtype=np.float64)) + + def _check_allclose( + arg1: ArrayOrScalar, arg2: ArrayOrScalar, atol: float = 1.0e-14): + from arraycontext import NotAnArrayContainerError + try: + arg1_iterable = serialize_container(arg1) + arg2_iterable = serialize_container(arg2) + except NotAnArrayContainerError: + assert np.linalg.norm(actx.to_numpy(arg1 - arg2)) < atol + else: + arg1_subarrays = [ + subarray for _, subarray in arg1_iterable] + arg2_subarrays = [ + subarray for _, subarray in arg2_iterable] + for subarray1, subarray2 in zip(arg1_subarrays, arg2_subarrays, + strict=True): + _check_allclose(subarray1, subarray2) + + # single array + res = mat @ x + expected_res = expected_mat_x + _check_allclose(res, expected_res) + + # array of arrays + res = mat @ ary_of_x + expected_res = obj_array.new_1d([expected_mat_x] * 3) + _check_allclose(res, expected_res) + + # container of arrays + res = mat @ dc_of_x + expected_res = MyContainer( + name="container", + mass=expected_mat_x, + momentum=obj_array.new_1d([expected_mat_x] * 3), + enthalpy=expected_mat_x) + _check_allclose(res, expected_res) + + # {{{ array container classes for test