Skip to content

Commit e2937a4

Browse files
authored
fix(pyarrow): properly support round tripping of fixed size list types (#11330)
1 parent cfa5768 commit e2937a4

File tree

2 files changed

+11
-11
lines changed

2 files changed

+11
-11
lines changed

ibis/formats/pyarrow.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -90,11 +90,10 @@ def to_ibis(cls, typ: pa.DataType, nullable=True) -> dt.DataType:
9090
return dt.Interval(typ.unit, nullable=nullable)
9191
elif pa.types.is_interval(typ):
9292
raise ValueError("Arrow interval type is not supported")
93-
elif (
94-
pa.types.is_list(typ)
95-
or pa.types.is_large_list(typ)
96-
or pa.types.is_fixed_size_list(typ)
97-
):
93+
elif pa.types.is_fixed_size_list(typ):
94+
value_dtype = cls.to_ibis(typ.value_type, typ.value_field.nullable)
95+
return dt.Array(value_dtype, length=typ.list_size, nullable=nullable)
96+
elif pa.types.is_list(typ) or pa.types.is_large_list(typ):
9897
value_dtype = cls.to_ibis(typ.value_type, typ.value_field.nullable)
9998
return dt.Array(value_dtype, nullable=nullable)
10099
elif pa.types.is_struct(typ):
@@ -196,7 +195,10 @@ def from_ibis(cls, dtype: dt.DataType) -> pa.DataType:
196195
cls.from_ibis(dtype.value_type),
197196
nullable=dtype.value_type.nullable,
198197
)
199-
return pa.list_(value_field)
198+
if dtype.length is None:
199+
return pa.list_(value_field)
200+
else:
201+
return pa.list_(value_field, dtype.length)
200202
elif dtype.is_struct():
201203
fields = [
202204
pa.field(name, cls.from_ibis(dtype), nullable=dtype.nullable)

ibis/formats/tests/test_pyarrow.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ def assert_dtype_roundtrip(arrow_type, ibis_type=None, restored_type=None):
4141
| past.binary_type
4242
| past.timestamp_types
4343
| st.builds(pa.list_, roundtripable_types)
44+
| st.builds(
45+
pa.list_, roundtripable_types, st.integers(min_value=1, max_value=2**31 - 1)
46+
)
4447
| past.struct_types(roundtripable_types)
4548
| past.map_types(roundtripable_types, roundtripable_types)
4649
)
@@ -75,11 +78,6 @@ def test_roundtripable_types(arrow_type):
7578
dt.Array(dt.Int64(nullable=True), nullable=False),
7679
pa.list_(pa.int64()),
7780
),
78-
(
79-
pa.list_(pa.int64(), list_size=3),
80-
dt.Array(dt.Int64(nullable=True), nullable=False),
81-
pa.list_(pa.int64()),
82-
),
8381
],
8482
)
8583
def test_non_roundtripable_types(arrow_type, ibis_type, restored_type):

0 commit comments

Comments
 (0)