Skip to content

Commit a544b75

Browse files
srittauAvasam
andauthored
[SQLAlchemy] Annotate row classes (#9568)
Co-authored-by: Avasam <[email protected]>
1 parent 1d15121 commit a544b75

File tree

3 files changed

+62
-40
lines changed

3 files changed

+62
-40
lines changed

stubs/SQLAlchemy/@tests/stubtest_allowlist.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,12 @@ sqlalchemy.testing.provision.stop_test_class_outside_fixtures
5959
sqlalchemy.testing.provision.temp_table_keyword_args
6060
sqlalchemy.testing.provision.update_db_opts
6161

62+
# potentially replaced at runtime
63+
sqlalchemy.engine.Row.count
64+
sqlalchemy.engine.Row.index
65+
sqlalchemy.engine.row.Row.count
66+
sqlalchemy.engine.row.Row.index
67+
6268
# KeyError/AttributeError on import due to dynamic initialization from a different module
6369
sqlalchemy.testing.fixtures
6470
sqlalchemy.testing.pickleable
Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,23 @@
1-
from typing import Any
1+
from _typeshed import Incomplete
2+
from collections.abc import Callable, Iterable, Iterator
3+
from typing import Any, overload
24

35
class BaseRow:
4-
def __init__(self, parent, processors, keymap, key_style, data) -> None: ...
5-
def __reduce__(self): ...
6-
def __iter__(self): ...
6+
def __init__(
7+
self,
8+
__parent,
9+
__processors: Iterable[Callable[[Any], Any]] | None,
10+
__keymap: dict[Incomplete, Incomplete],
11+
__key_style: int,
12+
__row: Iterable[Any],
13+
) -> None: ...
14+
def __reduce__(self) -> tuple[Incomplete, tuple[Incomplete, Incomplete]]: ...
15+
def __iter__(self) -> Iterator[Any]: ...
716
def __len__(self) -> int: ...
817
def __hash__(self) -> int: ...
9-
__getitem__: Any
18+
@overload
19+
def __getitem__(self, __key: str | int) -> tuple[Any, ...]: ...
20+
@overload
21+
def __getitem__(self, __key: slice) -> tuple[tuple[Any, ...]]: ...
1022

1123
def safe_rowproxy_reconstructor(__cls, __state): ...
Lines changed: 39 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
import abc
2-
from collections.abc import ItemsView, KeysView, Mapping, Sequence, ValuesView
3-
from typing import Any
1+
from collections.abc import ItemsView, Iterator, KeysView, Mapping, Sequence, ValuesView
2+
from typing import Any, Generic, TypeVar
43

54
from ..cresultproxy import BaseRow as BaseRow
65

6+
_VT_co = TypeVar("_VT_co", covariant=True)
7+
78
MD_INDEX: int
89

910
def rowproxy_reconstructor(cls, state): ...
@@ -13,45 +14,48 @@ KEY_OBJECTS_ONLY: int
1314
KEY_OBJECTS_BUT_WARN: int
1415
KEY_OBJECTS_NO_WARN: int
1516

16-
class Row(BaseRow, Sequence[Any], metaclass=abc.ABCMeta):
17+
class Row(BaseRow, Sequence[Any]):
18+
# The count and index methods are inherited from Sequence.
19+
# If the result set contains columns with the same names, these
20+
# fields contains their respective values, instead. We don't reflect
21+
# this in the stubs.
22+
__hash__ = BaseRow.__hash__ # type: ignore[assignment]
23+
def __lt__(self, other: Row | tuple[Any, ...]) -> bool: ...
24+
def __le__(self, other: Row | tuple[Any, ...]) -> bool: ...
25+
def __ge__(self, other: Row | tuple[Any, ...]) -> bool: ...
26+
def __gt__(self, other: Row | tuple[Any, ...]) -> bool: ...
27+
def __eq__(self, other: object) -> bool: ...
28+
def __ne__(self, other: object) -> bool: ...
29+
def keys(self) -> list[str]: ...
30+
# The following methods are public, but have a leading underscore
31+
# to prevent conflicts with column names.
1732
@property
18-
def count(self): ...
33+
def _mapping(self) -> RowMapping: ...
1934
@property
20-
def index(self): ...
21-
def __contains__(self, key): ...
22-
__hash__ = BaseRow.__hash__
23-
def __lt__(self, other): ...
24-
def __le__(self, other): ...
25-
def __ge__(self, other): ...
26-
def __gt__(self, other): ...
27-
def __eq__(self, other): ...
28-
def __ne__(self, other): ...
29-
def keys(self): ...
30-
31-
class LegacyRow(Row, metaclass=abc.ABCMeta):
32-
def __contains__(self, key): ...
33-
def has_key(self, key): ...
34-
def items(self): ...
35-
def iterkeys(self): ...
36-
def itervalues(self): ...
37-
def values(self): ...
35+
def _fields(self) -> tuple[str, ...]: ...
36+
def _asdict(self) -> dict[str, Any]: ...
37+
38+
class LegacyRow(Row):
39+
def has_key(self, key: str) -> bool: ...
40+
def items(self) -> list[tuple[str, Any]]: ...
41+
def iterkeys(self) -> Iterator[str]: ...
42+
def itervalues(self) -> Iterator[Any]: ...
43+
def values(self) -> list[Any]: ...
3844

3945
BaseRowProxy = BaseRow
4046
RowProxy = Row
4147

42-
class ROMappingView(KeysView[Any], ValuesView[Any], ItemsView[Any, Any]):
43-
def __init__(self, mapping, items) -> None: ...
48+
class ROMappingView(KeysView[str], ValuesView[_VT_co], ItemsView[str, _VT_co], Generic[_VT_co]): # type: ignore[misc]
49+
def __init__(self, mapping: RowMapping, items: list[_VT_co]) -> None: ...
4450
def __len__(self) -> int: ...
45-
def __iter__(self): ...
46-
def __contains__(self, item): ...
47-
def __eq__(self, other): ...
48-
def __ne__(self, other): ...
51+
def __iter__(self) -> Iterator[_VT_co]: ... # type: ignore[override]
52+
def __eq__(self, other: ROMappingView[_VT_co]) -> bool: ... # type: ignore[override]
53+
def __ne__(self, other: ROMappingView[_VT_co]) -> bool: ... # type: ignore[override]
4954

50-
class RowMapping(BaseRow, Mapping[Any, Any]):
55+
class RowMapping(BaseRow, Mapping[str, Row]):
5156
__getitem__: Any
52-
def __iter__(self): ...
57+
def __iter__(self) -> Iterator[str]: ...
5358
def __len__(self) -> int: ...
54-
def __contains__(self, key): ...
55-
def items(self): ...
56-
def keys(self): ...
57-
def values(self): ...
59+
def items(self) -> ROMappingView[tuple[str, Any]]: ... # type: ignore[override]
60+
def keys(self) -> list[str]: ... # type: ignore[override]
61+
def values(self) -> ROMappingView[Any]: ... # type: ignore[override]

0 commit comments

Comments
 (0)