Skip to content

Commit 421972a

Browse files
authored
refactor(duckdb): simplify loading and installation of extensions (#10900)
1 parent 8b9b143 commit 421972a

File tree

1 file changed

+26
-9
lines changed

1 file changed

+26
-9
lines changed

ibis/backends/duckdb/__init__.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import contextlib
77
import urllib
88
import warnings
9-
from operator import itemgetter
109
from pathlib import Path
1110
from typing import TYPE_CHECKING, Any, Literal
1211

@@ -450,17 +449,35 @@ def _load_extensions(
450449
) -> None:
451450
f = self.compiler.f
452451
query = (
453-
sg.select(f.anon.unnest(f.list_append(C.aliases, C.extension_name)))
452+
sg.select(
453+
f.anon.unnest(
454+
f.list_intersect(
455+
f.list_append(C.aliases, C.extension_name),
456+
f.list_value(*extensions),
457+
)
458+
),
459+
C.installed,
460+
C.loaded,
461+
)
454462
.from_(f.duckdb_extensions())
455-
.where(C.installed, C.loaded)
463+
.where(sg.not_(C.installed & C.loaded))
456464
)
457465
with self._safe_raw_sql(query) as cur:
458-
installed = map(itemgetter(0), cur.fetchall())
459-
# Install and load all other extensions
460-
todo = frozenset(extensions).difference(installed)
461-
for extension in todo:
462-
cur.install_extension(extension, force_install=force_install)
463-
cur.load_extension(extension)
466+
if not (not_installed_or_loaded := cur.fetchall()):
467+
return
468+
469+
commands = [
470+
"FORCE " * force_install + f"INSTALL '{extension}'"
471+
for extension, installed, _ in not_installed_or_loaded
472+
if not installed
473+
]
474+
commands.extend(
475+
f"LOAD '{extension}'"
476+
for extension, _, loaded in not_installed_or_loaded
477+
if not loaded
478+
)
479+
command = ";".join(commands)
480+
cur.execute(command)
464481

465482
def load_extension(self, extension: str, force_install: bool = False) -> None:
466483
"""Install and load a duckdb extension by name or path.

0 commit comments

Comments
 (0)