Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions cuda_core/cuda/core/experimental/_kernel_arg_handler.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,23 @@
# SPDX-License-Identifier: Apache-2.0

from cpython.mem cimport PyMem_Malloc, PyMem_Free
from libc.stdint cimport (intptr_t,
from libc.stdint cimport (intptr_t, uintptr_t,
int8_t, int16_t, int32_t, int64_t,
uint8_t, uint16_t, uint32_t, uint64_t,)
from libcpp cimport bool as cpp_bool
from libcpp.complex cimport complex as cpp_complex
from libcpp cimport nullptr
from libcpp cimport vector

from cuda.bindings cimport cydriver
from cuda.core.experimental._memoryview cimport _MDSPAN

import ctypes

import numpy

from cuda.core.experimental._memory import Buffer
from cuda.core.experimental._utils.cuda_utils import driver
from cuda.bindings cimport cydriver


ctypedef cpp_complex.complex[float] cpp_single_complex
Expand Down Expand Up @@ -296,6 +298,12 @@ cdef class ParamHolder:
elif arg_type is complex:
prepare_arg[cpp_double_complex](self.data, self.data_addresses, arg, i)
continue
elif arg_type is _MDSPAN:
# The mdspan struct is allocated on the host and owned by the CuPy mdspan object.
# We pass a pointer to the struct so the driver can copy it by value to the kernel.
# Access _ptr at C level to avoid creating a temporary Python object.
self.data_addresses[i] = <void*>((<_MDSPAN>arg)._ptr)
continue

not_prepared = prepare_numpy_arg(self.data, self.data_addresses, arg, i)
if not_prepared:
Expand Down
10 changes: 10 additions & 0 deletions cuda_core/cuda/core/experimental/_memoryview.pxd
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from libc.stdint cimport uintptr_t


cdef class _MDSPAN:
cdef:
# this must be a pointer to a host mdspan object
readonly uintptr_t _ptr
# if the host mdspan is exported from any Python object,
# we need to keep a reference to that object alive
readonly object _exporting_obj
41 changes: 40 additions & 1 deletion cuda_core/cuda/core/experimental/_memoryview.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
#
# SPDX-License-Identifier: Apache-2.0

from libc.stdint cimport uintptr_t

from ._dlpack cimport *
from cuda.core.experimental._utils cimport cuda_utils

import functools
import warnings
Expand All @@ -11,12 +14,26 @@ from typing import Optional
import numpy

from cuda.core.experimental._utils.cuda_utils import handle_return, driver
from cuda.core.experimental._utils cimport cuda_utils


# TODO(leofang): support NumPy structured dtypes


cdef class _MDSPAN:

def __cinit__(self):
self._ptr = 0

def __init__(self, uintptr_t ptr, object obj=None):
self._ptr = ptr
self._exporting_obj = obj

def __dealloc__(self):
self._ptr = 0
self._exporting_obj = None



cdef class StridedMemoryView:
"""A dataclass holding metadata of a strided dense array/tensor.

Expand Down Expand Up @@ -98,6 +115,7 @@ cdef class StridedMemoryView:
# this flag helps prevent unnecessary recompuation of _strides
bint _strides_init
object _dtype
_MDSPAN _mdspan

def __init__(self, obj: object = None, stream_ptr: int | None = None) -> None:
cdef str clsname = self.__class__.__name__
Expand Down Expand Up @@ -224,6 +242,27 @@ cdef class StridedMemoryView:
self._dtype = numpy.dtype(self.metadata["typestr"])
return self._dtype

@property
def as_mdspan(self) -> _MDSPAN:
"""A C++ mdspan view of the tensor.

Returns
-------
mdspan : _MDSPAN
"""
if self._mdspan is None:
arr = self.exporting_obj
module = self.exporting_obj.__class__.__module__.split(".")[0]
if module == "cupy":
mdspan = arr.mdspan
#mdspan = arr.cstruct
self._mdspan = _MDSPAN(<uintptr_t>(mdspan.ptr), mdspan)
else:
raise NotImplementedError(
f"as_mdspan is not implemented for objects from module '{module}'"
)
return self._mdspan

def __repr__(self):
return (f"StridedMemoryView(ptr={self.ptr},\n"
+ f" shape={self.shape},\n"
Expand Down
Loading