Skip to content

Commit 8c3cf70

Browse files
committed
refactor(postgres): use raw sql for get_schema invocation to simplify code
1 parent dfb818a commit 8c3cf70

File tree

1 file changed

+29
-51
lines changed

1 file changed

+29
-51
lines changed

ibis/backends/postgres/__init__.py

Lines changed: 29 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -506,69 +506,47 @@ def get_schema(
506506
a = ColGen(table="a")
507507
c = ColGen(table="c")
508508
n = ColGen(table="n")
509-
t = ColGen(table="t")
510-
e = ColGen(table="e")
511509

512510
format_type = self.compiler.f["pg_catalog.format_type"]
513511

514512
# If no database is specified, assume the current database
515-
db = database or self.current_database
516-
517-
dbs = [sge.convert(db)]
513+
dbs = [database or self.current_database]
518514

519515
# If a database isn't specified, then include temp tables in the
520516
# returned values
521517
if database is None and (temp_table_db := self._session_temp_db) is not None:
522-
dbs.append(sge.convert(temp_table_db))
523-
524-
type_info = (
525-
sg.select(
526-
a.attname.as_("column_name"),
527-
sg.case()
528-
.when(
529-
sge.Exists(
530-
this=sg.select(1)
531-
.from_(sg.table("pg_type", db="pg_catalog").as_("t"))
532-
.join(
533-
sg.table("pg_enum", db="pg_catalog").as_("e"),
534-
on=sg.and_(
535-
e.enumtypid.eq(t.oid),
536-
t.typname.eq(format_type(a.atttypid, a.atttypmod)),
537-
),
538-
)
539-
),
540-
sge.convert("enum"),
541-
)
542-
.else_(format_type(a.atttypid, a.atttypmod))
543-
.as_("data_type"),
544-
sg.not_(a.attnotnull).as_("nullable"),
545-
)
546-
.from_(sg.table("pg_attribute", db="pg_catalog").as_("a"))
547-
.join(
548-
sg.table("pg_class", db="pg_catalog").as_("c"),
549-
on=c.oid.eq(a.attrelid),
550-
join_type="INNER",
551-
)
552-
.join(
553-
sg.table("pg_namespace", db="pg_catalog").as_("n"),
554-
on=n.oid.eq(c.relnamespace),
555-
join_type="INNER",
556-
)
557-
.where(
558-
a.attnum > 0,
559-
sg.not_(a.attisdropped),
560-
n.nspname.isin(*dbs),
561-
c.relname.eq(sge.convert(name)),
562-
)
563-
.order_by(a.attnum)
564-
.sql(self.dialect)
565-
)
566-
518+
dbs.append(temp_table_db)
519+
520+
type_info = """\
521+
SELECT
522+
a.attname AS column_name,
523+
CASE
524+
WHEN EXISTS(
525+
SELECT 1
526+
FROM pg_catalog.pg_type t
527+
INNER JOIN pg_catalog.pg_enum e
528+
ON e.enumtypid = t.oid
529+
AND t.typname = pg_catalog.format_type(a.atttypid, a.atttypmod)
530+
) THEN 'enum'
531+
ELSE pg_catalog.format_type(a.atttypid, a.atttypmod)
532+
END AS data_type,
533+
NOT a.attnotnull AS nullable
534+
FROM pg_catalog.pg_attribute a
535+
INNER JOIN pg_catalog.pg_class c
536+
ON a.attrelid = c.oid
537+
INNER JOIN pg_catalog.pg_namespace n
538+
ON c.relnamespace = n.oid
539+
WHERE a.attnum > 0
540+
AND NOT a.attisdropped
541+
AND n.nspname = ANY(%(dbs)s)
542+
AND c.relname = %(name)s
543+
ORDER BY a.attnum ASC"""
567544
type_mapper = self.compiler.type_mapper
568545

569546
con = self.con
547+
params = {"dbs": dbs, "name": name}
570548
with con.cursor() as cursor, con.transaction():
571-
rows = cursor.execute(type_info).fetchall()
549+
rows = cursor.execute(type_info, params, prepare=True).fetchall()
572550

573551
if not rows:
574552
raise com.TableNotFound(name)

0 commit comments

Comments
 (0)