Skip to content

Commit 59ed451

Browse files
committed
[FIX] intp -> uintp for cupy
This will need to handle -ve fill value for count
1 parent 096c080 commit 59ed451

File tree

1 file changed

+21
-19
lines changed

1 file changed

+21
-19
lines changed

flox/aggregations.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -252,8 +252,8 @@ def __repr__(self) -> str:
252252
combine="sum",
253253
fill_value=0,
254254
final_fill_value=0,
255-
dtypes=np.intp,
256-
final_dtype=np.intp,
255+
dtypes=np.uintp,
256+
final_dtype=np.uintp,
257257
)
258258

259259
# note that the fill values are the result of np.func([np.nan, np.nan])
@@ -281,7 +281,7 @@ def _mean_finalize(sum_, count):
281281
combine=("sum", "sum"),
282282
finalize=_mean_finalize,
283283
fill_value=(0, 0),
284-
dtypes=(None, np.intp),
284+
dtypes=(None, np.uintp),
285285
final_dtype=np.floating,
286286
)
287287
nanmean = Aggregation(
@@ -290,7 +290,7 @@ def _mean_finalize(sum_, count):
290290
combine=("sum", "sum"),
291291
finalize=_mean_finalize,
292292
fill_value=(0, 0),
293-
dtypes=(None, np.intp),
293+
dtypes=(None, np.uintp),
294294
final_dtype=np.floating,
295295
)
296296

@@ -315,7 +315,7 @@ def _std_finalize(sumsq, sum_, count, ddof=0):
315315
finalize=_var_finalize,
316316
fill_value=0,
317317
final_fill_value=np.nan,
318-
dtypes=(None, None, np.intp),
318+
dtypes=(None, None, np.uintp),
319319
final_dtype=np.floating,
320320
)
321321
nanvar = Aggregation(
@@ -325,7 +325,7 @@ def _std_finalize(sumsq, sum_, count, ddof=0):
325325
finalize=_var_finalize,
326326
fill_value=0,
327327
final_fill_value=np.nan,
328-
dtypes=(None, None, np.intp),
328+
dtypes=(None, None, np.uintp),
329329
final_dtype=np.floating,
330330
)
331331
std = Aggregation(
@@ -335,7 +335,7 @@ def _std_finalize(sumsq, sum_, count, ddof=0):
335335
finalize=_std_finalize,
336336
fill_value=0,
337337
final_fill_value=np.nan,
338-
dtypes=(None, None, np.intp),
338+
dtypes=(None, None, np.uintp),
339339
final_dtype=np.floating,
340340
)
341341
nanstd = Aggregation(
@@ -345,7 +345,7 @@ def _std_finalize(sumsq, sum_, count, ddof=0):
345345
finalize=_std_finalize,
346346
fill_value=0,
347347
final_fill_value=np.nan,
348-
dtypes=(None, None, np.intp),
348+
dtypes=(None, None, np.uintp),
349349
final_dtype=np.floating,
350350
)
351351

@@ -368,7 +368,7 @@ def argreduce_preprocess(array, axis):
368368
assert len(axis) == 1
369369
axis = axis[0]
370370

371-
idx = dask.array.arange(array.shape[axis], chunks=array.chunks[axis], dtype=np.intp)
371+
idx = dask.array.arange(array.shape[axis], chunks=array.chunks[axis], dtype=np.uintp)
372372
# broadcast (TODO: is this needed?)
373373
idx = idx[tuple(slice(None) if i == axis else np.newaxis for i in range(array.ndim))]
374374

@@ -398,8 +398,8 @@ def _pick_second(*x):
398398
fill_value=(dtypes.NINF, 0),
399399
final_fill_value=-1,
400400
finalize=_pick_second,
401-
dtypes=(None, np.intp),
402-
final_dtype=np.intp,
401+
dtypes=(None, np.uintp),
402+
final_dtype=np.uintp,
403403
)
404404

405405
argmin = Aggregation(
@@ -411,8 +411,8 @@ def _pick_second(*x):
411411
fill_value=(dtypes.INF, 0),
412412
final_fill_value=-1,
413413
finalize=_pick_second,
414-
dtypes=(None, np.intp),
415-
final_dtype=np.intp,
414+
dtypes=(None, np.uintp),
415+
final_dtype=np.uintp,
416416
)
417417

418418
nanargmax = Aggregation(
@@ -424,8 +424,8 @@ def _pick_second(*x):
424424
fill_value=(dtypes.NINF, 0),
425425
final_fill_value=-1,
426426
finalize=_pick_second,
427-
dtypes=(None, np.intp),
428-
final_dtype=np.intp,
427+
dtypes=(None, np.uintp),
428+
final_dtype=np.uintp,
429429
)
430430

431431
nanargmin = Aggregation(
@@ -437,8 +437,8 @@ def _pick_second(*x):
437437
fill_value=(dtypes.INF, 0),
438438
final_fill_value=-1,
439439
finalize=_pick_second,
440-
dtypes=(None, np.intp),
441-
final_dtype=np.intp,
440+
dtypes=(None, np.uintp),
441+
final_dtype=np.uintp,
442442
)
443443

444444
first = Aggregation("first", chunk=None, combine=None, fill_value=0)
@@ -574,8 +574,10 @@ def _initialize_aggregation(
574574
agg.combine += ("sum",)
575575
agg.fill_value["intermediate"] += (0,)
576576
agg.fill_value["numpy"] += (0,)
577-
agg.dtype["intermediate"] += (np.intp,)
578-
agg.dtype["numpy"] += (np.intp,)
577+
# uintp is supported by cupy, intp is not
578+
# Also count is >=0, so uint should be fine.
579+
agg.dtype["intermediate"] += (np.uintp,)
580+
agg.dtype["numpy"] += (np.uintp,)
579581
else:
580582
agg.min_count = 0
581583

0 commit comments

Comments
 (0)