Skip to content

Commit 0039575

Browse files
fix: Fix to_datetime() fallible identification (#23735)
1 parent 226a8a2 commit 0039575

File tree

2 files changed

+111
-10
lines changed

2 files changed

+111
-10
lines changed

crates/polars-plan/src/plans/aexpr/properties.rs

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -212,16 +212,25 @@ impl ExprPushdownGroup {
212212
} => {
213213
debug_assert!(input.len() <= 2);
214214

215-
// `ambiguous` parameter to `to_datetime()`. Should always be a literal.
216-
debug_assert!(matches!(
217-
input.get(1).map(|x| expr_arena.get(x.node())),
218-
Some(AExpr::Literal(_)) | None
219-
));
220-
221-
match input.first().map(|x| expr_arena.get(x.node())) {
222-
Some(AExpr::Literal(_)) | None => false,
223-
_ => strptime_options.strict,
224-
}
215+
strptime_options.strict
216+
|| input
217+
.get(1)
218+
.map(|x| expr_arena.get(x.node()))
219+
.is_some_and(|ae| match ae {
220+
AExpr::Literal(lv) => {
221+
lv.extract_str().is_some_and(|ambiguous| match ambiguous {
222+
"raise" => true,
223+
"earliest" | "latest" | "null" => false,
224+
v => {
225+
if cfg!(debug_assertions) {
226+
panic!("unhandled parameter to ambiguous: {v}")
227+
}
228+
true
229+
},
230+
})
231+
},
232+
_ => true,
233+
})
225234
},
226235
AExpr::Cast {
227236
expr,

py-polars/tests/unit/operations/namespaces/temporal/test_to_datetime.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import polars as pl
1212
from polars.exceptions import ComputeError, InvalidOperationError
1313
from polars.testing import assert_series_equal
14+
from polars.testing.asserts.frame import assert_frame_equal
1415

1516
if TYPE_CHECKING:
1617
from hypothesis.strategies import DrawFn
@@ -197,3 +198,94 @@ def test_to_datetime_two_digit_year_17213(
197198
) -> None:
198199
result = pl.Series([inputs]).str.to_date(format=format).item()
199200
assert result == expected
201+
202+
203+
def test_to_datetime_column_input_to_ambiguous() -> None:
204+
q = pl.LazyFrame(
205+
{
206+
"a": ["2020-01-01 01:00Z", "2020-01-01 02:00Z"],
207+
"b": ["raise", "earliest"],
208+
}
209+
).select(pl.col.a.str.to_datetime("%Y-%m-%d %H:%M%#z", ambiguous=pl.col.b))
210+
211+
expect = pl.DataFrame(
212+
[
213+
pl.Series(
214+
"a",
215+
[
216+
datetime(2020, 1, 1, 1, 0, tzinfo=ZoneInfo(key="UTC")),
217+
datetime(2020, 1, 1, 2, 0, tzinfo=ZoneInfo(key="UTC")),
218+
],
219+
dtype=pl.Datetime(time_unit="us", time_zone="UTC"),
220+
),
221+
]
222+
)
223+
224+
assert_frame_equal(q.collect(), expect)
225+
226+
227+
def test_to_datetime_fallible_predicate_pushdown() -> None:
228+
df = pl.DataFrame({"x": ["2020-10-25 01:00", "X"]})
229+
230+
c = pl.first()
231+
232+
expect_fail = [
233+
c.str.to_datetime(
234+
"%Y-%m-%d %H:%M",
235+
time_zone="Europe/London",
236+
ambiguous="raise",
237+
strict=False,
238+
),
239+
c.str.to_datetime(
240+
"%Y-%m-%d %H:%M",
241+
time_zone="Europe/London",
242+
ambiguous="null",
243+
strict=True,
244+
),
245+
c.str.to_datetime("%Y-%m-%d %H:%M", strict=True),
246+
]
247+
248+
expect_pass = [
249+
c.str.to_datetime(
250+
"%Y-%m-%d %H:%M",
251+
time_zone="Europe/London",
252+
ambiguous="null",
253+
strict=False,
254+
),
255+
c.str.to_datetime(
256+
"%Y-%m-%d %H:%M",
257+
time_zone="Europe/London",
258+
ambiguous="earliest",
259+
strict=False,
260+
),
261+
c.str.to_datetime(
262+
"%Y-%m-%d %H:%M",
263+
time_zone="Europe/London",
264+
ambiguous="latest",
265+
strict=False,
266+
),
267+
]
268+
269+
for expr in expect_fail:
270+
with pytest.raises(Exception): # noqa: B017
271+
df.select(expr)
272+
273+
for expr in expect_pass:
274+
df.select(expr)
275+
276+
lf = df.with_columns(false=False).lazy().filter("false")
277+
278+
for e in expect_pass:
279+
q = lf.filter(e.is_not_null())
280+
plan = q.explain()
281+
assert plan.count("FILTER") == 1
282+
assert_frame_equal(
283+
q.collect(), q.collect(optimizations=pl.QueryOptFlags.none())
284+
)
285+
286+
for e in expect_fail:
287+
q = lf.filter(e.is_not_null())
288+
plan = q.explain()
289+
assert_frame_equal(
290+
q.collect(), q.collect(optimizations=pl.QueryOptFlags.none())
291+
)

0 commit comments

Comments
 (0)