Skip to content

Commit f6641e7

Browse files
authored
feat(datatypes): add Schema.from_sqlglot method to produce an Ibis schema from SQLGlot (#11351)
1 parent bb50aea commit f6641e7

File tree

3 files changed

+224
-6
lines changed

3 files changed

+224
-6
lines changed

ibis/backends/sql/datatypes.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -862,7 +862,6 @@ def _from_sqlglot_ARRAY(
862862
length: sge.Literal | None = None,
863863
nullable: bool | None = None,
864864
) -> dt.Array:
865-
assert value_type is None
866865
return dt.Array(dt.json, nullable=nullable)
867866

868867
@classmethod

ibis/expr/schema.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,101 @@ def from_polars(cls, polars_schema) -> Self:
180180

181181
return PolarsSchema.to_ibis(polars_schema)
182182

183+
@classmethod
184+
def from_sqlglot(
185+
cls, schema: sge.Schema, dialect: str | sg.Dialect | None = None
186+
) -> Self:
187+
"""Construct an Ibis Schema from a SQLGlot Schema.
188+
189+
Parameters
190+
----------
191+
schema
192+
A SQLGlot Schema containing column definitions.
193+
dialect
194+
Optional dialect to use for type conversion.
195+
196+
Returns
197+
-------
198+
Schema
199+
An Ibis Schema.
200+
201+
Examples
202+
--------
203+
>>> import ibis
204+
>>> import sqlglot as sg
205+
>>> import sqlglot.expressions as sge
206+
>>> columns = [
207+
... sge.ColumnDef(
208+
... this=sg.to_identifier("a", quoted=True),
209+
... kind=sge.DataType(this=sge.DataType.Type.BIGINT),
210+
... ),
211+
... sge.ColumnDef(
212+
... this=sg.to_identifier("b", quoted=True),
213+
... kind=sge.DataType(this=sge.DataType.Type.VARCHAR),
214+
... constraints=[sge.ColumnConstraint(kind=sge.NotNullColumnConstraint())],
215+
... ),
216+
... ]
217+
>>> schema_expr = sge.Schema(expressions=columns)
218+
>>> sch = ibis.Schema.from_sqlglot(schema_expr)
219+
>>> sch
220+
ibis.Schema {
221+
a int64
222+
b !string
223+
}
224+
225+
Different source dialects are supported using the `dialect` keyword argument.
226+
227+
>>> columns = [
228+
... sge.ColumnDef(
229+
... this=sg.to_identifier("a", quoted=True),
230+
... kind=sge.DataType(
231+
... this=sge.DataType.Type.ARRAY,
232+
... expressions=[sge.DataType(this=sge.DataType.Type.BIGINT, nested=False)],
233+
... nested=True,
234+
... ),
235+
... )
236+
... ]
237+
>>> schema_expr = sge.Schema(expressions=columns)
238+
>>> snowflake_schema = ibis.Schema.from_sqlglot(schema_expr, dialect="snowflake")
239+
>>> snowflake_schema
240+
ibis.Schema {
241+
a array<json>
242+
}
243+
>>> bigquery_schema = ibis.Schema.from_sqlglot(schema_expr, dialect="bigquery")
244+
>>> bigquery_schema
245+
ibis.Schema {
246+
a array<int64>
247+
}
248+
"""
249+
import sqlglot.expressions as sge
250+
251+
from ibis.backends.sql.datatypes import TYPE_MAPPERS, SqlglotType
252+
253+
expressions = schema.expressions
254+
if not expressions:
255+
return cls({})
256+
257+
type_mapper_class = TYPE_MAPPERS.get(dialect, SqlglotType)
258+
type_mapper = type_mapper_class()
259+
fields = {}
260+
261+
for column in expressions:
262+
name = column.this.this
263+
264+
nullable = not any(
265+
isinstance(constraint.kind, sge.NotNullColumnConstraint)
266+
for constraint in (column.constraints or [])
267+
)
268+
269+
if column.kind:
270+
ibis_dtype = type_mapper.to_ibis(column.kind, nullable=nullable)
271+
else:
272+
ibis_dtype = dt.String(nullable=nullable)
273+
274+
fields[name] = ibis_dtype
275+
276+
return cls(fields)
277+
183278
def to_numpy(self) -> list[tuple[str, np.dtype]]:
184279
"""Return the equivalent numpy dtypes."""
185280
from ibis.formats.numpy import NumpySchema

ibis/expr/tests/test_schema.py

Lines changed: 129 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from typing import NamedTuple
55

66
import pytest
7+
import sqlglot as sg
8+
import sqlglot.expressions as sge
79

810
import ibis.expr.datatypes as dt
911
import ibis.expr.schema as sch
@@ -440,8 +442,6 @@ def test_null_fields():
440442

441443

442444
def test_to_sqlglot_column_defs():
443-
import sqlglot.expressions as sge
444-
445445
schema = sch.schema({"a": "int64", "b": "string", "c": "!string"})
446446
columns = schema.to_sqlglot_column_defs("duckdb")
447447

@@ -463,9 +463,6 @@ def test_to_sqlglot_column_defs_empty_schema():
463463

464464

465465
def test_to_sqlglot_column_defs_create_table_integration():
466-
import sqlglot as sg
467-
import sqlglot.expressions as sge
468-
469466
schema = sch.schema({"id": "!int64", "name": "string"})
470467
columns = schema.to_sqlglot_column_defs("duckdb")
471468

@@ -478,3 +475,130 @@ def test_to_sqlglot_column_defs_create_table_integration():
478475
sql = create_stmt.sql(dialect="duckdb")
479476
expected = 'CREATE TABLE "test_table" ("id" BIGINT NOT NULL, "name" TEXT)'
480477
assert sql == expected
478+
479+
480+
def test_schema_from_sqlglot():
481+
columns = [
482+
sge.ColumnDef(
483+
this=sg.to_identifier("bigint_col", quoted=True),
484+
kind=sge.DataType(this=sge.DataType.Type.BIGINT),
485+
),
486+
sge.ColumnDef(
487+
this=sg.to_identifier("int_col", quoted=True),
488+
kind=sge.DataType(this=sge.DataType.Type.INT),
489+
),
490+
sge.ColumnDef(
491+
this=sg.to_identifier("smallint_col", quoted=True),
492+
kind=sge.DataType(this=sge.DataType.Type.SMALLINT),
493+
),
494+
sge.ColumnDef(
495+
this=sg.to_identifier("tinyint_col", quoted=True),
496+
kind=sge.DataType(this=sge.DataType.Type.TINYINT),
497+
),
498+
sge.ColumnDef(
499+
this=sg.to_identifier("double_col", quoted=True),
500+
kind=sge.DataType(this=sge.DataType.Type.DOUBLE),
501+
),
502+
sge.ColumnDef(
503+
this=sg.to_identifier("float_col", quoted=True),
504+
kind=sge.DataType(this=sge.DataType.Type.FLOAT),
505+
),
506+
sge.ColumnDef(
507+
this=sg.to_identifier("varchar_col", quoted=True),
508+
kind=sge.DataType(this=sge.DataType.Type.VARCHAR),
509+
),
510+
sge.ColumnDef(
511+
this=sg.to_identifier("text_col", quoted=True),
512+
kind=sge.DataType(this=sge.DataType.Type.TEXT),
513+
),
514+
sge.ColumnDef(
515+
this=sg.to_identifier("boolean_col", quoted=True),
516+
kind=sge.DataType(this=sge.DataType.Type.BOOLEAN),
517+
),
518+
sge.ColumnDef(
519+
this=sg.to_identifier("date_col", quoted=True),
520+
kind=sge.DataType(this=sge.DataType.Type.DATE),
521+
),
522+
sge.ColumnDef(
523+
this=sg.to_identifier("timestamp_col", quoted=True),
524+
kind=sge.DataType(this=sge.DataType.Type.DATETIME),
525+
),
526+
sge.ColumnDef(
527+
this=sg.to_identifier("time_col", quoted=True),
528+
kind=sge.DataType(this=sge.DataType.Type.TIME),
529+
),
530+
sge.ColumnDef(
531+
this=sg.to_identifier("binary_col", quoted=True),
532+
kind=sge.DataType(this=sge.DataType.Type.BINARY),
533+
),
534+
sge.ColumnDef(
535+
this=sg.to_identifier("uuid_col", quoted=True),
536+
kind=sge.DataType(this=sge.DataType.Type.UUID),
537+
),
538+
sge.ColumnDef(
539+
this=sg.to_identifier("json_col", quoted=True),
540+
kind=sge.DataType(this=sge.DataType.Type.JSON),
541+
),
542+
sge.ColumnDef(
543+
this=sg.to_identifier("decimal_col", quoted=True),
544+
kind=sge.DataType(
545+
this=sge.DataType.Type.DECIMAL,
546+
expressions=[
547+
sge.DataTypeParam(this=sge.Literal.number(10)),
548+
sge.DataTypeParam(this=sge.Literal.number(2)),
549+
],
550+
),
551+
),
552+
sge.ColumnDef(
553+
this=sg.to_identifier("not_null_col", quoted=True),
554+
kind=sge.DataType(this=sge.DataType.Type.VARCHAR),
555+
constraints=[sge.ColumnConstraint(kind=sge.NotNullColumnConstraint())],
556+
),
557+
sge.ColumnDef(
558+
this=sg.to_identifier("array_col", quoted=True),
559+
kind=sge.DataType(
560+
this=sge.DataType.Type.ARRAY,
561+
expressions=[sge.DataType(this=sge.DataType.Type.VARCHAR)],
562+
nested=True,
563+
),
564+
),
565+
sge.ColumnDef(
566+
this=sg.to_identifier("map_col", quoted=True),
567+
kind=sge.DataType(
568+
this=sge.DataType.Type.MAP,
569+
expressions=[
570+
sge.DataType(this=sge.DataType.Type.VARCHAR),
571+
sge.DataType(this=sge.DataType.Type.INT),
572+
],
573+
nested=True,
574+
),
575+
),
576+
]
577+
578+
sqlglot_schema = sge.Schema(expressions=columns)
579+
ibis_schema = sch.Schema.from_sqlglot(sqlglot_schema)
580+
expected = sch.Schema(
581+
{
582+
"bigint_col": dt.int64,
583+
"int_col": dt.int32,
584+
"smallint_col": dt.int16,
585+
"tinyint_col": dt.int8,
586+
"double_col": dt.float64,
587+
"float_col": dt.float32,
588+
"varchar_col": dt.string,
589+
"text_col": dt.string,
590+
"boolean_col": dt.boolean,
591+
"date_col": dt.date,
592+
"timestamp_col": dt.timestamp,
593+
"time_col": dt.time,
594+
"binary_col": dt.binary,
595+
"uuid_col": dt.uuid,
596+
"json_col": dt.json,
597+
"decimal_col": dt.Decimal(10, 2),
598+
"not_null_col": dt.String(nullable=False),
599+
"array_col": dt.Array(dt.string),
600+
"map_col": dt.Map(dt.string, dt.int32),
601+
}
602+
)
603+
604+
assert ibis_schema == expected

0 commit comments

Comments
 (0)