Skip to content

Commit e46539b

Browse files
committed
Fix imports
1 parent 92af19f commit e46539b

File tree

2 files changed

+134
-1
lines changed

2 files changed

+134
-1
lines changed

py-polars/tests/test_udfs.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import pytest
2121

22+
# TODO: Import these from py-polars/tests/unit/operations/test_inefficient_apply.py
2223
MY_CONSTANT = 3
2324
MY_DICT = {0: "a", 1: "b", 2: "c", 3: "d", 4: "e"}
2425
MY_LIST = [1, 2, 3]

py-polars/tests/unit/operations/test_inefficient_apply.py

Lines changed: 133 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,139 @@
1515
from polars.testing import assert_frame_equal, assert_series_equal
1616
from polars.utils.udfs import _NUMPY_FUNCTIONS, BytecodeParser
1717
from polars.utils.various import in_terminal_that_supports_colour
18-
from tests.test_udfs import MY_CONSTANT, MY_DICT, MY_LIST, NOOP_TEST_CASES, TEST_CASES
18+
19+
MY_CONSTANT = 3
20+
MY_DICT = {0: "a", 1: "b", 2: "c", 3: "d", 4: "e"}
21+
MY_LIST = [1, 2, 3]
22+
23+
24+
# column_name, function, expected_suggestion
25+
TEST_CASES = [
26+
# ---------------------------------------------
27+
# numeric expr: math, comparison, logic ops
28+
# ---------------------------------------------
29+
("a", "lambda x: x + 1 - (2 / 3)", '(pl.col("a") + 1) - 0.6666666666666666'),
30+
("a", "lambda x: x // 1 % 2", '(pl.col("a") // 1) % 2'),
31+
("a", "lambda x: x & True", 'pl.col("a") & True'),
32+
("a", "lambda x: x | False", 'pl.col("a") | False'),
33+
("a", "lambda x: abs(x) != 3", 'pl.col("a").abs() != 3'),
34+
("a", "lambda x: int(x) > 1", 'pl.col("a").cast(pl.Int64) > 1'),
35+
("a", "lambda x: not (x > 1) or x == 2", '~(pl.col("a") > 1) | (pl.col("a") == 2)'),
36+
("a", "lambda x: x is None", 'pl.col("a") is None'),
37+
("a", "lambda x: x is not None", 'pl.col("a") is not None'),
38+
(
39+
"a",
40+
"lambda x: ((x * -x) ** x) * 1.0",
41+
'((pl.col("a") * -pl.col("a")) ** pl.col("a")) * 1.0',
42+
),
43+
(
44+
"a",
45+
"lambda x: 1.0 * (x * (x**x))",
46+
'1.0 * (pl.col("a") * (pl.col("a") ** pl.col("a")))',
47+
),
48+
(
49+
"a",
50+
"lambda x: (x / x) + ((x * x) - x)",
51+
'(pl.col("a") / pl.col("a")) + ((pl.col("a") * pl.col("a")) - pl.col("a"))',
52+
),
53+
(
54+
"a",
55+
"lambda x: (10 - x) / (((x * 4) - x) // (2 + (x * (x - 1))))",
56+
'(10 - pl.col("a")) / (((pl.col("a") * 4) - pl.col("a")) // (2 + (pl.col("a") * (pl.col("a") - 1))))',
57+
),
58+
("a", "lambda x: x in (2, 3, 4)", 'pl.col("a").is_in((2, 3, 4))'),
59+
("a", "lambda x: x not in (2, 3, 4)", '~pl.col("a").is_in((2, 3, 4))'),
60+
(
61+
"a",
62+
"lambda x: x in (1, 2, 3, 4, 3) and x % 2 == 0 and x > 0",
63+
'pl.col("a").is_in((1, 2, 3, 4, 3)) & ((pl.col("a") % 2) == 0) & (pl.col("a") > 0)',
64+
),
65+
("a", "lambda x: MY_CONSTANT + x", 'MY_CONSTANT + pl.col("a")'),
66+
("a", "lambda x: 0 + numpy.cbrt(x)", '0 + pl.col("a").cbrt()'),
67+
("a", "lambda x: np.sin(x) + 1", 'pl.col("a").sin() + 1'),
68+
(
69+
"a", # note: functions operate on consts
70+
"lambda x: np.sin(3.14159265358979) + (x - 1) + abs(-3)",
71+
'(np.sin(3.14159265358979) + (pl.col("a") - 1)) + abs(-3)',
72+
),
73+
(
74+
"a",
75+
"lambda x: (float(x) * int(x)) // 2",
76+
'(pl.col("a").cast(pl.Float64) * pl.col("a").cast(pl.Int64)) // 2',
77+
),
78+
# ---------------------------------------------
79+
# logical 'and/or' (validate nesting levels)
80+
# ---------------------------------------------
81+
(
82+
"a",
83+
"lambda x: x > 1 or (x == 1 and x == 2)",
84+
'(pl.col("a") > 1) | (pl.col("a") == 1) & (pl.col("a") == 2)',
85+
),
86+
(
87+
"a",
88+
"lambda x: (x > 1 or x == 1) and x == 2",
89+
'((pl.col("a") > 1) | (pl.col("a") == 1)) & (pl.col("a") == 2)',
90+
),
91+
(
92+
"a",
93+
"lambda x: x > 2 or x != 3 and x not in (0, 1, 4)",
94+
'(pl.col("a") > 2) | (pl.col("a") != 3) & ~pl.col("a").is_in((0, 1, 4))',
95+
),
96+
(
97+
"a",
98+
"lambda x: x > 1 and x != 2 or x % 2 == 0 and x < 3",
99+
'(pl.col("a") > 1) & (pl.col("a") != 2) | ((pl.col("a") % 2) == 0) & (pl.col("a") < 3)',
100+
),
101+
(
102+
"a",
103+
"lambda x: x > 1 and (x != 2 or x % 2 == 0) and x < 3",
104+
'(pl.col("a") > 1) & ((pl.col("a") != 2) | ((pl.col("a") % 2) == 0)) & (pl.col("a") < 3)',
105+
),
106+
# ---------------------------------------------
107+
# string expr: case/cast ops
108+
# ---------------------------------------------
109+
("b", "lambda x: str(x).title()", 'pl.col("b").cast(pl.Utf8).str.to_titlecase()'),
110+
(
111+
"b",
112+
'lambda x: x.lower() + ":" + x.upper() + ":" + x.title()',
113+
'(((pl.col("b").str.to_lowercase() + \':\') + pl.col("b").str.to_uppercase()) + \':\') + pl.col("b").str.to_titlecase()',
114+
),
115+
# ---------------------------------------------
116+
# json expr: load/extract
117+
# ---------------------------------------------
118+
("c", "lambda x: json.loads(x)", 'pl.col("c").str.json_extract()'),
119+
# ---------------------------------------------
120+
# map_dict
121+
# ---------------------------------------------
122+
("a", "lambda x: MY_DICT[x]", 'pl.col("a").map_dict(MY_DICT)'),
123+
(
124+
"a",
125+
"lambda x: MY_DICT[x - 1] + MY_DICT[1 + x]",
126+
'(pl.col("a") - 1).map_dict(MY_DICT) + (1 + pl.col("a")).map_dict(MY_DICT)',
127+
),
128+
# ---------------------------------------------
129+
# standard library datetime parsing
130+
# ---------------------------------------------
131+
(
132+
"d",
133+
'lambda x: datetime.strptime(x, "%Y-%m-%d")',
134+
'pl.col("d").str.to_datetime(format="%Y-%m-%d")',
135+
),
136+
(
137+
"d",
138+
'lambda x: dt.datetime.strptime(x, "%Y-%m-%d")',
139+
'pl.col("d").str.to_datetime(format="%Y-%m-%d")',
140+
),
141+
]
142+
143+
NOOP_TEST_CASES = [
144+
"lambda x: x",
145+
"lambda x, y: x + y",
146+
"lambda x: x[0] + 1",
147+
"lambda x: MY_LIST[x]",
148+
"lambda x: MY_DICT[1]",
149+
'lambda x: "first" if x == 1 else "not first"',
150+
]
19151

20152
EVAL_ENVIRONMENT = {
21153
"np": numpy,

0 commit comments

Comments
 (0)