@@ -72,6 +72,48 @@ def create(dst: GlobalTensor, x: RegisterTensor, offsets: Sequence[Expr], dims:
72
72
return StoreGlobalInst (output = None , inputs = (dst , x ), offsets = tuple (offsets ), dims = tuple (dims ))
73
73
74
74
75
+ @dataclass (frozen = True , eq = False )
76
+ class GlobalSliceInst (Instruction ):
77
+ offsets : tuple [Expr , ...]
78
+ dims : Optional [tuple [int , ...]]
79
+
80
+ @staticmethod
81
+ def create (
82
+ tensor : GlobalTensor ,
83
+ offsets : Sequence [Expr ],
84
+ dims : Sequence [int ],
85
+ shape : Sequence [Expr | int ],
86
+ ) -> SharedSliceInst :
87
+ from tilus .ir .layout .global_layout import global_slice
88
+
89
+ output = GlobalTensor .create (dtype = tensor .dtype , layout = global_slice (tensor .layout , offsets , dims , shape ))
90
+ return SharedSliceInst (
91
+ output = output ,
92
+ inputs = (tensor ,),
93
+ offsets = tuple (offsets ),
94
+ dims = tuple (dims ) if len (dims ) < len (tensor .shape ) else None ,
95
+ )
96
+
97
+
98
+ @dataclass (frozen = True , eq = False )
99
+ class GlobalIndexInst (Instruction ):
100
+ dst : Var
101
+ indices : tuple [Expr , ...]
102
+
103
+ @staticmethod
104
+ def create (
105
+ dst : Var ,
106
+ tensor : GlobalTensor ,
107
+ indices : Sequence [Expr ],
108
+ ) -> GlobalIndexInst :
109
+ return GlobalIndexInst (
110
+ output = None ,
111
+ inputs = (tensor ,),
112
+ dst = dst ,
113
+ indices = tuple (indices ),
114
+ )
115
+
116
+
75
117
@dataclass (frozen = True , eq = False )
76
118
class LoadSharedInst (Instruction ):
77
119
@staticmethod
@@ -103,7 +145,26 @@ def create(
103
145
output = output ,
104
146
inputs = (tensor ,),
105
147
offsets = tuple (offsets ),
106
- dims = tuple (dims ) if len (dims ) < len (tensor .shape ) else None ,
148
+ dims = tuple (dims ) if len (dims ) < len (tensor .shape ) else tuple (range (len (tensor .shape ))),
149
+ )
150
+
151
+
152
+ @dataclass (frozen = True , eq = False )
153
+ class SharedIndexInst (Instruction ):
154
+ dst : Var
155
+ indices : tuple [Expr , ...]
156
+
157
+ @staticmethod
158
+ def create (
159
+ dst : Var ,
160
+ tensor : SharedTensor ,
161
+ indices : Sequence [Expr ],
162
+ ) -> SharedIndexInst :
163
+ return SharedIndexInst (
164
+ output = None ,
165
+ inputs = (tensor ,),
166
+ dst = dst ,
167
+ indices = tuple (indices ),
107
168
)
108
169
109
170
0 commit comments