Skip to content

Commit ae49c48

Browse files
authored
fix: vectorizer_relationship with tuple table args (#805)
1 parent 1d23f92 commit ae49c48

File tree

2 files changed

+42
-2
lines changed

2 files changed

+42
-2
lines changed

projects/pgai/pgai/sqlalchemy/__init__.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Generic, TypeVar, overload
1+
from typing import Any, Generic, TypeVar, cast, overload
22

33
from pgvector.sqlalchemy import Vector # type: ignore
44
from sqlalchemy import ForeignKeyConstraint, Integer, Text, event, inspect
@@ -75,7 +75,18 @@ def _initialize_all(self):
7575
self.__get__(None, self.owner)
7676

7777
def set_schemas_correctly(self, owner: type[DeclarativeBase]) -> None:
78-
table_args_schema_name = getattr(owner, "__table_args__", {}).get("schema")
78+
table_args = getattr(owner, "__table_args__", {})
79+
table_args_schema_name: str | None = None
80+
81+
if isinstance(table_args, dict):
82+
table_args_schema_name = cast(str | None, table_args.get("schema")) # type: ignore
83+
elif (
84+
isinstance(table_args, tuple)
85+
and len(table_args) > 0 # type: ignore
86+
and isinstance(table_args[-1], dict)
87+
):
88+
table_args_schema_name = cast(str | None, table_args[-1].get("schema")) # type: ignore
89+
7990
self.target_schema = (
8091
self.target_schema
8192
or table_args_schema_name
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
from sqlalchemy import Column, Integer, Text, UniqueConstraint
2+
from sqlalchemy.orm import DeclarativeBase
3+
4+
from pgai.sqlalchemy import vectorizer_relationship
5+
6+
7+
class Base(DeclarativeBase):
8+
pass
9+
10+
11+
class FeatureWithTupleTableArgs(Base):
12+
__tablename__ = "features_tuple"
13+
14+
id = Column(Integer, primary_key=True)
15+
name = Column(Text, nullable=False)
16+
tenant_id = Column(Integer, nullable=False)
17+
18+
__table_args__ = (UniqueConstraint("name", "tenant_id"),)
19+
20+
embeddings = vectorizer_relationship(
21+
dimensions=1536, target_table="features_embeddings"
22+
)
23+
24+
25+
def test_tuple_table_args():
26+
FeatureWithTupleTableArgs()
27+
embedding_class = FeatureWithTupleTableArgs.embeddings
28+
29+
assert embedding_class is not None

0 commit comments

Comments
 (0)