@@ -189,8 +189,53 @@ def check_group_idx(group_idx, a=None, check_min=True):
189
189
raise ValueError ("group_idx contains negative indices" )
190
190
191
191
192
+ def _ravel_group_idx (group_idx , a , axis , size , order , method = "ravel" ):
193
+ ndim_a = a .ndim
194
+ # Create the broadcast-ready multidimensional indexing.
195
+ # Note the user could do this themselves, so this is
196
+ # very much just a convenience.
197
+ size_in = int (np .max (group_idx )) + 1 if size is None else size
198
+ group_idx_in = group_idx
199
+ group_idx = []
200
+ size = []
201
+ for ii , s in enumerate (a .shape ):
202
+ if method == "ravel" :
203
+ ii_idx = group_idx_in if ii == axis else np .arange (s )
204
+ ii_shape = [1 ] * ndim_a
205
+ ii_shape [ii ] = s
206
+ group_idx .append (ii_idx .reshape (ii_shape ))
207
+ size .append (size_in if ii == axis else s )
208
+ # Use the indexing, and return. It's a bit simpler than
209
+ # using trying to keep all the logic below happy
210
+ if method == "ravel" :
211
+ group_idx = np .ravel_multi_index (group_idx , size , order = order ,
212
+ mode = 'raise' )
213
+ elif method == "offset" :
214
+ group_idx = offset_labels (group_idx_in , a .shape , axis , order , size_in )
215
+ return group_idx , size
216
+
217
+ def offset_labels (group_idx , inshape , axis , order , size ):
218
+ """
219
+ Offset group labels by dimension. This is used when we
220
+ reduce over a subset of the dimensions of by. It assumes that the reductions
221
+ dimensions have been flattened in the last dimension
222
+ Copied from
223
+ https://stackoverflow.com/questions/46256279/bin-elements-per-row-vectorized-2d-bincount-for-numpy
224
+ """
225
+ if axis not in (- 1 , len (inshape ) - 1 ):
226
+ newshape = (s for idx , s in enumerate (inshape ) if idx != axis ) + (inshape [axis ],)
227
+ else :
228
+ newshape = inshape
229
+ group_idx = np .broadcast_to (group_idx , newshape )
230
+ group_idx : np .ndarray = (
231
+ group_idx
232
+ + np .arange (np .prod (group_idx .shape [:- 1 ]), dtype = int ).reshape ((* group_idx .shape [:- 1 ], - 1 ))
233
+ * size
234
+ )
235
+ return group_idx .reshape (inshape ).ravel ()
236
+
192
237
def input_validation (group_idx , a , size = None , order = 'C' , axis = None ,
193
- ravel_group_idx = True , check_bounds = True ):
238
+ ravel_group_idx = True , check_bounds = True , method = "ravel" ):
194
239
""" Do some fairly extensive checking of group_idx and a, trying to
195
240
give the user as much help as possible with what is wrong. Also,
196
241
convert ndim-indexing to 1d indexing.
@@ -230,23 +275,7 @@ def input_validation(group_idx, a, size=None, order='C', axis=None,
230
275
raise NotImplementedError ("when using axis arg, size must be"
231
276
"None or scalar." )
232
277
else :
233
- # Create the broadcast-ready multidimensional indexing.
234
- # Note the user could do this themselves, so this is
235
- # very much just a convenience.
236
- size_in = int (np .max (group_idx )) + 1 if size is None else size
237
- group_idx_in = group_idx
238
- group_idx = []
239
- size = []
240
- for ii , s in enumerate (a .shape ):
241
- ii_idx = group_idx_in if ii == axis else np .arange (s )
242
- ii_shape = [1 ] * ndim_a
243
- ii_shape [ii ] = s
244
- group_idx .append (ii_idx .reshape (ii_shape ))
245
- size .append (size_in if ii == axis else s )
246
- # Use the indexing, and return. It's a bit simpler than
247
- # using trying to keep all the logic below happy
248
- group_idx = np .ravel_multi_index (group_idx , size , order = order ,
249
- mode = 'raise' )
278
+ group_idx , size = _ravel_group_idx (group_idx , a , axis , size , order , method = method )
250
279
flat_size = np .prod (size )
251
280
ndim_idx = ndim_a
252
281
return group_idx .ravel (), a .ravel (), flat_size , ndim_idx , size
0 commit comments