Skip to content

Commit 8710385

Browse files
simonwclaude
andauthored
Type fixes now enforced by ty
* Fix type warning for pipe.stdout possibly being None Add conditional check before calling .read() on pipe.stdout since Popen can return None for stdout. * Use ctx.meta instead of dynamic attribute for database cleanup Click's Context.meta dictionary is the proper way to store arbitrary data on the context object, avoiding type checker warnings about dynamic attribute assignment. * Add assert for tables.callback before calling Click's callback attribute is typed as Optional[Callable], so add assert to satisfy type checker that it's not None. * Fix type errors in cli.py and db.py - Add type annotation for Database.conn to fix context manager errors - Convert exception objects to str() when raising ClickException - Handle None return from find_spatialite() with proper error message * Fix remaining type errors in cli.py - Add typing import and type annotations for dict kwargs - Use db.table() instead of db[] for extract command - Fix missing str() conversion for exception * Fix type errors in db.py - Add type annotation for Database.conn - Add type: ignore for optional sqlite_dump import - Update execute/query parameter types to Sequence|Dict for sqlite3 compatibility - Use getattr for fn.__name__ access to handle callables without __name__ - Handle None return from find_spatialite() with OSError - Fix pk_values assignment to use local variable * Add type: ignore for optional pysqlite3 and sqlean imports These are alternative sqlite3 implementations that may not be installed. * Fix type errors in tests and plugins - Add type: ignore for monkey-patching Database.__init__ in conftest - Fix CLI test to pass string "2" instead of integer to Click invoke - Add type: ignore for optional sqlean import - Fix add_geometry_column test to use "XY" instead of integer 2 - Add type: ignore for click.Context as context manager - Add type: ignore for enable_fts test that intentionally omits argument - Add type: ignore for sys._called_from_test dynamic attribute - Fix rows_from_file test type error for intentional wrong argument - Handle None from pm.get_hookcallers in plugins.py * Use db.table() instead of db[] for Table-specific operations Changes db[table] to db.table(table) in CLI commands where we know we're working with tables, not views. This resolves most of the Table | View disambiguation type warnings since db.table() returns Table directly rather than Table | View. * Fix remaining type warnings in sqlite_utils package - Add assert for sniff_buffer not being None - Handle cursor.fetchone() potentially returning None - Use db.table() for counts_table and index_foreign_keys - Add type: ignore for cursor union type in raw mode * Ran Black * Run ty in CI * ty check sqlite_utils * Skip running ty on Windows --------- Co-authored-by: Claude Opus 4.5 <[email protected]>
1 parent fd5b09f commit 8710385

File tree

13 files changed

+89
-68
lines changed

13 files changed

+89
-68
lines changed

.github/workflows/test.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,11 @@ jobs:
4545
run: mypy sqlite_utils tests
4646
- name: run flake8
4747
run: flake8
48+
- name: run ty
49+
if: matrix.os != 'windows-latest'
50+
run: |
51+
pip install uv
52+
uv run ty check sqlite_utils
4853
- name: Check formatting
4954
run: black . --check
5055
- name: Check if cog needs to be run

docs/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def linkcode_resolve(domain, info):
7979
#
8080
# The short X.Y version.
8181
pipe = Popen("git describe --tags --always", stdout=PIPE, shell=True)
82-
git_version = pipe.stdout.read().decode("utf8")
82+
git_version = pipe.stdout.read().decode("utf8") if pipe.stdout else ""
8383

8484
if git_version:
8585
version = git_version.rsplit("-", 1)[0]

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ dev = [
4747
# flake8
4848
"flake8",
4949
"flake8-pyproject",
50+
"ty",
5051
]
5152
docs = [
5253
"beanbag-docutils>=2.0",

sqlite_utils/cli.py

Lines changed: 47 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import base64
2+
from typing import Any
23
import click
34
from click_default_group import DefaultGroup # type: ignore
45
from datetime import datetime, timezone
@@ -50,20 +51,19 @@ def _register_db_for_cleanup(db):
5051
ctx = click.get_current_context(silent=True)
5152
if ctx is None:
5253
return
53-
if not hasattr(ctx, "_databases_to_close"):
54-
ctx._databases_to_close = []
54+
if "_databases_to_close" not in ctx.meta:
55+
ctx.meta["_databases_to_close"] = []
5556
ctx.call_on_close(lambda: _close_databases(ctx))
56-
ctx._databases_to_close.append(db)
57+
ctx.meta["_databases_to_close"].append(db)
5758

5859

5960
def _close_databases(ctx):
6061
"""Close all databases registered for cleanup."""
61-
if hasattr(ctx, "_databases_to_close"):
62-
for db in ctx._databases_to_close:
63-
try:
64-
db.close()
65-
except Exception:
66-
pass
62+
for db in ctx.meta.get("_databases_to_close", []):
63+
try:
64+
db.close()
65+
except Exception:
66+
pass
6767

6868

6969
VALID_COLUMN_TYPES = ("INTEGER", "TEXT", "FLOAT", "REAL", "BLOB")
@@ -294,6 +294,7 @@ def views(
294294
\b
295295
sqlite-utils views trees.db
296296
"""
297+
assert tables.callback is not None
297298
tables.callback(
298299
path=path,
299300
fts4=False,
@@ -338,7 +339,7 @@ def optimize(path, tables, no_vacuum, load_extension):
338339
tables = db.table_names(fts4=True) + db.table_names(fts5=True)
339340
with db.conn:
340341
for table in tables:
341-
db[table].optimize()
342+
db.table(table).optimize()
342343
if not no_vacuum:
343344
db.vacuum()
344345

@@ -366,7 +367,7 @@ def rebuild_fts(path, tables, load_extension):
366367
tables = db.table_names(fts4=True) + db.table_names(fts5=True)
367368
with db.conn:
368369
for table in tables:
369-
db[table].rebuild_fts()
370+
db.table(table).rebuild_fts()
370371

371372

372373
@cli.command()
@@ -393,7 +394,7 @@ def analyze(path, names):
393394
else:
394395
db.analyze()
395396
except OperationalError as e:
396-
raise click.ClickException(e)
397+
raise click.ClickException(str(e))
397398

398399

399400
@cli.command()
@@ -496,7 +497,7 @@ def add_column(
496497
_register_db_for_cleanup(db)
497498
_load_extensions(db, load_extension)
498499
try:
499-
db[table].add_column(
500+
db.table(table).add_column(
500501
col_name, col_type, fk=fk, fk_col=fk_col, not_null_default=not_null_default
501502
)
502503
except OperationalError as ex:
@@ -534,9 +535,11 @@ def add_foreign_key(
534535
_register_db_for_cleanup(db)
535536
_load_extensions(db, load_extension)
536537
try:
537-
db[table].add_foreign_key(column, other_table, other_column, ignore=ignore)
538+
db.table(table).add_foreign_key(
539+
column, other_table, other_column, ignore=ignore
540+
)
538541
except AlterError as e:
539-
raise click.ClickException(e)
542+
raise click.ClickException(str(e))
540543

541544

542545
@cli.command(name="add-foreign-keys")
@@ -571,7 +574,7 @@ def add_foreign_keys(path, foreign_key, load_extension):
571574
try:
572575
db.add_foreign_keys(tuples)
573576
except AlterError as e:
574-
raise click.ClickException(e)
577+
raise click.ClickException(str(e))
575578

576579

577580
@cli.command(name="index-foreign-keys")
@@ -644,7 +647,7 @@ def create_index(
644647
if col.startswith("-"):
645648
col = DescIndex(col[1:])
646649
columns.append(col)
647-
db[table].create_index(
650+
db.table(table).create_index(
648651
columns,
649652
index_name=name,
650653
unique=unique,
@@ -705,7 +708,7 @@ def enable_fts(
705708
replace=replace,
706709
)
707710
except OperationalError as ex:
708-
raise click.ClickException(ex)
711+
raise click.ClickException(str(ex))
709712

710713

711714
@cli.command(name="populate-fts")
@@ -728,7 +731,7 @@ def populate_fts(path, table, column, load_extension):
728731
db = sqlite_utils.Database(path)
729732
_register_db_for_cleanup(db)
730733
_load_extensions(db, load_extension)
731-
db[table].populate_fts(column)
734+
db.table(table).populate_fts(column)
732735

733736

734737
@cli.command(name="disable-fts")
@@ -750,7 +753,7 @@ def disable_fts(path, table, load_extension):
750753
db = sqlite_utils.Database(path)
751754
_register_db_for_cleanup(db)
752755
_load_extensions(db, load_extension)
753-
db[table].disable_fts()
756+
db.table(table).disable_fts()
754757

755758

756759
@cli.command(name="enable-wal")
@@ -826,7 +829,7 @@ def enable_counts(path, tables, load_extension):
826829
if bad_tables:
827830
raise click.ClickException("Invalid tables: {}".format(bad_tables))
828831
for table in tables:
829-
db[table].enable_counts()
832+
db.table(table).enable_counts()
830833

831834

832835
@cli.command(name="reset-counts")
@@ -1036,13 +1039,14 @@ def insert_upsert_implementation(
10361039
if csv or tsv:
10371040
if sniff:
10381041
# Read first 2048 bytes and use that to detect
1042+
assert sniff_buffer is not None
10391043
first_bytes = sniff_buffer.peek(2048)
10401044
dialect = csv_std.Sniffer().sniff(
10411045
first_bytes.decode(encoding, "ignore")
10421046
)
10431047
else:
10441048
dialect = "excel-tab" if tsv else "excel"
1045-
csv_reader_args = {"dialect": dialect}
1049+
csv_reader_args: dict[str, Any] = {"dialect": dialect}
10461050
if delimiter:
10471051
csv_reader_args["delimiter"] = delimiter
10481052
if quotechar:
@@ -1146,7 +1150,7 @@ def insert_upsert_implementation(
11461150
return
11471151

11481152
try:
1149-
db[table].insert_all(
1153+
db.table(table).insert_all(
11501154
docs, pk=pk, batch_size=batch_size, alter=alter, **extra_kwargs
11511155
)
11521156
except Exception as e:
@@ -1173,7 +1177,7 @@ def insert_upsert_implementation(
11731177
else:
11741178
raise
11751179
if tracker is not None:
1176-
db[table].transform(types=tracker.types)
1180+
db.table(table).transform(types=tracker.types)
11771181

11781182
# Clean up open file-like objects
11791183
if sniff_buffer:
@@ -1636,7 +1640,7 @@ def create_table(
16361640
table
16371641
)
16381642
)
1639-
db[table].create(
1643+
db.table(table).create(
16401644
coltypes,
16411645
pk=pks[0] if len(pks) == 1 else pks,
16421646
not_null=not_null,
@@ -1667,7 +1671,7 @@ def duplicate(path, table, new_table, ignore, load_extension):
16671671
_register_db_for_cleanup(db)
16681672
_load_extensions(db, load_extension)
16691673
try:
1670-
db[table].duplicate(new_table)
1674+
db.table(table).duplicate(new_table)
16711675
except NoTable:
16721676
if not ignore:
16731677
raise click.ClickException('Table "{}" does not exist'.format(table))
@@ -2028,9 +2032,9 @@ def memory(
20282032
if flatten:
20292033
rows = (_flatten(row) for row in rows)
20302034

2031-
db[file_table].insert_all(rows, alter=True)
2035+
db.table(file_table).insert_all(rows, alter=True)
20322036
if tracker is not None:
2033-
db[file_table].transform(types=tracker.types)
2037+
db.table(file_table).transform(types=tracker.types)
20342038
# Add convenient t / t1 / t2 views
20352039
view_names = ["t{}".format(i + 1)]
20362040
if i == 0:
@@ -2119,7 +2123,8 @@ def _execute_query(
21192123
else:
21202124
headers = [c[0] for c in cursor.description]
21212125
if raw:
2122-
data = cursor.fetchone()[0]
2126+
row = cursor.fetchone() # type: ignore[union-attr]
2127+
data = row[0] if row else None
21232128
if isinstance(data, bytes):
21242129
sys.stdout.buffer.write(data)
21252130
else:
@@ -2200,7 +2205,7 @@ def search(
22002205
_register_db_for_cleanup(db)
22012206
_load_extensions(db, load_extension)
22022207
# Check table exists
2203-
table_obj = db[dbtable]
2208+
table_obj = db.table(dbtable)
22042209
if not table_obj.exists():
22052210
raise click.ClickException("Table '{}' does not exist".format(dbtable))
22062211
if not table_obj.detect_fts():
@@ -2612,10 +2617,10 @@ def transform(
26122617
kwargs["add_foreign_keys"] = add_foreign_keys
26132618

26142619
if sql:
2615-
for line in db[table].transform_sql(**kwargs):
2620+
for line in db.table(table).transform_sql(**kwargs):
26162621
click.echo(line)
26172622
else:
2618-
db[table].transform(**kwargs)
2623+
db.table(table).transform(**kwargs)
26192624

26202625

26212626
@cli.command()
@@ -2656,13 +2661,13 @@ def extract(
26562661
db = sqlite_utils.Database(path)
26572662
_register_db_for_cleanup(db)
26582663
_load_extensions(db, load_extension)
2659-
kwargs = dict(
2664+
kwargs: dict[str, Any] = dict(
26602665
columns=columns,
26612666
table=other_table,
26622667
fk_column=fk_column,
26632668
rename=dict(rename),
26642669
)
2665-
db[table].extract(**kwargs)
2670+
db.table(table).extract(**kwargs)
26662671

26672672

26682673
@cli.command(name="insert-files")
@@ -2803,7 +2808,7 @@ def _content_text(p):
28032808
_load_extensions(db, load_extension)
28042809
try:
28052810
with db.conn:
2806-
db[table].insert_all(
2811+
db.table(table).insert_all(
28072812
to_insert(),
28082813
pk=pks[0] if len(pks) == 1 else pks,
28092814
alter=alter,
@@ -3122,7 +3127,7 @@ def wrapped_fn(value):
31223127

31233128
fn = wrapped_fn
31243129
try:
3125-
db[table].convert(
3130+
db.table(table).convert(
31263131
columns,
31273132
fn,
31283133
where=where,
@@ -3212,7 +3217,7 @@ def add_geometry_column(
32123217
_load_extensions(db, load_extension)
32133218
db.init_spatialite()
32143219

3215-
if db[table].add_geometry_column(
3220+
if db.table(table).add_geometry_column(
32163221
column_name, geometry_type, srid, coord_dimension, not_null
32173222
):
32183223
click.echo(f"Added {geometry_type} column {column_name} to {table}")
@@ -3250,7 +3255,7 @@ def create_spatial_index(db_path, table, column_name, load_extension):
32503255
"You must add a geometry column before creating a spatial index"
32513256
)
32523257

3253-
db[table].create_spatial_index(column_name)
3258+
db.table(table).create_spatial_index(column_name)
32543259

32553260

32563261
@cli.command(name="plugins")
@@ -3361,7 +3366,10 @@ def _load_extensions(db, load_extension):
33613366
db.conn.enable_load_extension(True)
33623367
for ext in load_extension:
33633368
if ext == "spatialite" and not os.path.exists(ext):
3364-
ext = find_spatialite()
3369+
found = find_spatialite()
3370+
if found is None:
3371+
raise click.ClickException("Could not find SpatiaLite extension")
3372+
ext = found
33653373
if ":" in ext:
33663374
path, _, entrypoint = ext.partition(":")
33673375
db.conn.execute("SELECT load_extension(?, ?)", [path, entrypoint])

0 commit comments

Comments
 (0)