Skip to content

Commit 33e13b3

Browse files
committed
Rework property tests
1 parent f1bd894 commit 33e13b3

File tree

2 files changed

+26
-23
lines changed

2 files changed

+26
-23
lines changed

tests/strategies.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,22 +13,22 @@
1313

1414
Chunks = tuple[tuple[int, ...], ...]
1515

16-
17-
def supported_dtypes() -> st.SearchStrategy[np.dtype]:
18-
return (
19-
npst.integer_dtypes(endianness="=")
20-
| npst.unsigned_integer_dtypes(endianness="=")
21-
| npst.floating_dtypes(endianness="=", sizes=(32, 64))
22-
| npst.complex_number_dtypes(endianness="=")
23-
| npst.datetime64_dtypes(endianness="=")
24-
| npst.timedelta64_dtypes(endianness="=")
25-
| npst.unicode_string_dtypes(endianness="=")
26-
)
27-
28-
16+
numeric_dtypes = (
17+
npst.integer_dtypes(endianness="=")
18+
| npst.unsigned_integer_dtypes(endianness="=")
19+
| npst.floating_dtypes(endianness="=", sizes=(32, 64))
20+
# TODO: add complex here not in supported_dtypes
21+
)
2922
# TODO: stop excluding everything but U
30-
array_dtypes = supported_dtypes().filter(lambda x: x.kind not in "cU")
31-
by_dtype_st = supported_dtypes()
23+
numeric_like_dtypes = (
24+
numeric_dtypes | npst.datetime64_dtypes(endianness="=") | npst.timedelta64_dtypes(endianness="=")
25+
)
26+
supported_dtypes = (
27+
numeric_like_dtypes
28+
| npst.unicode_string_dtypes(endianness="=")
29+
| npst.complex_number_dtypes(endianness="=")
30+
)
31+
by_dtype_st = supported_dtypes
3232

3333
NON_NUMPY_FUNCS = [
3434
"first",
@@ -43,12 +43,15 @@ def supported_dtypes() -> st.SearchStrategy[np.dtype]:
4343

4444
func_st = st.sampled_from([f for f in ALL_FUNCS if f not in NON_NUMPY_FUNCS and f not in SKIPPED_FUNCS])
4545
numeric_arrays = npst.arrays(
46-
elements={"allow_subnormal": False}, shape=npst.array_shapes(), dtype=array_dtypes
46+
elements={"allow_subnormal": False}, shape=npst.array_shapes(), dtype=numeric_dtypes
47+
)
48+
numeric_like_arrays = npst.arrays(
49+
elements={"allow_subnormal": False}, shape=npst.array_shapes(), dtype=numeric_like_dtypes
4750
)
4851
all_arrays = npst.arrays(
4952
elements={"allow_subnormal": False},
5053
shape=npst.array_shapes(),
51-
dtype=supported_dtypes(),
54+
dtype=supported_dtypes,
5255
)
5356

5457
calendars = st.sampled_from(

tests/test_properties.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from flox.xrutils import isnull, notnull
2121

2222
from . import assert_equal
23-
from .strategies import array_dtypes, by_arrays, chunked_arrays, func_st, numeric_arrays
23+
from .strategies import by_arrays, chunked_arrays, func_st, numeric_dtypes, numeric_like_arrays
2424
from .strategies import chunks as chunks_strategy
2525

2626
dask.config.set(scheduler="sync")
@@ -66,7 +66,7 @@ def not_overflowing_array(array: np.ndarray[Any, Any]) -> bool:
6666

6767
@given(
6868
data=st.data(),
69-
array=st.one_of(numeric_arrays, chunked_arrays(arrays=numeric_arrays)),
69+
array=st.one_of(numeric_like_arrays, chunked_arrays(arrays=numeric_like_arrays)),
7070
func=func_st,
7171
)
7272
def test_groupby_reduce(data, array, func: str) -> None:
@@ -136,7 +136,7 @@ def test_groupby_reduce(data, array, func: str) -> None:
136136

137137
@given(
138138
data=st.data(),
139-
array=chunked_arrays(arrays=numeric_arrays),
139+
array=chunked_arrays(arrays=numeric_like_arrays),
140140
func=func_st,
141141
)
142142
def test_groupby_reduce_numpy_vs_dask(data, array, func: str) -> None:
@@ -170,7 +170,7 @@ def test_groupby_reduce_numpy_vs_dask(data, array, func: str) -> None:
170170
@settings(report_multiple_bugs=False)
171171
@given(
172172
data=st.data(),
173-
array=chunked_arrays(arrays=numeric_arrays),
173+
array=chunked_arrays(arrays=numeric_like_arrays),
174174
func=st.sampled_from(tuple(NUMPY_SCAN_FUNCS)),
175175
)
176176
def test_scans(data, array: dask.array.Array, func: str) -> None:
@@ -297,8 +297,8 @@ def test_first_last_useless(data, func):
297297
@given(
298298
func=st.sampled_from(["sum", "prod", "nansum", "nanprod"]),
299299
engine=st.sampled_from(["numpy", "flox"]),
300-
array_dtype=st.none() | array_dtypes,
301-
dtype=st.none() | array_dtypes,
300+
array_dtype=st.none() | numeric_dtypes,
301+
dtype=st.none() | numeric_dtypes,
302302
)
303303
def test_agg_dtype_specified(func, array_dtype, dtype, engine):
304304
# regression test for GH388

0 commit comments

Comments
 (0)