Skip to content

Commit cc0a602

Browse files
committed
Accept start point
1 parent 903a743 commit cc0a602

File tree

1 file changed

+26
-4
lines changed

1 file changed

+26
-4
lines changed

src/shmem4py/shmem.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -872,9 +872,30 @@ def _ceildiv(p: int, q: int) -> int:
872872
# ---
873873

874874

875+
def _parse_stride(st):
876+
if isinstance(st, int):
877+
stride = st
878+
start = [0]
879+
elif isinstance(st, tuple):
880+
stride, start = st[1], list(st[0])
881+
882+
return stride, start
883+
875884
def _parse_rma(target, source, size=None, tst=1, sst=1):
876-
tdata, tlen, ttype = _getbuffer(target, readonly=False)
877-
sdata, slen, stype = _getbuffer(source, readonly=True)
885+
if isinstance(tst, tuple): assert target.ndim == len(tst[0])
886+
if isinstance(sst, tuple): assert source.ndim == len(sst[0])
887+
tst, tstart = _parse_stride(tst)
888+
sst, sstart = _parse_stride(sst)
889+
890+
if tstart != [0]:
891+
tdata, tlen, ttype = _getbuffer(target[*tstart[:-1],tstart[-1]:], readonly=False)
892+
else:
893+
tdata, tlen, ttype = _getbuffer(target, readonly=False)
894+
895+
if sstart != [0]:
896+
sdata, slen, stype = _getbuffer(source[*sstart[:-1],sstart[-1]:], readonly=True)
897+
else:
898+
sdata, slen, stype = _getbuffer(source, readonly=True)
878899

879900
assert ttype == stype
880901
ctype = ttype
@@ -884,8 +905,7 @@ def _parse_rma(target, source, size=None, tst=1, sst=1):
884905
if size is None:
885906
size = min(tsize, ssize)
886907
else:
887-
assert size <= tsize
888-
assert size <= ssize
908+
assert size >= 0
889909

890910
return (ctype, tdata, sdata, size)
891911

@@ -901,6 +921,8 @@ def _shmem_rma(ctx, name, target, source, size, pe):
901921

902922
def _shmem_irma(ctx, name, target, source, tst, sst, size, pe):
903923
ctype, target, source, size = _parse_rma(target, source, size, tst, sst)
924+
tst, _ = _parse_stride(tst)
925+
sst, _ = _parse_stride(sst)
904926
return _shmem(ctx, ctype, f'i{name}')(target, source, tst, sst, size, pe)
905927

906928

0 commit comments

Comments
 (0)