|
15 | 15 | from polars.testing import assert_frame_equal, assert_series_equal
|
16 | 16 | from polars.utils.udfs import _NUMPY_FUNCTIONS, BytecodeParser
|
17 | 17 | from polars.utils.various import in_terminal_that_supports_colour
|
18 |
| -from tests.test_udfs import MY_CONSTANT, MY_DICT, MY_LIST, NOOP_TEST_CASES, TEST_CASES |
| 18 | + |
| 19 | +MY_CONSTANT = 3 |
| 20 | +MY_DICT = {0: "a", 1: "b", 2: "c", 3: "d", 4: "e"} |
| 21 | +MY_LIST = [1, 2, 3] |
| 22 | + |
| 23 | + |
| 24 | +# column_name, function, expected_suggestion |
| 25 | +TEST_CASES = [ |
| 26 | + # --------------------------------------------- |
| 27 | + # numeric expr: math, comparison, logic ops |
| 28 | + # --------------------------------------------- |
| 29 | + ("a", "lambda x: x + 1 - (2 / 3)", '(pl.col("a") + 1) - 0.6666666666666666'), |
| 30 | + ("a", "lambda x: x // 1 % 2", '(pl.col("a") // 1) % 2'), |
| 31 | + ("a", "lambda x: x & True", 'pl.col("a") & True'), |
| 32 | + ("a", "lambda x: x | False", 'pl.col("a") | False'), |
| 33 | + ("a", "lambda x: abs(x) != 3", 'pl.col("a").abs() != 3'), |
| 34 | + ("a", "lambda x: int(x) > 1", 'pl.col("a").cast(pl.Int64) > 1'), |
| 35 | + ("a", "lambda x: not (x > 1) or x == 2", '~(pl.col("a") > 1) | (pl.col("a") == 2)'), |
| 36 | + ("a", "lambda x: x is None", 'pl.col("a") is None'), |
| 37 | + ("a", "lambda x: x is not None", 'pl.col("a") is not None'), |
| 38 | + ( |
| 39 | + "a", |
| 40 | + "lambda x: ((x * -x) ** x) * 1.0", |
| 41 | + '((pl.col("a") * -pl.col("a")) ** pl.col("a")) * 1.0', |
| 42 | + ), |
| 43 | + ( |
| 44 | + "a", |
| 45 | + "lambda x: 1.0 * (x * (x**x))", |
| 46 | + '1.0 * (pl.col("a") * (pl.col("a") ** pl.col("a")))', |
| 47 | + ), |
| 48 | + ( |
| 49 | + "a", |
| 50 | + "lambda x: (x / x) + ((x * x) - x)", |
| 51 | + '(pl.col("a") / pl.col("a")) + ((pl.col("a") * pl.col("a")) - pl.col("a"))', |
| 52 | + ), |
| 53 | + ( |
| 54 | + "a", |
| 55 | + "lambda x: (10 - x) / (((x * 4) - x) // (2 + (x * (x - 1))))", |
| 56 | + '(10 - pl.col("a")) / (((pl.col("a") * 4) - pl.col("a")) // (2 + (pl.col("a") * (pl.col("a") - 1))))', |
| 57 | + ), |
| 58 | + ("a", "lambda x: x in (2, 3, 4)", 'pl.col("a").is_in((2, 3, 4))'), |
| 59 | + ("a", "lambda x: x not in (2, 3, 4)", '~pl.col("a").is_in((2, 3, 4))'), |
| 60 | + ( |
| 61 | + "a", |
| 62 | + "lambda x: x in (1, 2, 3, 4, 3) and x % 2 == 0 and x > 0", |
| 63 | + 'pl.col("a").is_in((1, 2, 3, 4, 3)) & ((pl.col("a") % 2) == 0) & (pl.col("a") > 0)', |
| 64 | + ), |
| 65 | + ("a", "lambda x: MY_CONSTANT + x", 'MY_CONSTANT + pl.col("a")'), |
| 66 | + ("a", "lambda x: 0 + numpy.cbrt(x)", '0 + pl.col("a").cbrt()'), |
| 67 | + ("a", "lambda x: np.sin(x) + 1", 'pl.col("a").sin() + 1'), |
| 68 | + ( |
| 69 | + "a", # note: functions operate on consts |
| 70 | + "lambda x: np.sin(3.14159265358979) + (x - 1) + abs(-3)", |
| 71 | + '(np.sin(3.14159265358979) + (pl.col("a") - 1)) + abs(-3)', |
| 72 | + ), |
| 73 | + ( |
| 74 | + "a", |
| 75 | + "lambda x: (float(x) * int(x)) // 2", |
| 76 | + '(pl.col("a").cast(pl.Float64) * pl.col("a").cast(pl.Int64)) // 2', |
| 77 | + ), |
| 78 | + # --------------------------------------------- |
| 79 | + # logical 'and/or' (validate nesting levels) |
| 80 | + # --------------------------------------------- |
| 81 | + ( |
| 82 | + "a", |
| 83 | + "lambda x: x > 1 or (x == 1 and x == 2)", |
| 84 | + '(pl.col("a") > 1) | (pl.col("a") == 1) & (pl.col("a") == 2)', |
| 85 | + ), |
| 86 | + ( |
| 87 | + "a", |
| 88 | + "lambda x: (x > 1 or x == 1) and x == 2", |
| 89 | + '((pl.col("a") > 1) | (pl.col("a") == 1)) & (pl.col("a") == 2)', |
| 90 | + ), |
| 91 | + ( |
| 92 | + "a", |
| 93 | + "lambda x: x > 2 or x != 3 and x not in (0, 1, 4)", |
| 94 | + '(pl.col("a") > 2) | (pl.col("a") != 3) & ~pl.col("a").is_in((0, 1, 4))', |
| 95 | + ), |
| 96 | + ( |
| 97 | + "a", |
| 98 | + "lambda x: x > 1 and x != 2 or x % 2 == 0 and x < 3", |
| 99 | + '(pl.col("a") > 1) & (pl.col("a") != 2) | ((pl.col("a") % 2) == 0) & (pl.col("a") < 3)', |
| 100 | + ), |
| 101 | + ( |
| 102 | + "a", |
| 103 | + "lambda x: x > 1 and (x != 2 or x % 2 == 0) and x < 3", |
| 104 | + '(pl.col("a") > 1) & ((pl.col("a") != 2) | ((pl.col("a") % 2) == 0)) & (pl.col("a") < 3)', |
| 105 | + ), |
| 106 | + # --------------------------------------------- |
| 107 | + # string expr: case/cast ops |
| 108 | + # --------------------------------------------- |
| 109 | + ("b", "lambda x: str(x).title()", 'pl.col("b").cast(pl.Utf8).str.to_titlecase()'), |
| 110 | + ( |
| 111 | + "b", |
| 112 | + 'lambda x: x.lower() + ":" + x.upper() + ":" + x.title()', |
| 113 | + '(((pl.col("b").str.to_lowercase() + \':\') + pl.col("b").str.to_uppercase()) + \':\') + pl.col("b").str.to_titlecase()', |
| 114 | + ), |
| 115 | + # --------------------------------------------- |
| 116 | + # json expr: load/extract |
| 117 | + # --------------------------------------------- |
| 118 | + ("c", "lambda x: json.loads(x)", 'pl.col("c").str.json_extract()'), |
| 119 | + # --------------------------------------------- |
| 120 | + # map_dict |
| 121 | + # --------------------------------------------- |
| 122 | + ("a", "lambda x: MY_DICT[x]", 'pl.col("a").map_dict(MY_DICT)'), |
| 123 | + ( |
| 124 | + "a", |
| 125 | + "lambda x: MY_DICT[x - 1] + MY_DICT[1 + x]", |
| 126 | + '(pl.col("a") - 1).map_dict(MY_DICT) + (1 + pl.col("a")).map_dict(MY_DICT)', |
| 127 | + ), |
| 128 | + # --------------------------------------------- |
| 129 | + # standard library datetime parsing |
| 130 | + # --------------------------------------------- |
| 131 | + ( |
| 132 | + "d", |
| 133 | + 'lambda x: datetime.strptime(x, "%Y-%m-%d")', |
| 134 | + 'pl.col("d").str.to_datetime(format="%Y-%m-%d")', |
| 135 | + ), |
| 136 | + ( |
| 137 | + "d", |
| 138 | + 'lambda x: dt.datetime.strptime(x, "%Y-%m-%d")', |
| 139 | + 'pl.col("d").str.to_datetime(format="%Y-%m-%d")', |
| 140 | + ), |
| 141 | +] |
| 142 | + |
| 143 | +NOOP_TEST_CASES = [ |
| 144 | + "lambda x: x", |
| 145 | + "lambda x, y: x + y", |
| 146 | + "lambda x: x[0] + 1", |
| 147 | + "lambda x: MY_LIST[x]", |
| 148 | + "lambda x: MY_DICT[1]", |
| 149 | + 'lambda x: "first" if x == 1 else "not first"', |
| 150 | +] |
19 | 151 |
|
20 | 152 | EVAL_ENVIRONMENT = {
|
21 | 153 | "np": numpy,
|
|
0 commit comments