Skip to content

Commit 95a915b

Browse files
committed
support slicing/indexing of shared/global tensors
Signed-off-by: Yaoyao Ding <[email protected]>
1 parent 10e41b1 commit 95a915b

File tree

16 files changed

+256
-20
lines changed

16 files changed

+256
-20
lines changed

python/tilus/backends/codegen.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -472,8 +472,8 @@ def visit_Function(self, func: Function) -> IRModule:
472472
if self.smem_workspace:
473473
self.free_shared_value(self.smem_workspace)
474474
self.smem_workspace = None
475-
if self.smem_allocator.allocated != 0:
476-
raise ValueError("Shared memory is not properly allocated/freed")
475+
# if self.smem_allocator.allocated != 0:
476+
# raise ValueError("Shared memory is not properly allocated/freed")
477477
if self.smem_allocator.maximum_allocated > get_current_target().properties.shared_memory_per_block:
478478
raise CodeGenerationFailed(
479479
"Request shared memory {} bytes, but the device only allows {} bytes.".format(

python/tilus/backends/emitters/gmem.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
from hidet.ir.expr import Expr
1616

1717
from tilus.backends.codegen import BaseInstEmitter, register_emitter
18-
from tilus.ir.instructions import AllocateGlobalInst, GlobalViewInst
18+
from tilus.ir import GlobalTensor
19+
from tilus.ir.instructions import AllocateGlobalInst, GlobalIndexInst, GlobalSliceInst, GlobalViewInst
1920
from tilus.utils import cdiv
2021

2122

@@ -34,3 +35,23 @@ def emit(self, inst: AllocateGlobalInst) -> None:
3435
)
3536
var = self.get_or_allocate_var(tensor)
3637
self.assign(var, ptr)
38+
39+
40+
@register_emitter(GlobalIndexInst)
41+
class GlobalIndexInstEmitter(BaseInstEmitter):
42+
def emit(self, inst: GlobalIndexInst) -> None:
43+
dst = inst.dst
44+
tensor = inst.inputs[0].as_global_tensor()
45+
var = self.get_or_allocate_var(tensor)
46+
offset = tensor.layout(*inst.indices)
47+
self.assign(dst, value=var[offset])
48+
49+
50+
@register_emitter(GlobalSliceInst)
51+
class GlobalSliceInstEmitter(BaseInstEmitter):
52+
def emit(self, inst: GlobalSliceInst) -> None:
53+
input_tensor: GlobalTensor = inst.global_input
54+
output_tensor: GlobalTensor = inst.global_output
55+
slice_offset = input_tensor.layout(*inst.offsets)
56+
output_var = self.get_or_allocate_var(output_tensor)
57+
self.assign(output_var, ~self.tensor2var[input_tensor][slice_offset])

python/tilus/backends/emitters/smem.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from hidet.ir.type import tensor_pointer_type
1919

2020
from tilus.backends.codegen import BaseInstEmitter, register_emitter
21-
from tilus.ir.instructions import AllocateSharedInst, FreeSharedInst, SharedSliceInst
21+
from tilus.ir.instructions import AllocateSharedInst, FreeSharedInst, SharedIndexInst, SharedSliceInst
2222
from tilus.ir.tensor import SharedTensor
2323

2424

@@ -62,3 +62,13 @@ def emit(self, inst: SharedSliceInst) -> None:
6262
tp=int32,
6363
init=self.shared_tensor_shared_space_addr[shared_input] + slice_offset * shared_input.dtype.nbytes,
6464
)
65+
66+
67+
@register_emitter(SharedIndexInst)
68+
class SharedIndexInstEmitter(BaseInstEmitter):
69+
def emit(self, inst: SharedIndexInst) -> None:
70+
dst = inst.dst
71+
tensor = inst.shared_input
72+
var = self.get_or_allocate_var(tensor)
73+
offset = tensor.layout(*inst.indices)
74+
self.assign(dst, value=var[offset])

python/tilus/ir/builders/stmt_builder.py

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@
5151
ExitInst,
5252
FormatPrintInst,
5353
FreeSharedInst,
54+
GlobalIndexInst,
55+
GlobalSliceInst,
5456
GlobalViewInst,
5557
LoadGlobalGenericInst,
5658
LoadGlobalInst,
@@ -62,6 +64,7 @@
6264
ReduceInst,
6365
RepeatInst,
6466
RepeatInterleaveInst,
67+
SharedIndexInst,
6568
SharedSliceInst,
6669
SqueezeInst,
6770
StoreGlobalGenericInst,
@@ -285,8 +288,10 @@ def brk(self):
285288
stmt = BreakStmt()
286289
self._stack[-1].append(stmt)
287290

288-
def declare(self, type: BaseType, init: Optional[Expr | float | int] = None) -> Var:
289-
var = Var("v", type=type)
291+
def declare(self, type: BaseType, init: Optional[Expr | float | int] = None, hint: Optional[str] = None) -> Var:
292+
if hint is not None:
293+
hint = "v"
294+
var = Var(hint, type=type)
290295
self.append(DeclareStmt(var, as_expr(init) if init is not None else None))
291296
return var
292297

@@ -364,6 +369,33 @@ def allocate_global(
364369
self.append(inst)
365370
return inst.global_output
366371

372+
def slice_global(
373+
self,
374+
tensor: GlobalTensor,
375+
offsets: Sequence[Expr | int],
376+
slice_dims: Sequence[int],
377+
slice_shape: Sequence[Expr | int],
378+
) -> GlobalTensor:
379+
offsets_ = [as_expr(offset) for offset in offsets]
380+
inst = GlobalSliceInst.create(
381+
tensor=tensor,
382+
offsets=offsets_,
383+
dims=slice_dims,
384+
shape=slice_shape,
385+
)
386+
self.append(inst)
387+
return inst.global_output
388+
389+
def index_global(
390+
self,
391+
dst: Var,
392+
tensor: GlobalTensor,
393+
indices: Sequence[Expr | int],
394+
) -> None:
395+
indices_ = [as_expr(index) for index in indices]
396+
inst = GlobalIndexInst.create(dst=dst, tensor=tensor, indices=indices_)
397+
self.append(inst)
398+
367399
def assign_register(self, output: RegisterTensor, x: RegisterTensor) -> None:
368400
inst = AssignInst.create(output, x)
369401
self.append(inst)
@@ -722,7 +754,7 @@ def free_shared(self, shared_value: SharedTensor) -> None:
722754
inst = FreeSharedInst.create(shared_value)
723755
self.append(inst)
724756

725-
def shared_slice(
757+
def slice_shared(
726758
self,
727759
tensor: SharedTensor,
728760
offsets: Sequence[Expr | int],
@@ -739,6 +771,16 @@ def shared_slice(
739771
self.append(inst)
740772
return inst.shared_output
741773

774+
def index_shared(
775+
self,
776+
dst: Var,
777+
tensor: SharedTensor,
778+
indices: Sequence[Expr | int],
779+
) -> None:
780+
indices_ = [as_expr(index) for index in indices]
781+
inst = SharedIndexInst.create(dst=dst, tensor=tensor, indices=indices_)
782+
self.append(inst)
783+
742784
def load_shared(
743785
self,
744786
src: SharedTensor,

python/tilus/ir/inst.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,13 @@ def shared_input(self) -> SharedTensor:
6060
assert isinstance(x, SharedTensor)
6161
return x
6262

63+
@property
64+
def global_input(self) -> GlobalTensor:
65+
assert len(self.inputs) == 1
66+
x = self.inputs[0]
67+
assert isinstance(x, GlobalTensor)
68+
return x
69+
6370
@property
6471
def attributes(self) -> dict[str, Any]:
6572
attrs = {}

python/tilus/ir/instructions/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@
4343
ExitInst,
4444
FormatPrintInst,
4545
FreeSharedInst,
46+
GlobalIndexInst,
47+
GlobalSliceInst,
4648
GlobalViewInst,
4749
LoadGlobalGenericInst,
4850
LoadGlobalInst,
@@ -54,6 +56,7 @@
5456
ReduceInst,
5557
RepeatInst,
5658
RepeatInterleaveInst,
59+
SharedIndexInst,
5760
SharedSliceInst,
5861
ShuffleDownInst,
5962
ShuffleUpInst,

python/tilus/ir/instructions/generic.py

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,48 @@ def create(dst: GlobalTensor, x: RegisterTensor, offsets: Sequence[Expr], dims:
7272
return StoreGlobalInst(output=None, inputs=(dst, x), offsets=tuple(offsets), dims=tuple(dims))
7373

7474

75+
@dataclass(frozen=True, eq=False)
76+
class GlobalSliceInst(Instruction):
77+
offsets: tuple[Expr, ...]
78+
dims: Optional[tuple[int, ...]]
79+
80+
@staticmethod
81+
def create(
82+
tensor: GlobalTensor,
83+
offsets: Sequence[Expr],
84+
dims: Sequence[int],
85+
shape: Sequence[Expr | int],
86+
) -> SharedSliceInst:
87+
from tilus.ir.layout.global_layout import global_slice
88+
89+
output = GlobalTensor.create(dtype=tensor.dtype, layout=global_slice(tensor.layout, offsets, dims, shape))
90+
return SharedSliceInst(
91+
output=output,
92+
inputs=(tensor,),
93+
offsets=tuple(offsets),
94+
dims=tuple(dims) if len(dims) < len(tensor.shape) else None,
95+
)
96+
97+
98+
@dataclass(frozen=True, eq=False)
99+
class GlobalIndexInst(Instruction):
100+
dst: Var
101+
indices: tuple[Expr, ...]
102+
103+
@staticmethod
104+
def create(
105+
dst: Var,
106+
tensor: GlobalTensor,
107+
indices: Sequence[Expr],
108+
) -> GlobalIndexInst:
109+
return GlobalIndexInst(
110+
output=None,
111+
inputs=(tensor,),
112+
dst=dst,
113+
indices=tuple(indices),
114+
)
115+
116+
75117
@dataclass(frozen=True, eq=False)
76118
class LoadSharedInst(Instruction):
77119
@staticmethod
@@ -103,7 +145,26 @@ def create(
103145
output=output,
104146
inputs=(tensor,),
105147
offsets=tuple(offsets),
106-
dims=tuple(dims) if len(dims) < len(tensor.shape) else None,
148+
dims=tuple(dims) if len(dims) < len(tensor.shape) else tuple(range(len(tensor.shape))),
149+
)
150+
151+
152+
@dataclass(frozen=True, eq=False)
153+
class SharedIndexInst(Instruction):
154+
dst: Var
155+
indices: tuple[Expr, ...]
156+
157+
@staticmethod
158+
def create(
159+
dst: Var,
160+
tensor: SharedTensor,
161+
indices: Sequence[Expr],
162+
) -> SharedIndexInst:
163+
return SharedIndexInst(
164+
output=None,
165+
inputs=(tensor,),
166+
dst=dst,
167+
indices=tuple(indices),
107168
)
108169

109170

python/tilus/ir/layout/global_layout.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,3 +218,43 @@ def f_offset(axes: Sequence[Var]) -> Expr:
218218
return sum([axes[i] * strides[i] for i in range(len(shape))], start=int32.zero)
219219

220220
return GlobalLayout.create(shape=shape, size=prod(shape), f_offset=f_offset)
221+
222+
223+
def global_slice(
224+
layout: GlobalLayout, offsets: Sequence[Expr | int], dims: Sequence[int], shape: Sequence[Expr | int]
225+
) -> GlobalLayout:
226+
"""Create a sliced global layout from an existing layout.
227+
228+
This function creates a new global layout by slicing an existing global layout. The slicing is defined by the
229+
specified offsets, dimensions to slice, and the shape of the resulting layout. The new layout retains the mapping
230+
function of the original layout, adjusted for the specified offsets and dimensions.
231+
232+
Parameters
233+
----------
234+
layout: GlobalLayout
235+
The original global layout to be sliced.
236+
offsets: Sequence[Expr | int]
237+
The offsets for each dimension of the original layout. It should have the same length as the original layout's
238+
shape.
239+
dims: Sequence[int]
240+
The dimensions to be sliced from the original layout. Each dimension should be a valid index in the original
241+
layout's shape.
242+
shape: Sequence[Expr | int]
243+
The shape of the resulting sliced global layout. It should have the same length as the number of dimensions
244+
specified in `dims`.
245+
246+
Returns
247+
-------
248+
ret: GlobalLayout
249+
A new global layout that represents the sliced version of the original layout, with the specified shape and
250+
adjusted mapping function.
251+
"""
252+
assert len(dims) == len(shape) <= len(layout.shape) == len(offsets)
253+
254+
def f_offset(axes: Sequence[Var]) -> Expr:
255+
indices = list(offsets)
256+
for dim, axis in zip(dims, axes):
257+
indices[dim] = axis + offsets[dim]
258+
return layout(*indices) - layout(*offsets) # type: ignore[arg-type]
259+
260+
return GlobalLayout.create(shape=shape, size=prod(shape), f_offset=f_offset)

python/tilus/ir/layout/inference/inference_rules/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
from . import (
16+
allocate_shared,
1617
assign,
1718
cp_async,
1819
elementwise_binary,
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
from tilus.ir.instructions import AllocateSharedInst
16+
from tilus.ir.layout.inference.rule import LayoutInferenceContext, LayoutInferenceRule, register_rule
17+
from tilus.ir.layout.shared_layout import SharedLayout, shared_row_major
18+
from tilus.ir.tensor import SharedTensor
19+
20+
21+
@register_rule(AllocateSharedInst)
22+
class AllocateSharedRule(LayoutInferenceRule):
23+
@staticmethod
24+
def inference(ctx: LayoutInferenceContext, inst: AllocateSharedInst) -> dict[SharedTensor, SharedLayout]:
25+
tensor = inst.shared_output
26+
27+
if tensor.optional_layout is not None:
28+
return {}
29+
else:
30+
return {tensor: shared_row_major(*tensor.shape)}

0 commit comments

Comments
 (0)