@@ -42,91 +42,6 @@ def is_scalar(x):
42
42
return False
43
43
return False
44
44
45
-
46
- class _FlattenIndexMapping (object ):
47
- def __init__ (self , stride = 1 , reverse = False ):
48
- self ._stride = stride
49
- self .reverse = reverse
50
-
51
- def __call__ (self , idxs : _HybridIndex ):
52
- new_idxs = []
53
-
54
- if self .reverse == True :
55
- for i in idxs :
56
- new_idxs .append ( _HybridIndex ( idx = (i .idx // self ._stride ), root_idx = i .root_idx ) )
57
- new_idxs = list (set (new_idxs ))
58
- else :
59
- for i in idxs :
60
- new_idxs .extend (
61
- [ _HybridIndex (idx = k , root_idx = i .root_idx ) for k in range (i .idx * self ._stride , (i .idx + 1 ) * self ._stride ) ]
62
- )
63
- return new_idxs
64
-
65
-
66
- class _ConcatIndexMapping (object ):
67
- def __init__ (self , offset , reverse = False ):
68
- self .offset = offset
69
- self .reverse = reverse
70
-
71
- def __call__ (self , idxs : _HybridIndex ):
72
- if self .reverse == True :
73
- new_idxs = [
74
- _HybridIndex (idx = i .idx - self .offset [0 ], root_idx = i .root_idx )
75
- for i in idxs
76
- if (i .idx >= self .offset [0 ] and i .idx < self .offset [1 ])
77
- ]
78
- else :
79
- new_idxs = [ _HybridIndex (idx = i .idx + self .offset [0 ], root_idx = i .root_idx ) for i in idxs ]
80
- return new_idxs
81
-
82
- class _GQAIndexMapping (object ):
83
- def __init__ (self , repeat , head_dim , reverse = False ):
84
- self .repeat = repeat
85
- self .reverse = reverse
86
- self .head_dim = head_dim
87
-
88
- def __call__ (self , idxs : _HybridIndex ):
89
- head_dim = self .head_dim
90
- repeat = self .repeat
91
- if self .reverse == True :
92
- new_idxs = [ _HybridIndex (idx = ( i .idx - i .idx // (head_dim * repeat ) * head_dim * (repeat - 1 ) - i .idx // head_dim % repeat * head_dim ), root_idx = None ) for i in idxs ]
93
- else :
94
- new_idxs = []
95
-
96
- return new_idxs
97
-
98
- class _SliceIndexMapping (object ):
99
- def __init__ (self , dim , start , step , end , reverse = False ):
100
- self .start = start
101
- self .step = step
102
- self .end = end
103
- self .reverse = reverse
104
- self .dim = dim
105
-
106
- def __call__ (self , idxs : _HybridIndex ):
107
-
108
- if self .reverse == True :
109
- new_idxs = [ _HybridIndex (idx = i .idx * self .step + self .start , root_idx = i .root_idx ) for i in idxs ]
110
- else :
111
- new_idxs = [ _HybridIndex (idx = (i .idx - self .start ) // self .step , root_idx = i .root_idx ) for i in idxs if (i .idx >= self .start and i .idx < self .end and (i .idx - self .start )% self .step == 0 ) ]
112
- return new_idxs
113
-
114
- class _SplitIndexMapping (object ):
115
- def __init__ (self , offset , reverse = False ):
116
- self .offset = offset
117
- self .reverse = reverse
118
-
119
- def __call__ (self , idxs : _HybridIndex ):
120
- if self .reverse == True :
121
- new_idxs = [ _HybridIndex (idx = i .idx + self .offset [0 ], root_idx = i .root_idx ) for i in idxs ]
122
- else :
123
- new_idxs = [
124
- _HybridIndex (idx = i .idx - self .offset [0 ], root_idx = i .root_idx )
125
- for i in idxs
126
- if (i .idx >= self .offset [0 ] and i .idx < self .offset [1 ])
127
- ]
128
- return new_idxs
129
-
130
45
class ScalarSum :
131
46
def __init__ (self ):
132
47
self ._results = {}
0 commit comments