diff --git a/.github/workflows/e2e-tests.yml b/.github/workflows/e2e-tests.yml new file mode 100644 index 0000000000..b39f2cd3ca --- /dev/null +++ b/.github/workflows/e2e-tests.yml @@ -0,0 +1,190 @@ +name: E2E Tests (Real Databases) + +on: + workflow_dispatch: # Manual trigger only for E2E tests + pull_request: + branches: [ main, dev ] + paths: + - 'lightrag/kg/postgres_impl.py' + - 'lightrag/kg/qdrant_impl.py' + - 'tests/test_e2e_*.py' + +jobs: + e2e-postgres: + name: E2E PostgreSQL Tests + runs-on: ubuntu-latest + + services: + postgres: + image: ankane/pgvector:latest + env: + POSTGRES_USER: lightrag + POSTGRES_PASSWORD: lightrag_test_password + POSTGRES_DB: lightrag_test + ports: + - 5432:5432 + options: >- + --health-cmd "pg_isready -U lightrag" + --health-interval 10s + --health-timeout 5s + --health-retries 5 + + strategy: + matrix: + python-version: ['3.10', '3.12'] + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Cache pip packages + uses: actions/cache@v4 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pip-e2e-${{ hashFiles('**/pyproject.toml') }} + restore-keys: | + ${{ runner.os }}-pip-e2e- + ${{ runner.os }}-pip- + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -e ".[api]" + pip install pytest pytest-asyncio asyncpg numpy qdrant-client + + - name: Wait for PostgreSQL + run: | + timeout 30 bash -c 'until pg_isready -h localhost -p 5432 -U lightrag; do sleep 1; done' + + - name: Setup pgvector extension + env: + PGPASSWORD: lightrag_test_password + run: | + psql -h localhost -U lightrag -d lightrag_test -c "CREATE EXTENSION IF NOT EXISTS vector;" + psql -h localhost -U lightrag -d lightrag_test -c "SELECT extname, extversion FROM pg_extension WHERE extname = 'vector';" + + - name: Run PostgreSQL E2E tests + env: + POSTGRES_HOST: localhost + POSTGRES_PORT: 5432 + POSTGRES_USER: lightrag + POSTGRES_PASSWORD: lightrag_test_password + POSTGRES_DATABASE: lightrag_test + run: | + pytest tests/test_e2e_multi_instance.py -k "postgres" -v --tb=short -s + timeout-minutes: 20 + + - name: Upload PostgreSQL test results + if: always() + uses: actions/upload-artifact@v4 + with: + name: e2e-postgres-results-py${{ matrix.python-version }} + path: | + .pytest_cache/ + test-results.xml + retention-days: 7 + + e2e-qdrant: + name: E2E Qdrant Tests + runs-on: ubuntu-latest + + services: + qdrant: + image: qdrant/qdrant:latest + ports: + - 6333:6333 + - 6334:6334 + + strategy: + matrix: + python-version: ['3.10', '3.12'] + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Cache pip packages + uses: actions/cache@v4 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pip-e2e-${{ hashFiles('**/pyproject.toml') }} + restore-keys: | + ${{ runner.os }}-pip-e2e- + ${{ runner.os }}-pip- + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -e ".[api]" + pip install pytest pytest-asyncio qdrant-client numpy + + - name: Wait for Qdrant + run: | + echo "Waiting for Qdrant to be ready..." + for i in {1..60}; do + if curl -s http://localhost:6333 > /dev/null 2>&1; then + echo "Qdrant is ready!" + break + fi + echo "Attempt $i/60: Qdrant not ready yet, waiting..." + sleep 1 + done + # Final check + if ! curl -s http://localhost:6333 > /dev/null 2>&1; then + echo "ERROR: Qdrant failed to start after 60 seconds" + exit 1 + fi + + - name: Verify Qdrant connection + run: | + echo "Verifying Qdrant API..." + curl -X GET "http://localhost:6333/collections" -H "Content-Type: application/json" + echo "" + echo "Qdrant is accessible and ready for testing" + + - name: Run Qdrant E2E tests + env: + QDRANT_URL: http://localhost:6333 + QDRANT_API_KEY: "" + run: | + pytest tests/test_e2e_multi_instance.py -k "qdrant" -v --tb=short -s + timeout-minutes: 15 + + - name: Upload Qdrant test results + if: always() + uses: actions/upload-artifact@v4 + with: + name: e2e-qdrant-results-py${{ matrix.python-version }} + path: | + .pytest_cache/ + test-results.xml + retention-days: 7 + + e2e-summary: + name: E2E Test Summary + runs-on: ubuntu-latest + needs: [e2e-postgres, e2e-qdrant] + if: always() + + steps: + - name: Check test results + run: | + echo "## E2E Test Summary" >> $GITHUB_STEP_SUMMARY + echo "" >> $GITHUB_STEP_SUMMARY + echo "### PostgreSQL E2E Tests" >> $GITHUB_STEP_SUMMARY + echo "Status: ${{ needs.e2e-postgres.result }}" >> $GITHUB_STEP_SUMMARY + echo "" >> $GITHUB_STEP_SUMMARY + echo "### Qdrant E2E Tests" >> $GITHUB_STEP_SUMMARY + echo "Status: ${{ needs.e2e-qdrant.result }}" >> $GITHUB_STEP_SUMMARY + + - name: Fail if any test failed + if: needs.e2e-postgres.result != 'success' || needs.e2e-qdrant.result != 'success' + run: exit 1 diff --git a/.github/workflows/feature-tests.yml b/.github/workflows/feature-tests.yml new file mode 100644 index 0000000000..f46ebcf364 --- /dev/null +++ b/.github/workflows/feature-tests.yml @@ -0,0 +1,74 @@ +name: Feature Branch Tests + +on: + workflow_dispatch: # Allow manual trigger + push: + branches: + - 'feature/**' + pull_request: + branches: [ main, dev ] + +jobs: + migration-tests: + name: Vector Storage Migration Tests + runs-on: ubuntu-latest + + strategy: + matrix: + python-version: ['3.10', '3.11', '3.12'] + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Cache pip packages + uses: actions/cache@v4 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements*.txt', '**/pyproject.toml') }} + restore-keys: | + ${{ runner.os }}-pip- + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -e ".[api]" + pip install pytest pytest-asyncio + + - name: Run Qdrant migration tests + run: | + pytest tests/test_qdrant_migration.py -v --tb=short + continue-on-error: false + + - name: Run PostgreSQL migration tests + run: | + pytest tests/test_postgres_migration.py -v --tb=short + continue-on-error: false + + - name: Run all unit tests (if exists) + run: | + # Run EmbeddingFunc tests + pytest tests/ -k "embedding" -v --tb=short || true + continue-on-error: true + + - name: Upload test results + if: always() + uses: actions/upload-artifact@v4 + with: + name: migration-test-results-py${{ matrix.python-version }} + path: | + .pytest_cache/ + test-results.xml + retention-days: 7 + + - name: Test Summary + if: always() + run: | + echo "## Test Summary" >> $GITHUB_STEP_SUMMARY + echo "- Python: ${{ matrix.python-version }}" >> $GITHUB_STEP_SUMMARY + echo "- Branch: ${{ github.ref_name }}" >> $GITHUB_STEP_SUMMARY + echo "- Commit: ${{ github.sha }}" >> $GITHUB_STEP_SUMMARY diff --git a/examples/multi_model_demo.py b/examples/multi_model_demo.py new file mode 100644 index 0000000000..000c841c9d --- /dev/null +++ b/examples/multi_model_demo.py @@ -0,0 +1,271 @@ +""" +Multi-Model Vector Storage Isolation Demo + +This example demonstrates LightRAG's automatic model isolation feature for vector storage. +When using different embedding models, LightRAG automatically creates separate collections/tables, +preventing dimension mismatches and data pollution. + +Key Features: +- Automatic model suffix generation: {model_name}_{dim}d +- Seamless migration from legacy (no-suffix) to new (with-suffix) collections +- Support for multiple workspaces with different embedding models + +Requirements: +- OpenAI API key (or any OpenAI-compatible API) +- Qdrant or PostgreSQL for vector storage (optional, defaults to NanoVectorDB) +""" + +import asyncio +from lightrag import LightRAG, QueryParam +from lightrag.llm.openai import gpt_4o_mini_complete, openai_embed +from lightrag.utils import EmbeddingFunc + +# Set your API key +# os.environ["OPENAI_API_KEY"] = "your-api-key-here" + + +async def scenario_1_new_workspace_with_explicit_model(): + """ + Scenario 1: Creating a new workspace with explicit model name + + Result: Creates collection/table with name like: + - Qdrant: lightrag_vdb_chunks_text_embedding_3_large_3072d + - PostgreSQL: LIGHTRAG_VDB_CHUNKS_text_embedding_3_large_3072d + """ + print("\n" + "=" * 80) + print("Scenario 1: New Workspace with Explicit Model Name") + print("=" * 80) + + # Define custom embedding function with explicit model name + async def my_embedding_func(texts: list[str]): + return await openai_embed(texts, model="text-embedding-3-large") + + # Create EmbeddingFunc with model_name specified + embedding_func = EmbeddingFunc( + embedding_dim=3072, + func=my_embedding_func, + model_name="text-embedding-3-large", # Explicit model name + ) + + rag = LightRAG( + working_dir="./workspace_large_model", + llm_model_func=gpt_4o_mini_complete, + embedding_func=embedding_func, + ) + + await rag.initialize_storages() + + # Insert sample data + await rag.ainsert("LightRAG supports automatic model isolation for vector storage.") + + # Query + result = await rag.aquery( + "What does LightRAG support?", param=QueryParam(mode="hybrid") + ) + + print(f"\nQuery Result: {result[:200]}...") + print("\n✅ Collection/table created with suffix: text_embedding_3_large_3072d") + + await rag.close() + + +async def scenario_2_legacy_migration(): + """ + Scenario 2: Upgrading from legacy version (without model_name) + + If you previously used LightRAG without specifying model_name, + the first run with model_name will automatically migrate your data. + + Result: Data is migrated from: + - Old: lightrag_vdb_chunks (no suffix) + - New: lightrag_vdb_chunks_text_embedding_ada_002_1536d (with suffix) + """ + print("\n" + "=" * 80) + print("Scenario 2: Automatic Migration from Legacy Format") + print("=" * 80) + + # Step 1: Simulate legacy workspace (no model_name) + print("\n[Step 1] Creating legacy workspace without model_name...") + + async def legacy_embedding_func(texts: list[str]): + return await openai_embed(texts, model="text-embedding-ada-002") + + # Legacy: No model_name specified + legacy_embedding = EmbeddingFunc( + embedding_dim=1536, + func=legacy_embedding_func, + # model_name not specified → uses "unknown" as fallback + ) + + rag_legacy = LightRAG( + working_dir="./workspace_legacy", + llm_model_func=gpt_4o_mini_complete, + embedding_func=legacy_embedding, + ) + + await rag_legacy.initialize_storages() + await rag_legacy.ainsert("Legacy data without model isolation.") + await rag_legacy.close() + + print("✅ Legacy workspace created with suffix: unknown_1536d") + + # Step 2: Upgrade to new version with model_name + print("\n[Step 2] Upgrading to new version with explicit model_name...") + + # New: With model_name specified + new_embedding = EmbeddingFunc( + embedding_dim=1536, + func=legacy_embedding_func, + model_name="text-embedding-ada-002", # Now explicitly specified + ) + + rag_new = LightRAG( + working_dir="./workspace_legacy", # Same working directory + llm_model_func=gpt_4o_mini_complete, + embedding_func=new_embedding, + ) + + # On first initialization, LightRAG will: + # 1. Detect legacy collection exists + # 2. Automatically migrate data to new collection with model suffix + # 3. Legacy collection remains but can be deleted after verification + await rag_new.initialize_storages() + + # Verify data is still accessible + result = await rag_new.aquery( + "What is the legacy data?", param=QueryParam(mode="hybrid") + ) + + print(f"\nQuery Result: {result[:200] if result else 'No results'}...") + print("\n✅ Data migrated to: text_embedding_ada_002_1536d") + print("ℹ️ Legacy collection can be manually deleted after verification") + + await rag_new.close() + + +async def scenario_3_multiple_models_coexistence(): + """ + Scenario 3: Multiple workspaces with different embedding models + + Different embedding models create completely isolated collections/tables, + allowing safe coexistence without dimension conflicts or data pollution. + + Result: + - Workspace A: lightrag_vdb_chunks_bge_small_768d + - Workspace B: lightrag_vdb_chunks_bge_large_1024d + """ + print("\n" + "=" * 80) + print("Scenario 3: Multiple Models Coexistence") + print("=" * 80) + + # Workspace A: Small embedding model (768 dimensions) + print("\n[Workspace A] Using bge-small model (768d)...") + + async def embedding_func_small(texts: list[str]): + # Simulate small embedding model + # In real usage, replace with actual model call + return await openai_embed(texts, model="text-embedding-3-small") + + embedding_a = EmbeddingFunc( + embedding_dim=1536, # text-embedding-3-small dimension + func=embedding_func_small, + model_name="text-embedding-3-small", + ) + + rag_a = LightRAG( + working_dir="./workspace_a", + llm_model_func=gpt_4o_mini_complete, + embedding_func=embedding_a, + ) + + await rag_a.initialize_storages() + await rag_a.ainsert("Workspace A uses small embedding model for efficiency.") + + print("✅ Workspace A created with suffix: text_embedding_3_small_1536d") + + # Workspace B: Large embedding model (3072 dimensions) + print("\n[Workspace B] Using text-embedding-3-large model (3072d)...") + + async def embedding_func_large(texts: list[str]): + # Simulate large embedding model + return await openai_embed(texts, model="text-embedding-3-large") + + embedding_b = EmbeddingFunc( + embedding_dim=3072, # text-embedding-3-large dimension + func=embedding_func_large, + model_name="text-embedding-3-large", + ) + + rag_b = LightRAG( + working_dir="./workspace_b", + llm_model_func=gpt_4o_mini_complete, + embedding_func=embedding_b, + ) + + await rag_b.initialize_storages() + await rag_b.ainsert("Workspace B uses large embedding model for better accuracy.") + + print("✅ Workspace B created with suffix: text_embedding_3_large_3072d") + + # Verify isolation: Query each workspace + print("\n[Verification] Querying both workspaces...") + + result_a = await rag_a.aquery( + "What model does workspace use?", param=QueryParam(mode="hybrid") + ) + result_b = await rag_b.aquery( + "What model does workspace use?", param=QueryParam(mode="hybrid") + ) + + print(f"\nWorkspace A Result: {result_a[:100] if result_a else 'No results'}...") + print(f"Workspace B Result: {result_b[:100] if result_b else 'No results'}...") + + print("\n✅ Both workspaces operate independently without interference") + + await rag_a.close() + await rag_b.close() + + +async def main(): + """ + Run all scenarios to demonstrate model isolation features + """ + print("\n" + "=" * 80) + print("LightRAG Multi-Model Vector Storage Isolation Demo") + print("=" * 80) + print("\nThis demo shows how LightRAG automatically handles:") + print("1. ✅ Automatic model suffix generation") + print("2. ✅ Seamless data migration from legacy format") + print("3. ✅ Multiple embedding models coexistence") + + try: + # Scenario 1: New workspace with explicit model + await scenario_1_new_workspace_with_explicit_model() + + # Scenario 2: Legacy migration + await scenario_2_legacy_migration() + + # Scenario 3: Multiple models coexistence + await scenario_3_multiple_models_coexistence() + + print("\n" + "=" * 80) + print("✅ All scenarios completed successfully!") + print("=" * 80) + + print("\n📝 Key Takeaways:") + print("- Always specify `model_name` in EmbeddingFunc for clear model tracking") + print("- LightRAG automatically migrates legacy data on first run") + print("- Different embedding models create isolated collections/tables") + print("- Collection names follow pattern: {base_name}_{model_name}_{dim}d") + print("\n📚 See the plan document for more details:") + print(" .claude/plan/PR-vector-model-isolation.md") + + except Exception as e: + print(f"\n❌ Error: {e}") + import traceback + + traceback.print_exc() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/lightrag/base.py b/lightrag/base.py index bae0728b3d..9f891a7c89 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -220,6 +220,45 @@ class BaseVectorStorage(StorageNameSpace, ABC): cosine_better_than_threshold: float = field(default=0.2) meta_fields: set[str] = field(default_factory=set) + def _generate_collection_suffix(self) -> str: + """Generates collection/table suffix from embedding_func. + + Returns: + str: Suffix string, e.g. "text_embedding_3_large_3072d" + """ + # Try to get model identifier from the embedding function + # If it's a wrapped function (doesn't have get_model_identifier), + # fallback to the original embedding_func from global_config + if hasattr(self.embedding_func, "get_model_identifier"): + return self.embedding_func.get_model_identifier() + elif "embedding_func" in self.global_config: + original_embedding_func = self.global_config["embedding_func"] + if original_embedding_func is not None and hasattr( + original_embedding_func, "get_model_identifier" + ): + return original_embedding_func.get_model_identifier() + else: + # Debug: log why we couldn't get model identifier + from lightrag.utils import logger + + logger.debug( + f"Could not get model_identifier: embedding_func is {type(original_embedding_func)}, has method={hasattr(original_embedding_func, 'get_model_identifier') if original_embedding_func else False}" + ) + + # Fallback: no model identifier available + return "" + + def _get_legacy_collection_name(self) -> str: + """Get legacy collection/table name (without suffix). + + Used for data migration detection. + """ + raise NotImplementedError("Subclasses must implement this method") + + def _get_new_collection_name(self) -> str: + """Get new collection/table name (with suffix).""" + raise NotImplementedError("Subclasses must implement this method") + @abstractmethod async def query( self, query: str, top_k: int, query_embedding: list[float] = None diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 1447a79e14..c9ea40f766 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -1163,23 +1163,9 @@ async def check_tables(self): except Exception as e: logger.error(f"PostgreSQL, Failed to batch check/create indexes: {e}") - # Create vector indexs - if self.vector_index_type: - logger.info( - f"PostgreSQL, Create vector indexs, type: {self.vector_index_type}" - ) - try: - if self.vector_index_type in ["HNSW", "IVFFLAT", "VCHORDRQ"]: - await self._create_vector_indexes() - else: - logger.warning( - "Doesn't support this vector index type: {self.vector_index_type}. " - "Supported types: HNSW, IVFFLAT, VCHORDRQ" - ) - except Exception as e: - logger.error( - f"PostgreSQL, Failed to create vector index, type: {self.vector_index_type}, Got: {e}" - ) + # NOTE: Vector index creation moved to PGVectorStorage.setup_table() + # Each vector storage instance creates its own index with correct embedding_dim + # After all tables are created, attempt to migrate timestamp fields try: await self._migrate_timestamp_columns() @@ -1381,64 +1367,72 @@ async def _create_pagination_indexes(self): except Exception as e: logger.warning(f"Failed to create index {index['name']}: {e}") - async def _create_vector_indexes(self): - vdb_tables = [ - "LIGHTRAG_VDB_CHUNKS", - "LIGHTRAG_VDB_ENTITY", - "LIGHTRAG_VDB_RELATION", - ] + async def _create_vector_index(self, table_name: str, embedding_dim: int): + """ + Create vector index for a specific table. + + Args: + table_name: Name of the table to create index on + embedding_dim: Embedding dimension for the vector column + """ + if not self.vector_index_type: + return create_sql = { "HNSW": f""" CREATE INDEX {{vector_index_name}} - ON {{k}} USING hnsw (content_vector vector_cosine_ops) + ON {{table_name}} USING hnsw (content_vector vector_cosine_ops) WITH (m = {self.hnsw_m}, ef_construction = {self.hnsw_ef}) """, "IVFFLAT": f""" CREATE INDEX {{vector_index_name}} - ON {{k}} USING ivfflat (content_vector vector_cosine_ops) + ON {{table_name}} USING ivfflat (content_vector vector_cosine_ops) WITH (lists = {self.ivfflat_lists}) """, "VCHORDRQ": f""" CREATE INDEX {{vector_index_name}} - ON {{k}} USING vchordrq (content_vector vector_cosine_ops) - {f'WITH (options = $${self.vchordrq_build_options}$$)' if self.vchordrq_build_options else ''} + ON {{table_name}} USING vchordrq (content_vector vector_cosine_ops) + {f"WITH (options = $${self.vchordrq_build_options}$$)" if self.vchordrq_build_options else ""} """, } - embedding_dim = int(os.environ.get("EMBEDDING_DIM", 1024)) - for k in vdb_tables: - vector_index_name = ( - f"idx_{k.lower()}_{self.vector_index_type.lower()}_cosine" + if self.vector_index_type not in create_sql: + logger.warning( + f"Unsupported vector index type: {self.vector_index_type}. " + "Supported types: HNSW, IVFFLAT, VCHORDRQ" ) - check_vector_index_sql = f""" - SELECT 1 FROM pg_indexes - WHERE indexname = '{vector_index_name}' AND tablename = '{k.lower()}' - """ - try: - vector_index_exists = await self.query(check_vector_index_sql) - if not vector_index_exists: - # Only set vector dimension when index doesn't exist - alter_sql = f"ALTER TABLE {k} ALTER COLUMN content_vector TYPE VECTOR({embedding_dim})" - await self.execute(alter_sql) - logger.debug(f"Ensured vector dimension for {k}") - logger.info( - f"Creating {self.vector_index_type} index {vector_index_name} on table {k}" - ) - await self.execute( - create_sql[self.vector_index_type].format( - vector_index_name=vector_index_name, k=k - ) - ) - logger.info( - f"Successfully created vector index {vector_index_name} on table {k}" - ) - else: - logger.info( - f"{self.vector_index_type} vector index {vector_index_name} already exists on table {k}" + return + + k = table_name + vector_index_name = f"idx_{k.lower()}_{self.vector_index_type.lower()}_cosine" + check_vector_index_sql = f""" + SELECT 1 FROM pg_indexes + WHERE indexname = '{vector_index_name}' AND tablename = '{k.lower()}' + """ + try: + vector_index_exists = await self.query(check_vector_index_sql) + if not vector_index_exists: + # Only set vector dimension when index doesn't exist + alter_sql = f"ALTER TABLE {k} ALTER COLUMN content_vector TYPE VECTOR({embedding_dim})" + await self.execute(alter_sql) + logger.debug(f"Ensured vector dimension for {k}") + logger.info( + f"Creating {self.vector_index_type} index {vector_index_name} on table {k}" + ) + await self.execute( + create_sql[self.vector_index_type].format( + vector_index_name=vector_index_name, table_name=k ) - except Exception as e: - logger.error(f"Failed to create vector index on table {k}, Got: {e}") + ) + logger.info( + f"Successfully created vector index {vector_index_name} on table {k}" + ) + else: + logger.info( + f"{self.vector_index_type} vector index {vector_index_name} already exists on table {k}" + ) + except Exception as e: + logger.error(f"Failed to create vector index on table {k}, Got: {e}") async def query( self, @@ -2175,6 +2169,90 @@ async def drop(self) -> dict[str, str]: return {"status": "error", "message": str(e)} +async def _pg_table_exists(db: PostgreSQLDB, table_name: str) -> bool: + """Check if a table exists in PostgreSQL database""" + query = """ + SELECT EXISTS ( + SELECT FROM information_schema.tables + WHERE table_name = $1 + ) + """ + result = await db.query(query, [table_name.lower()]) + return result.get("exists", False) if result else False + + +async def _pg_create_table( + db: PostgreSQLDB, table_name: str, base_table: str, embedding_dim: int +) -> None: + """Create a new vector table by replacing the table name in DDL template""" + if base_table not in TABLES: + raise ValueError(f"No DDL template found for table: {base_table}") + + ddl_template = TABLES[base_table]["ddl"] + + # Replace embedding dimension placeholder if exists + ddl = ddl_template.replace( + f"VECTOR({os.environ.get('EMBEDDING_DIM', 1024)})", f"VECTOR({embedding_dim})" + ) + + # Replace table name + ddl = ddl.replace(base_table, table_name) + + await db.execute(ddl) + + +async def _pg_migrate_workspace_data( + db: PostgreSQLDB, + legacy_table_name: str, + new_table_name: str, + workspace: str, + expected_count: int, + embedding_dim: int, +) -> int: + """Migrate workspace data from legacy table to new table""" + migrated_count = 0 + offset = 0 + batch_size = 500 + + while True: + if workspace: + select_query = f"SELECT * FROM {legacy_table_name} WHERE workspace = $1 OFFSET $2 LIMIT $3" + rows = await db.query( + select_query, [workspace, offset, batch_size], multirows=True + ) + else: + select_query = f"SELECT * FROM {legacy_table_name} OFFSET $1 LIMIT $2" + rows = await db.query(select_query, [offset, batch_size], multirows=True) + + if not rows: + break + + for row in rows: + row_dict = dict(row) + columns = list(row_dict.keys()) + columns_str = ", ".join(columns) + placeholders = ", ".join([f"${i + 1}" for i in range(len(columns))]) + insert_query = f""" + INSERT INTO {new_table_name} ({columns_str}) + VALUES ({placeholders}) + ON CONFLICT (workspace, id) DO NOTHING + """ + # Rebuild dict in columns order to ensure values() matches placeholders order + # Python 3.7+ dicts maintain insertion order, and execute() uses tuple(data.values()) + values = {col: row_dict[col] for col in columns} + await db.execute(insert_query, values) + + migrated_count += len(rows) + workspace_info = f" for workspace '{workspace}'" if workspace else "" + logger.info( + f"PostgreSQL: {migrated_count}/{expected_count} records migrated{workspace_info}" + ) + + offset += batch_size + + return migrated_count + + @final @dataclass class PGVectorStorage(BaseVectorStorage): @@ -2190,6 +2268,412 @@ def __post_init__(self): ) self.cosine_better_than_threshold = cosine_threshold + # Generate model suffix for table isolation + self.model_suffix = self._generate_collection_suffix() + + # Get base table name + base_table = namespace_to_table_name(self.namespace) + if not base_table: + raise ValueError(f"Unknown namespace: {self.namespace}") + + # New table name (with suffix) + # Ensure model_suffix is not empty before appending + if self.model_suffix: + self.table_name = f"{base_table}_{self.model_suffix}" + else: + # Fallback: use base table name if model_suffix is unavailable + self.table_name = base_table + logger.warning( + f"Model suffix unavailable, using base table name '{base_table}'. " + f"Ensure embedding_func has model_name for proper model isolation." + ) + + # Legacy table name (without suffix, for migration) + self.legacy_table_name = base_table + + logger.debug( + f"PostgreSQL table naming: " + f"new='{self.table_name}', " + f"legacy='{self.legacy_table_name}', " + f"model_suffix='{self.model_suffix}'" + ) + + @staticmethod + async def setup_table( + db: PostgreSQLDB, + table_name: str, + legacy_table_name: str = None, + base_table: str = None, + embedding_dim: int = None, + workspace: str = None, + ): + """ + Setup PostgreSQL table with migration support from legacy tables. + + This method mirrors Qdrant's setup_collection approach to maintain consistency. + + Args: + db: PostgreSQLDB instance + table_name: Name of the new table + legacy_table_name: Name of the legacy table (if exists) + base_table: Base table name for DDL template lookup + embedding_dim: Embedding dimension for vector column + """ + new_table_exists = await _pg_table_exists(db, table_name) + legacy_exists = legacy_table_name and await _pg_table_exists( + db, legacy_table_name + ) + + # Case 1: Both new and legacy tables exist + if new_table_exists and legacy_exists: + if table_name.lower() == legacy_table_name.lower(): + logger.debug( + f"PostgreSQL: Table '{table_name}' already exists (no model suffix). Skipping Case 1 cleanup." + ) + return + + try: + workspace_info = f" for workspace '{workspace}'" if workspace else "" + + if workspace: + count_query = f"SELECT COUNT(*) as count FROM {legacy_table_name} WHERE workspace = $1" + count_result = await db.query(count_query, [workspace]) + else: + count_query = f"SELECT COUNT(*) as count FROM {legacy_table_name}" + count_result = await db.query(count_query, []) + + workspace_count = count_result.get("count", 0) if count_result else 0 + + if workspace_count > 0: + logger.info( + f"PostgreSQL: Found {workspace_count} records in legacy table{workspace_info}. Migrating..." + ) + + legacy_dim = None + try: + dim_query = """ + SELECT + CASE + WHEN typname = 'vector' THEN + COALESCE(atttypmod, -1) + ELSE -1 + END as vector_dim + FROM pg_attribute a + JOIN pg_type t ON a.atttypid = t.oid + WHERE a.attrelid = $1::regclass + AND a.attname = 'content_vector' + """ + dim_result = await db.query(dim_query, [legacy_table_name]) + legacy_dim = ( + dim_result.get("vector_dim", -1) if dim_result else -1 + ) + + if legacy_dim <= 0: + sample_query = f"SELECT content_vector FROM {legacy_table_name} LIMIT 1" + sample_result = await db.query(sample_query, []) + if sample_result and sample_result.get("content_vector"): + vector_data = sample_result["content_vector"] + if isinstance(vector_data, (list, tuple)): + legacy_dim = len(vector_data) + elif isinstance(vector_data, str): + import json + + vector_list = json.loads(vector_data) + legacy_dim = len(vector_list) + + if ( + legacy_dim > 0 + and embedding_dim + and legacy_dim != embedding_dim + ): + logger.warning( + f"PostgreSQL: Dimension mismatch - " + f"legacy table has {legacy_dim}d vectors, " + f"new embedding model expects {embedding_dim}d. " + f"Skipping migration{workspace_info}." + ) + await db._create_vector_index(table_name, embedding_dim) + return + + except Exception as e: + logger.warning( + f"PostgreSQL: Could not verify vector dimension: {e}. Proceeding with caution..." + ) + + migrated_count = await _pg_migrate_workspace_data( + db, + legacy_table_name, + table_name, + workspace, + workspace_count, + embedding_dim, + ) + + if workspace: + new_count_query = f"SELECT COUNT(*) as count FROM {table_name} WHERE workspace = $1" + new_count_result = await db.query(new_count_query, [workspace]) + else: + new_count_query = f"SELECT COUNT(*) as count FROM {table_name}" + new_count_result = await db.query(new_count_query, []) + + new_count = ( + new_count_result.get("count", 0) if new_count_result else 0 + ) + + if new_count < workspace_count: + logger.warning( + f"PostgreSQL: Expected {workspace_count} records, found {new_count}{workspace_info}. " + f"Some records may have been skipped due to conflicts." + ) + else: + logger.info( + f"PostgreSQL: Migration completed: {migrated_count} records migrated{workspace_info}" + ) + + if workspace: + delete_query = ( + f"DELETE FROM {legacy_table_name} WHERE workspace = $1" + ) + await db.execute(delete_query, {"workspace": workspace}) + logger.info( + f"PostgreSQL: Deleted workspace '{workspace}' data from legacy table" + ) + + total_count_query = f"SELECT COUNT(*) as count FROM {legacy_table_name}" + total_count_result = await db.query(total_count_query, []) + total_count = ( + total_count_result.get("count", 0) if total_count_result else 0 + ) + + if total_count == 0: + logger.info( + f"PostgreSQL: Legacy table '{legacy_table_name}' is empty. Deleting..." + ) + drop_query = f"DROP TABLE {legacy_table_name}" + await db.execute(drop_query, None) + logger.info( + f"PostgreSQL: Legacy table '{legacy_table_name}' deleted successfully" + ) + else: + logger.info( + f"PostgreSQL: Legacy table '{legacy_table_name}' preserved " + f"({total_count} records from other workspaces remain)" + ) + + except Exception as e: + logger.warning( + f"PostgreSQL: Error during Case 1 migration: {e}. Vector index will still be ensured." + ) + + await db._create_vector_index(table_name, embedding_dim) + return + + # Case 2: Only new table exists - Already migrated or newly created + if new_table_exists: + logger.debug(f"PostgreSQL: Table '{table_name}' already exists") + # Ensure vector index exists with correct embedding dimension + await db._create_vector_index(table_name, embedding_dim) + return + + # Case 3: Neither exists - Create new table + if not legacy_exists: + logger.info(f"PostgreSQL: Creating new table '{table_name}'") + await _pg_create_table(db, table_name, base_table, embedding_dim) + logger.info(f"PostgreSQL: Table '{table_name}' created successfully") + # Create vector index with correct embedding dimension + await db._create_vector_index(table_name, embedding_dim) + return + + # Case 4: Only legacy exists - Migrate data + logger.info( + f"PostgreSQL: Migrating data from legacy table '{legacy_table_name}'" + ) + + try: + # Get legacy table count (with workspace filtering) + if workspace: + count_query = f"SELECT COUNT(*) as count FROM {legacy_table_name} WHERE workspace = $1" + count_result = await db.query(count_query, [workspace]) + else: + count_query = f"SELECT COUNT(*) as count FROM {legacy_table_name}" + count_result = await db.query(count_query, []) + logger.warning( + "PostgreSQL: Migration without workspace filter - this may copy data from all workspaces!" + ) + + legacy_count = count_result.get("count", 0) if count_result else 0 + workspace_info = f" for workspace '{workspace}'" if workspace else "" + logger.info( + f"PostgreSQL: Found {legacy_count} records in legacy table{workspace_info}" + ) + + if legacy_count == 0: + logger.info("PostgreSQL: Legacy table is empty, skipping migration") + await _pg_create_table(db, table_name, base_table, embedding_dim) + # Create vector index with correct embedding dimension + await db._create_vector_index(table_name, embedding_dim) + return + + # Check vector dimension compatibility before migration + legacy_dim = None + try: + # Try to get vector dimension from pg_attribute metadata + dim_query = """ + SELECT + CASE + WHEN typname = 'vector' THEN + COALESCE(atttypmod, -1) + ELSE -1 + END as vector_dim + FROM pg_attribute a + JOIN pg_type t ON a.atttypid = t.oid + WHERE a.attrelid = $1::regclass + AND a.attname = 'content_vector' + """ + dim_result = await db.query(dim_query, [legacy_table_name]) + legacy_dim = dim_result.get("vector_dim", -1) if dim_result else -1 + + if legacy_dim <= 0: + # Alternative: Try to detect by sampling a vector + logger.info( + "PostgreSQL: Metadata dimension check failed, trying vector sampling..." + ) + sample_query = ( + f"SELECT content_vector FROM {legacy_table_name} LIMIT 1" + ) + sample_result = await db.query(sample_query, []) + if sample_result and sample_result.get("content_vector"): + vector_data = sample_result["content_vector"] + # pgvector returns list directly + if isinstance(vector_data, (list, tuple)): + legacy_dim = len(vector_data) + elif isinstance(vector_data, str): + import json + + vector_list = json.loads(vector_data) + legacy_dim = len(vector_list) + + if legacy_dim > 0 and embedding_dim and legacy_dim != embedding_dim: + logger.warning( + f"PostgreSQL: Dimension mismatch detected! " + f"Legacy table '{legacy_table_name}' has {legacy_dim}d vectors, " + f"but new embedding model expects {embedding_dim}d. " + f"Migration skipped to prevent data loss. " + f"Legacy table preserved as '{legacy_table_name}'. " + f"Creating new empty table '{table_name}' for new data." + ) + + # Create new table but skip migration + await _pg_create_table(db, table_name, base_table, embedding_dim) + await db._create_vector_index(table_name, embedding_dim) + + logger.info( + f"PostgreSQL: New table '{table_name}' created. " + f"To query legacy data, please use a {legacy_dim}d embedding model." + ) + return + + except Exception as e: + logger.warning( + f"PostgreSQL: Could not verify legacy table vector dimension: {e}. " + f"Proceeding with caution..." + ) + + logger.info(f"PostgreSQL: Creating new table '{table_name}'") + await _pg_create_table(db, table_name, base_table, embedding_dim) + + migrated_count = await _pg_migrate_workspace_data( + db, + legacy_table_name, + table_name, + workspace, + legacy_count, + embedding_dim, + ) + + logger.info("PostgreSQL: Verifying migration...") + new_count_query = f"SELECT COUNT(*) as count FROM {table_name}" + new_count_result = await db.query(new_count_query, []) + new_count = new_count_result.get("count", 0) if new_count_result else 0 + + if new_count != legacy_count: + error_msg = ( + f"PostgreSQL: Migration verification failed, " + f"expected {legacy_count} records, got {new_count} in new table" + ) + logger.error(error_msg) + raise PostgreSQLMigrationError(error_msg) + + logger.info( + f"PostgreSQL: Migration completed successfully: {migrated_count} records migrated" + ) + logger.info( + f"PostgreSQL: Migration from '{legacy_table_name}' to '{table_name}' completed successfully" + ) + + await db._create_vector_index(table_name, embedding_dim) + + try: + if workspace: + logger.info( + f"PostgreSQL: Deleting migrated workspace '{workspace}' data from legacy table '{legacy_table_name}'..." + ) + delete_query = ( + f"DELETE FROM {legacy_table_name} WHERE workspace = $1" + ) + await db.execute(delete_query, {"workspace": workspace}) + logger.info( + f"PostgreSQL: Deleted workspace '{workspace}' data from legacy table" + ) + + remaining_query = ( + f"SELECT COUNT(*) as count FROM {legacy_table_name}" + ) + remaining_result = await db.query(remaining_query, []) + remaining_count = ( + remaining_result.get("count", 0) if remaining_result else 0 + ) + + if remaining_count == 0: + logger.info( + f"PostgreSQL: Legacy table '{legacy_table_name}' is empty, deleting..." + ) + drop_query = f"DROP TABLE {legacy_table_name}" + await db.execute(drop_query, None) + logger.info( + f"PostgreSQL: Legacy table '{legacy_table_name}' deleted successfully" + ) + else: + logger.info( + f"PostgreSQL: Legacy table '{legacy_table_name}' preserved ({remaining_count} records from other workspaces remain)" + ) + else: + logger.warning( + f"PostgreSQL: No workspace specified, deleting entire legacy table '{legacy_table_name}'..." + ) + drop_query = f"DROP TABLE {legacy_table_name}" + await db.execute(drop_query, None) + logger.info( + f"PostgreSQL: Legacy table '{legacy_table_name}' deleted" + ) + + except Exception as delete_error: + # If cleanup fails, log warning but don't fail migration + logger.warning( + f"PostgreSQL: Failed to clean up legacy table '{legacy_table_name}': {delete_error}. " + "Migration succeeded, but manual cleanup may be needed." + ) + + except PostgreSQLMigrationError: + # Re-raise migration errors without wrapping + raise + except Exception as e: + error_msg = f"PostgreSQL: Migration failed with error: {e}" + logger.error(error_msg) + # Mirror Qdrant behavior: no automatic rollback + # Reason: partial data can be continued by re-running migration + raise PostgreSQLMigrationError(error_msg) from e + async def initialize(self): async with get_data_init_lock(): if self.db is None: @@ -2206,6 +2690,16 @@ async def initialize(self): # Use "default" for compatibility (lowest priority) self.workspace = "default" + # Setup table (create if not exists and handle migration) + await PGVectorStorage.setup_table( + self.db, + self.table_name, + legacy_table_name=self.legacy_table_name, + base_table=self.legacy_table_name, # base_table for DDL template lookup + embedding_dim=self.embedding_func.embedding_dim, + workspace=self.workspace, # CRITICAL: Filter migration by workspace + ) + async def finalize(self): if self.db is not None: await ClientManager.release_client(self.db) @@ -2215,7 +2709,9 @@ def _upsert_chunks( self, item: dict[str, Any], current_time: datetime.datetime ) -> tuple[str, dict[str, Any]]: try: - upsert_sql = SQL_TEMPLATES["upsert_chunk"] + upsert_sql = SQL_TEMPLATES["upsert_chunk"].format( + table_name=self.table_name + ) data: dict[str, Any] = { "workspace": self.workspace, "id": item["__id__"], @@ -2239,7 +2735,7 @@ def _upsert_chunks( def _upsert_entities( self, item: dict[str, Any], current_time: datetime.datetime ) -> tuple[str, dict[str, Any]]: - upsert_sql = SQL_TEMPLATES["upsert_entity"] + upsert_sql = SQL_TEMPLATES["upsert_entity"].format(table_name=self.table_name) source_id = item["source_id"] if isinstance(source_id, str) and "" in source_id: chunk_ids = source_id.split("") @@ -2262,7 +2758,9 @@ def _upsert_entities( def _upsert_relationships( self, item: dict[str, Any], current_time: datetime.datetime ) -> tuple[str, dict[str, Any]]: - upsert_sql = SQL_TEMPLATES["upsert_relationship"] + upsert_sql = SQL_TEMPLATES["upsert_relationship"].format( + table_name=self.table_name + ) source_id = item["source_id"] if isinstance(source_id, str) and "" in source_id: chunk_ids = source_id.split("") @@ -2335,7 +2833,9 @@ async def query( embedding_string = ",".join(map(str, embedding)) - sql = SQL_TEMPLATES[self.namespace].format(embedding_string=embedding_string) + sql = SQL_TEMPLATES[self.namespace].format( + embedding_string=embedding_string, table_name=self.table_name + ) params = { "workspace": self.workspace, "closer_than_threshold": 1 - self.cosine_better_than_threshold, @@ -2357,14 +2857,9 @@ async def delete(self, ids: list[str]) -> None: if not ids: return - table_name = namespace_to_table_name(self.namespace) - if not table_name: - logger.error( - f"[{self.workspace}] Unknown namespace for vector deletion: {self.namespace}" - ) - return - - delete_sql = f"DELETE FROM {table_name} WHERE workspace=$1 AND id = ANY($2)" + delete_sql = ( + f"DELETE FROM {self.table_name} WHERE workspace=$1 AND id = ANY($2)" + ) try: await self.db.execute(delete_sql, {"workspace": self.workspace, "ids": ids}) @@ -2383,8 +2878,8 @@ async def delete_entity(self, entity_name: str) -> None: entity_name: The name of the entity to delete """ try: - # Construct SQL to delete the entity - delete_sql = """DELETE FROM LIGHTRAG_VDB_ENTITY + # Construct SQL to delete the entity using dynamic table name + delete_sql = f"""DELETE FROM {self.table_name} WHERE workspace=$1 AND entity_name=$2""" await self.db.execute( @@ -2404,7 +2899,7 @@ async def delete_entity_relation(self, entity_name: str) -> None: """ try: # Delete relations where the entity is either the source or target - delete_sql = """DELETE FROM LIGHTRAG_VDB_RELATION + delete_sql = f"""DELETE FROM {self.table_name} WHERE workspace=$1 AND (source_id=$2 OR target_id=$2)""" await self.db.execute( @@ -2427,14 +2922,7 @@ async def get_by_id(self, id: str) -> dict[str, Any] | None: Returns: The vector data if found, or None if not found """ - table_name = namespace_to_table_name(self.namespace) - if not table_name: - logger.error( - f"[{self.workspace}] Unknown namespace for ID lookup: {self.namespace}" - ) - return None - - query = f"SELECT *, EXTRACT(EPOCH FROM create_time)::BIGINT as created_at FROM {table_name} WHERE workspace=$1 AND id=$2" + query = f"SELECT *, EXTRACT(EPOCH FROM create_time)::BIGINT as created_at FROM {self.table_name} WHERE workspace=$1 AND id=$2" params = {"workspace": self.workspace, "id": id} try: @@ -2460,15 +2948,8 @@ async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: if not ids: return [] - table_name = namespace_to_table_name(self.namespace) - if not table_name: - logger.error( - f"[{self.workspace}] Unknown namespace for IDs lookup: {self.namespace}" - ) - return [] - ids_str = ",".join([f"'{id}'" for id in ids]) - query = f"SELECT *, EXTRACT(EPOCH FROM create_time)::BIGINT as created_at FROM {table_name} WHERE workspace=$1 AND id IN ({ids_str})" + query = f"SELECT *, EXTRACT(EPOCH FROM create_time)::BIGINT as created_at FROM {self.table_name} WHERE workspace=$1 AND id IN ({ids_str})" params = {"workspace": self.workspace} try: @@ -2509,15 +2990,8 @@ async def get_vectors_by_ids(self, ids: list[str]) -> dict[str, list[float]]: if not ids: return {} - table_name = namespace_to_table_name(self.namespace) - if not table_name: - logger.error( - f"[{self.workspace}] Unknown namespace for vector lookup: {self.namespace}" - ) - return {} - ids_str = ",".join([f"'{id}'" for id in ids]) - query = f"SELECT id, content_vector FROM {table_name} WHERE workspace=$1 AND id IN ({ids_str})" + query = f"SELECT id, content_vector FROM {self.table_name} WHERE workspace=$1 AND id IN ({ids_str})" params = {"workspace": self.workspace} try: @@ -2546,15 +3020,8 @@ async def get_vectors_by_ids(self, ids: list[str]) -> dict[str, list[float]]: async def drop(self) -> dict[str, str]: """Drop the storage""" try: - table_name = namespace_to_table_name(self.namespace) - if not table_name: - return { - "status": "error", - "message": f"Unknown namespace: {self.namespace}", - } - drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format( - table_name=table_name + table_name=self.table_name ) await self.db.execute(drop_sql, {"workspace": self.workspace}) return {"status": "success", "message": "data dropped"} @@ -2593,6 +3060,9 @@ async def initialize(self): # Use "default" for compatibility (lowest priority) self.workspace = "default" + # NOTE: Table creation is handled by PostgreSQLDB.initdb() during initialization + # No need to create table here as it's already created in the TABLES dict + async def finalize(self): if self.db is not None: await ClientManager.release_client(self.db) @@ -3188,6 +3658,12 @@ async def drop(self) -> dict[str, str]: return {"status": "error", "message": str(e)} +class PostgreSQLMigrationError(Exception): + """Exception for PostgreSQL table migration errors.""" + + pass + + class PGGraphQueryException(Exception): """Exception for the AGE queries.""" @@ -5047,7 +5523,7 @@ def namespace_to_table_name(namespace: str) -> str: update_time = EXCLUDED.update_time """, # SQL for VectorStorage - "upsert_chunk": """INSERT INTO LIGHTRAG_VDB_CHUNKS (workspace, id, tokens, + "upsert_chunk": """INSERT INTO {table_name} (workspace, id, tokens, chunk_order_index, full_doc_id, content, content_vector, file_path, create_time, update_time) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) @@ -5060,7 +5536,7 @@ def namespace_to_table_name(namespace: str) -> str: file_path=EXCLUDED.file_path, update_time = EXCLUDED.update_time """, - "upsert_entity": """INSERT INTO LIGHTRAG_VDB_ENTITY (workspace, id, entity_name, content, + "upsert_entity": """INSERT INTO {table_name} (workspace, id, entity_name, content, content_vector, chunk_ids, file_path, create_time, update_time) VALUES ($1, $2, $3, $4, $5, $6::varchar[], $7, $8, $9) ON CONFLICT (workspace,id) DO UPDATE @@ -5071,7 +5547,7 @@ def namespace_to_table_name(namespace: str) -> str: file_path=EXCLUDED.file_path, update_time=EXCLUDED.update_time """, - "upsert_relationship": """INSERT INTO LIGHTRAG_VDB_RELATION (workspace, id, source_id, + "upsert_relationship": """INSERT INTO {table_name} (workspace, id, source_id, target_id, content, content_vector, chunk_ids, file_path, create_time, update_time) VALUES ($1, $2, $3, $4, $5, $6, $7::varchar[], $8, $9, $10) ON CONFLICT (workspace,id) DO UPDATE @@ -5087,7 +5563,7 @@ def namespace_to_table_name(namespace: str) -> str: SELECT r.source_id AS src_id, r.target_id AS tgt_id, EXTRACT(EPOCH FROM r.create_time)::BIGINT AS created_at - FROM LIGHTRAG_VDB_RELATION r + FROM {table_name} r WHERE r.workspace = $1 AND r.content_vector <=> '[{embedding_string}]'::vector < $2 ORDER BY r.content_vector <=> '[{embedding_string}]'::vector @@ -5096,7 +5572,7 @@ def namespace_to_table_name(namespace: str) -> str: "entities": """ SELECT e.entity_name, EXTRACT(EPOCH FROM e.create_time)::BIGINT AS created_at - FROM LIGHTRAG_VDB_ENTITY e + FROM {table_name} e WHERE e.workspace = $1 AND e.content_vector <=> '[{embedding_string}]'::vector < $2 ORDER BY e.content_vector <=> '[{embedding_string}]'::vector @@ -5107,7 +5583,7 @@ def namespace_to_table_name(namespace: str) -> str: c.content, c.file_path, EXTRACT(EPOCH FROM c.create_time)::BIGINT AS created_at - FROM LIGHTRAG_VDB_CHUNKS c + FROM {table_name} c WHERE c.workspace = $1 AND c.content_vector <=> '[{embedding_string}]'::vector < $2 ORDER BY c.content_vector <=> '[{embedding_string}]'::vector diff --git a/lightrag/kg/qdrant_impl.py b/lightrag/kg/qdrant_impl.py index 75de261365..5f8cb6426d 100644 --- a/lightrag/kg/qdrant_impl.py +++ b/lightrag/kg/qdrant_impl.py @@ -66,6 +66,48 @@ def workspace_filter_condition(workspace: str) -> models.FieldCondition: ) +def _find_legacy_collection( + client: QdrantClient, namespace: str, workspace: str = None +) -> str | None: + """ + Find legacy collection with backward compatibility support. + + This function tries multiple naming patterns to locate legacy collections + created by older versions of LightRAG: + + 1. {workspace}_{namespace} - Old format with workspace (pre-model-isolation) - HIGHEST PRIORITY + 2. lightrag_vdb_{namespace} - Current legacy format + 3. {namespace} - Old format without workspace (pre-model-isolation) + + Args: + client: QdrantClient instance + namespace: Base namespace (e.g., "chunks", "entities") + workspace: Optional workspace identifier + + Returns: + Collection name if found, None otherwise + """ + # Try multiple naming patterns for backward compatibility + # More specific names (with workspace) have higher priority + candidates = [ + f"{workspace}_{namespace}" + if workspace + else None, # Old format with workspace - most specific + f"lightrag_vdb_{namespace}", # New legacy format + namespace, # Old format without workspace - most generic + ] + + for candidate in candidates: + if candidate and client.collection_exists(candidate): + logger.info( + f"Qdrant: Found legacy collection '{candidate}' " + f"(namespace={namespace}, workspace={workspace or 'none'})" + ) + return candidate + + return None + + @final @dataclass class QdrantVectorDBStorage(BaseVectorStorage): @@ -85,28 +127,73 @@ def __init__( def setup_collection( client: QdrantClient, collection_name: str, - legacy_namespace: str = None, + namespace: str = None, workspace: str = None, **kwargs, ): """ Setup Qdrant collection with migration support from legacy collections. + This method now supports backward compatibility by automatically detecting + legacy collections created by older versions of LightRAG using multiple + naming patterns. + Args: client: QdrantClient instance collection_name: Name of the new collection - legacy_namespace: Name of the legacy collection (if exists) + namespace: Base namespace (e.g., "chunks", "entities") workspace: Workspace identifier for data isolation **kwargs: Additional arguments for collection creation (vectors_config, hnsw_config, etc.) """ new_collection_exists = client.collection_exists(collection_name) - legacy_exists = legacy_namespace and client.collection_exists(legacy_namespace) - # Case 1: Both new and legacy collections exist - Warning only (no migration) + # Try to find legacy collection with backward compatibility + legacy_collection = ( + _find_legacy_collection(client, namespace, workspace) if namespace else None + ) + legacy_exists = legacy_collection is not None + + # Case 1: Both new and legacy collections exist + # This can happen if: + # 1. Previous migration failed to delete the legacy collection + # 2. User manually created both collections + # 3. No model suffix (collection_name == legacy_collection) + # Strategy: Only delete legacy if it's empty (safe cleanup) and it's not the same as new collection if new_collection_exists and legacy_exists: - logger.warning( - f"Qdrant: Legacy collection '{legacy_namespace}' still exist. Remove it if migration is complete." - ) + # CRITICAL: Check if new and legacy are the same collection + # This happens when model_suffix is empty (no model_name provided) + if collection_name == legacy_collection: + logger.debug( + f"Qdrant: Collection '{collection_name}' already exists (no model suffix). Skipping Case 1 cleanup." + ) + return + + try: + # Check if legacy collection is empty + legacy_count = client.count( + collection_name=legacy_collection, exact=True + ).count + + if legacy_count == 0: + # Legacy collection is empty, safe to delete without data loss + logger.info( + f"Qdrant: Legacy collection '{legacy_collection}' is empty. Deleting..." + ) + client.delete_collection(collection_name=legacy_collection) + logger.info( + f"Qdrant: Legacy collection '{legacy_collection}' deleted successfully" + ) + else: + # Legacy collection still has data - don't risk deleting it + logger.warning( + f"Qdrant: Legacy collection '{legacy_collection}' still contains {legacy_count} records. " + f"Manual intervention required to verify and delete." + ) + except Exception as e: + logger.warning( + f"Qdrant: Could not check or cleanup legacy collection '{legacy_collection}': {e}. " + "You may need to delete it manually." + ) return # Case 2: Only new collection exists - Ensure index exists @@ -149,13 +236,13 @@ def setup_collection( # Case 4: Only legacy exists - Migrate data logger.info( - f"Qdrant: Migrating data from legacy collection '{legacy_namespace}'" + f"Qdrant: Migrating data from legacy collection '{legacy_collection}'" ) try: # Get legacy collection count legacy_count = client.count( - collection_name=legacy_namespace, exact=True + collection_name=legacy_collection, exact=True ).count logger.info(f"Qdrant: Found {legacy_count} records in legacy collection") @@ -173,6 +260,51 @@ def setup_collection( ) return + # Check vector dimension compatibility before migration + try: + legacy_info = client.get_collection(legacy_collection) + legacy_dim = legacy_info.config.params.vectors.size + + # Get expected dimension from kwargs + new_dim = ( + kwargs.get("vectors_config").size + if "vectors_config" in kwargs + else None + ) + + if new_dim and legacy_dim != new_dim: + logger.warning( + f"Qdrant: Dimension mismatch detected! " + f"Legacy collection '{legacy_collection}' has {legacy_dim}d vectors, " + f"but new embedding model expects {new_dim}d. " + f"Migration skipped to prevent data loss. " + f"Legacy collection preserved as '{legacy_collection}'. " + f"Creating new empty collection '{collection_name}' for new data." + ) + + # Create new collection but skip migration + client.create_collection(collection_name, **kwargs) + client.create_payload_index( + collection_name=collection_name, + field_name=WORKSPACE_ID_FIELD, + field_schema=models.KeywordIndexParams( + type=models.KeywordIndexType.KEYWORD, + is_tenant=True, + ), + ) + + logger.info( + f"Qdrant: New collection '{collection_name}' created. " + f"To query legacy data, please use a {legacy_dim}d embedding model." + ) + return + + except Exception as e: + logger.warning( + f"Qdrant: Could not verify legacy collection dimension: {e}. " + f"Proceeding with caution..." + ) + # Create new collection first logger.info(f"Qdrant: Creating new collection '{collection_name}'") client.create_collection(collection_name, **kwargs) @@ -185,7 +317,7 @@ def setup_collection( while True: # Scroll through legacy data result = client.scroll( - collection_name=legacy_namespace, + collection_name=legacy_collection, limit=batch_size, offset=offset, with_vectors=True, @@ -258,9 +390,27 @@ def setup_collection( ), ) logger.info( - f"Qdrant: Migration from '{legacy_namespace}' to '{collection_name}' completed successfully" + f"Qdrant: Migration from '{legacy_collection}' to '{collection_name}' completed successfully" ) + # Delete legacy collection after successful migration + # Data has been verified to match, so legacy collection is no longer needed + # and keeping it would cause Case 1 warnings on next startup + try: + logger.info( + f"Qdrant: Deleting legacy collection '{legacy_collection}'..." + ) + client.delete_collection(collection_name=legacy_collection) + logger.info( + f"Qdrant: Legacy collection '{legacy_collection}' deleted successfully" + ) + except Exception as delete_error: + # If deletion fails, user will see Case 1 warning on next startup + logger.warning( + f"Qdrant: Failed to delete legacy collection '{legacy_collection}': {delete_error}. " + "You may need to delete it manually." + ) + except QdrantMigrationError: # Re-raise migration errors without wrapping raise @@ -287,19 +437,34 @@ def __post_init__(self): f"Using passed workspace parameter: '{effective_workspace}'" ) - # Get legacy namespace for data migration from old version - if effective_workspace: - self.legacy_namespace = f"{effective_workspace}_{self.namespace}" - else: - self.legacy_namespace = self.namespace - self.effective_workspace = effective_workspace or DEFAULT_WORKSPACE - # Use a shared collection with payload-based partitioning (Qdrant's recommended approach) - # Ref: https://qdrant.tech/documentation/guides/multiple-partitions/ - self.final_namespace = f"lightrag_vdb_{self.namespace}" - logger.debug( - f"Using shared collection '{self.final_namespace}' with workspace '{self.effective_workspace}' for payload-based partitioning" + # Generate model suffix + model_suffix = self._generate_collection_suffix() + + # Legacy collection name (without model suffix, for migration) + # This matches the old naming scheme before model isolation was implemented + # Example: "lightrag_vdb_chunks" (without model suffix) + self.legacy_namespace = f"lightrag_vdb_{self.namespace}" + + # New naming scheme with model isolation + # Example: "lightrag_vdb_chunks_text_embedding_ada_002_1536d" + # Ensure model_suffix is not empty before appending + if model_suffix: + self.final_namespace = f"lightrag_vdb_{self.namespace}_{model_suffix}" + else: + # Fallback: use legacy namespace if model_suffix is unavailable + self.final_namespace = self.legacy_namespace + logger.warning( + f"Model suffix unavailable, using legacy collection name '{self.legacy_namespace}'. " + f"Ensure embedding_func has model_name for proper model isolation." + ) + + logger.info( + f"Qdrant collection naming: " + f"new='{self.final_namespace}', " + f"legacy='{self.legacy_namespace}', " + f"model_suffix='{model_suffix}'" ) kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {}) @@ -315,6 +480,12 @@ def __post_init__(self): self._max_batch_size = self.global_config["embedding_batch_num"] self._initialized = False + def _get_legacy_collection_name(self) -> str: + return self.legacy_namespace + + def _get_new_collection_name(self) -> str: + return self.final_namespace + async def initialize(self): """Initialize Qdrant collection""" async with get_data_init_lock(): @@ -338,11 +509,11 @@ async def initialize(self): ) # Setup collection (create if not exists and configure indexes) - # Pass legacy_namespace and workspace for migration support + # Pass namespace and workspace for backward-compatible migration support QdrantVectorDBStorage.setup_collection( self._client, self.final_namespace, - legacy_namespace=self.legacy_namespace, + namespace=self.namespace, workspace=self.effective_workspace, vectors_config=models.VectorParams( size=self.embedding_func.embedding_dim, @@ -354,6 +525,9 @@ async def initialize(self): ), ) + # Initialize max batch size from config + self._max_batch_size = self.global_config["embedding_batch_num"] + self._initialized = True logger.info( f"[{self.workspace}] Qdrant collection '{self.namespace}' initialized successfully" diff --git a/lightrag/kg/shared_storage.py b/lightrag/kg/shared_storage.py index 834cdc8f99..d69cd08d78 100644 --- a/lightrag/kg/shared_storage.py +++ b/lightrag/kg/shared_storage.py @@ -164,16 +164,29 @@ async def __aenter__(self) -> "UnifiedLock[T]": ) # Then acquire the main lock - if self._is_async: - await self._lock.acquire() - else: - self._lock.acquire() + if self._lock is not None: + if self._is_async: + await self._lock.acquire() + else: + self._lock.acquire() - direct_log( - f"== Lock == Process {self._pid}: Acquired lock {self._name} (async={self._is_async})", - level="INFO", - enable_output=self._enable_logging, - ) + direct_log( + f"== Lock == Process {self._pid}: Acquired lock {self._name} (async={self._is_async})", + level="INFO", + enable_output=self._enable_logging, + ) + else: + # CRITICAL: Raise exception instead of allowing unprotected execution + error_msg = ( + f"CRITICAL: Lock '{self._name}' is None - shared data not initialized. " + f"Call initialize_share_data() before using locks!" + ) + direct_log( + f"== Lock == Process {self._pid}: {error_msg}", + level="ERROR", + enable_output=True, + ) + raise RuntimeError(error_msg) return self except Exception as e: # If main lock acquisition fails, release the async lock if it was acquired @@ -193,19 +206,21 @@ async def __aenter__(self) -> "UnifiedLock[T]": async def __aexit__(self, exc_type, exc_val, exc_tb): main_lock_released = False + async_lock_released = False try: # Release main lock first - if self._is_async: - self._lock.release() - else: - self._lock.release() - main_lock_released = True + if self._lock is not None: + if self._is_async: + self._lock.release() + else: + self._lock.release() - direct_log( - f"== Lock == Process {self._pid}: Released lock {self._name} (async={self._is_async})", - level="INFO", - enable_output=self._enable_logging, - ) + direct_log( + f"== Lock == Process {self._pid}: Released lock {self._name} (async={self._is_async})", + level="INFO", + enable_output=self._enable_logging, + ) + main_lock_released = True # Then release async lock if in multiprocess mode if not self._is_async and self._async_lock is not None: @@ -215,6 +230,7 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): level="DEBUG", enable_output=self._enable_logging, ) + async_lock_released = True except Exception as e: direct_log( @@ -223,9 +239,10 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): enable_output=True, ) - # If main lock release failed but async lock hasn't been released, try to release it + # If main lock release failed but async lock hasn't been attempted yet, try to release it if ( not main_lock_released + and not async_lock_released and not self._is_async and self._async_lock is not None ): diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 8a6387591b..6618c95568 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -518,14 +518,10 @@ def __post_init__(self): f"max_total_tokens({self.summary_max_tokens}) should greater than summary_length_recommended({self.summary_length_recommended})" ) - # Fix global_config now - global_config = asdict(self) - - _print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()]) - logger.debug(f"LightRAG init with param:\n {_print_config}\n") - # Init Embedding - # Step 1: Capture max_token_size before applying decorator (decorator strips dataclass attributes) + # Step 1: Capture embedding_func and max_token_size before applying decorator + # (decorator strips dataclass attributes, and asdict() converts EmbeddingFunc to dict) + original_embedding_func = self.embedding_func embedding_max_token_size = None if self.embedding_func and hasattr(self.embedding_func, "max_token_size"): embedding_max_token_size = self.embedding_func.max_token_size @@ -534,6 +530,14 @@ def __post_init__(self): ) self.embedding_token_limit = embedding_max_token_size + # Fix global_config now + global_config = asdict(self) + # Restore original EmbeddingFunc object (asdict converts it to dict) + global_config["embedding_func"] = original_embedding_func + + _print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()]) + logger.debug(f"LightRAG init with param:\n {_print_config}\n") + # Step 2: Apply priority wrapper decorator self.embedding_func = priority_limit_async_func_call( self.embedding_func_max_async, diff --git a/lightrag/utils.py b/lightrag/utils.py index 8c9b7776d6..3a640d0713 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -370,6 +370,19 @@ class EmbeddingFunc: send_dimensions: bool = ( False # Control whether to send embedding_dim to the function ) + model_name: str | None = None + + def get_model_identifier(self) -> str: + """Generates model identifier for collection/table suffix. + + Returns: + str: Format "{model_name}_{dim}d", e.g. "text_embedding_3_large_3072d" + If model_name is not specified, returns "unknown_{dim}d" + """ + model_part = self.model_name if self.model_name else "unknown" + # Clean model name: remove special chars, convert to lower, replace - with _ + safe_model_name = re.sub(r"[^a-zA-Z0-9_]", "_", model_part.lower()) + return f"{safe_model_name}_{self.embedding_dim}d" async def __call__(self, *args, **kwargs) -> np.ndarray: # Only inject embedding_dim when send_dimensions is True diff --git a/tests/test_base_storage_integrity.py b/tests/test_base_storage_integrity.py new file mode 100644 index 0000000000..1bd247773f --- /dev/null +++ b/tests/test_base_storage_integrity.py @@ -0,0 +1,55 @@ +import pytest +from lightrag.base import BaseVectorStorage +from lightrag.utils import EmbeddingFunc + + +def test_base_vector_storage_integrity(): + # Just checking if we can import and inspect the class + assert hasattr(BaseVectorStorage, "_generate_collection_suffix") + assert hasattr(BaseVectorStorage, "_get_legacy_collection_name") + assert hasattr(BaseVectorStorage, "_get_new_collection_name") + + # Verify methods raise NotImplementedError + class ConcreteStorage(BaseVectorStorage): + async def query(self, *args, **kwargs): + pass + + async def upsert(self, *args, **kwargs): + pass + + async def delete_entity(self, *args, **kwargs): + pass + + async def delete_entity_relation(self, *args, **kwargs): + pass + + async def get_by_id(self, *args, **kwargs): + pass + + async def get_by_ids(self, *args, **kwargs): + pass + + async def delete(self, *args, **kwargs): + pass + + async def get_vectors_by_ids(self, *args, **kwargs): + pass + + async def index_done_callback(self): + pass + + async def drop(self): + pass + + func = EmbeddingFunc(embedding_dim=128, func=lambda x: x) + storage = ConcreteStorage( + namespace="test", workspace="test", global_config={}, embedding_func=func + ) + + assert storage._generate_collection_suffix() == "unknown_128d" + + with pytest.raises(NotImplementedError): + storage._get_legacy_collection_name() + + with pytest.raises(NotImplementedError): + storage._get_new_collection_name() diff --git a/tests/test_dimension_mismatch.py b/tests/test_dimension_mismatch.py new file mode 100644 index 0000000000..67bf4c71ab --- /dev/null +++ b/tests/test_dimension_mismatch.py @@ -0,0 +1,316 @@ +""" +Tests for dimension mismatch handling during migration. + +This test module verifies that both PostgreSQL and Qdrant storage backends +properly detect and handle vector dimension mismatches when migrating from +legacy collections/tables to new ones with different embedding models. +""" + +import pytest +from unittest.mock import MagicMock, AsyncMock, patch + +from lightrag.kg.qdrant_impl import QdrantVectorDBStorage +from lightrag.kg.postgres_impl import PGVectorStorage + + +# Note: Tests should use proper table names that have DDL templates +# Valid base tables: LIGHTRAG_VDB_CHUNKS, LIGHTRAG_VDB_ENTITIES, LIGHTRAG_VDB_RELATIONSHIPS, +# LIGHTRAG_DOC_CHUNKS, LIGHTRAG_DOC_FULL_DOCS, LIGHTRAG_DOC_TEXT_CHUNKS + + +class TestQdrantDimensionMismatch: + """Test suite for Qdrant dimension mismatch handling.""" + + def test_qdrant_dimension_mismatch_skip_migration(self): + """ + Test that Qdrant skips migration when dimensions don't match. + + Scenario: Legacy collection has 1536d vectors, new model expects 3072d. + Expected: Migration skipped, new empty collection created, legacy preserved. + """ + from qdrant_client import models + + # Setup mock client + client = MagicMock() + + # Mock legacy collection with 1536d vectors + legacy_collection_info = MagicMock() + legacy_collection_info.config.params.vectors.size = 1536 + + # Setup collection existence checks + def collection_exists_side_effect(name): + if name == "lightrag_chunks": # legacy + return True + elif name == "lightrag_chunks_model_3072d": # new + return False + return False + + client.collection_exists.side_effect = collection_exists_side_effect + client.get_collection.return_value = legacy_collection_info + client.count.return_value.count = 100 # Legacy has data + + # Call setup_collection with 3072d (different from legacy 1536d) + QdrantVectorDBStorage.setup_collection( + client, + "lightrag_chunks_model_3072d", + namespace="chunks", + workspace="test", + vectors_config=models.VectorParams( + size=3072, distance=models.Distance.COSINE + ), + ) + + # Verify new collection was created + client.create_collection.assert_called_once() + + # Verify migration was NOT attempted (no scroll/upsert calls) + client.scroll.assert_not_called() + client.upsert.assert_not_called() + + def test_qdrant_dimension_match_proceed_migration(self): + """ + Test that Qdrant proceeds with migration when dimensions match. + + Scenario: Legacy collection has 1536d vectors, new model also expects 1536d. + Expected: Migration proceeds normally. + """ + from qdrant_client import models + + client = MagicMock() + + # Mock legacy collection with 1536d vectors (matching new) + legacy_collection_info = MagicMock() + legacy_collection_info.config.params.vectors.size = 1536 + + def collection_exists_side_effect(name): + if name == "lightrag_chunks": # legacy + return True + elif name == "lightrag_chunks_model_1536d": # new + return False + return False + + client.collection_exists.side_effect = collection_exists_side_effect + client.get_collection.return_value = legacy_collection_info + client.count.return_value.count = 100 # Legacy has data + + # Mock scroll to return sample data + sample_point = MagicMock() + sample_point.id = "test_id" + sample_point.vector = [0.1] * 1536 + sample_point.payload = {"id": "test"} + client.scroll.return_value = ([sample_point], None) + + # Mock _find_legacy_collection to return the legacy collection name + with patch( + "lightrag.kg.qdrant_impl._find_legacy_collection", + return_value="lightrag_chunks", + ): + # Call setup_collection with matching 1536d + QdrantVectorDBStorage.setup_collection( + client, + "lightrag_chunks_model_1536d", + namespace="chunks", + workspace="test", + vectors_config=models.VectorParams( + size=1536, distance=models.Distance.COSINE + ), + ) + + # Verify migration WAS attempted + client.create_collection.assert_called_once() + client.scroll.assert_called() + client.upsert.assert_called() + + +class TestPostgresDimensionMismatch: + """Test suite for PostgreSQL dimension mismatch handling.""" + + @pytest.mark.asyncio + async def test_postgres_dimension_mismatch_skip_migration_metadata(self): + """ + Test that PostgreSQL skips migration when dimensions don't match (via metadata). + + Scenario: Legacy table has 1536d vectors (detected via pg_attribute), + new model expects 3072d. + Expected: Migration skipped, new empty table created, legacy preserved. + """ + # Setup mock database + db = AsyncMock() + + # Mock table existence and dimension checks + async def query_side_effect(query, params, **kwargs): + if "information_schema.tables" in query: + if params[0] == "LIGHTRAG_DOC_CHUNKS": # legacy + return {"exists": True} + elif params[0] == "LIGHTRAG_DOC_CHUNKS_model_3072d": # new + return {"exists": False} + elif "COUNT(*)" in query: + return {"count": 100} # Legacy has data + elif "pg_attribute" in query: + return {"vector_dim": 1536} # Legacy has 1536d vectors + return {} + + db.query.side_effect = query_side_effect + db.execute = AsyncMock() + db._create_vector_index = AsyncMock() + + # Call setup_table with 3072d (different from legacy 1536d) + await PGVectorStorage.setup_table( + db, + "LIGHTRAG_DOC_CHUNKS_model_3072d", + legacy_table_name="LIGHTRAG_DOC_CHUNKS", + base_table="LIGHTRAG_DOC_CHUNKS", + embedding_dim=3072, + workspace="test", + ) + + # Verify migration was NOT attempted (no INSERT calls) + # Note: _pg_create_table is mocked, so we check INSERT calls to verify migration was skipped + insert_calls = [ + call + for call in db.execute.call_args_list + if call[0][0] and "INSERT INTO" in call[0][0] + ] + assert ( + len(insert_calls) == 0 + ), "Migration should be skipped due to dimension mismatch" + + @pytest.mark.asyncio + async def test_postgres_dimension_mismatch_skip_migration_sampling(self): + """ + Test that PostgreSQL skips migration when dimensions don't match (via sampling). + + Scenario: Legacy table dimension detection fails via metadata, + falls back to vector sampling, detects 1536d vs expected 3072d. + Expected: Migration skipped, new empty table created, legacy preserved. + """ + db = AsyncMock() + + # Mock table existence and dimension checks + async def query_side_effect(query, params, **kwargs): + if "information_schema.tables" in query: + if params[0] == "LIGHTRAG_DOC_CHUNKS": # legacy + return {"exists": True} + elif params[0] == "LIGHTRAG_DOC_CHUNKS_model_3072d": # new + return {"exists": False} + elif "COUNT(*)" in query: + return {"count": 100} # Legacy has data + elif "pg_attribute" in query: + return {"vector_dim": -1} # Metadata check fails + elif "SELECT content_vector FROM" in query: + # Return sample vector with 1536 dimensions + return {"content_vector": [0.1] * 1536} + return {} + + db.query.side_effect = query_side_effect + db.execute = AsyncMock() + db._create_vector_index = AsyncMock() + + # Call setup_table with 3072d (different from legacy 1536d) + await PGVectorStorage.setup_table( + db, + "LIGHTRAG_DOC_CHUNKS_model_3072d", + legacy_table_name="LIGHTRAG_DOC_CHUNKS", + base_table="LIGHTRAG_DOC_CHUNKS", + embedding_dim=3072, + workspace="test", + ) + + # Verify new table was created + create_table_calls = [ + call + for call in db.execute.call_args_list + if call[0][0] and "CREATE TABLE" in call[0][0] + ] + assert len(create_table_calls) > 0, "New table should be created" + + # Verify migration was NOT attempted + insert_calls = [ + call + for call in db.execute.call_args_list + if call[0][0] and "INSERT INTO" in call[0][0] + ] + assert len(insert_calls) == 0, "Migration should be skipped" + + @pytest.mark.asyncio + async def test_postgres_dimension_match_proceed_migration(self): + """ + Test that PostgreSQL proceeds with migration when dimensions match. + + Scenario: Legacy table has 1536d vectors, new model also expects 1536d. + Expected: Migration proceeds normally. + """ + db = AsyncMock() + + async def query_side_effect(query, params, **kwargs): + multirows = kwargs.get("multirows", False) + + if "information_schema.tables" in query: + if params[0] == "LIGHTRAG_DOC_CHUNKS": # legacy + return {"exists": True} + elif params[0] == "LIGHTRAG_DOC_CHUNKS_model_1536d": # new + return {"exists": False} + elif "COUNT(*)" in query: + return {"count": 100} # Legacy has data + elif "pg_attribute" in query: + return {"vector_dim": 1536} # Legacy has matching 1536d + elif "SELECT * FROM" in query and multirows: + # Return sample data for migration (first batch) + # Handle workspace filtering: params = [workspace, offset, limit] + if "WHERE workspace" in query: + offset = params[1] if len(params) > 1 else 0 + else: + offset = params[0] if params else 0 + + if offset == 0: # First batch + return [ + { + "id": "test1", + "content_vector": [0.1] * 1536, + "workspace": "test", + }, + { + "id": "test2", + "content_vector": [0.2] * 1536, + "workspace": "test", + }, + ] + else: # offset > 0 + return [] # No more data + return {} + + db.query.side_effect = query_side_effect + db.execute = AsyncMock() + db._create_vector_index = AsyncMock() + + # Mock _pg_table_exists + async def mock_table_exists(db_inst, name): + if name == "LIGHTRAG_DOC_CHUNKS": # legacy exists + return True + elif name == "LIGHTRAG_DOC_CHUNKS_model_1536d": # new doesn't exist + return False + return False + + with patch( + "lightrag.kg.postgres_impl._pg_table_exists", + side_effect=mock_table_exists, + ): + # Call setup_table with matching 1536d + await PGVectorStorage.setup_table( + db, + "LIGHTRAG_DOC_CHUNKS_model_1536d", + legacy_table_name="LIGHTRAG_DOC_CHUNKS", + base_table="LIGHTRAG_DOC_CHUNKS", + embedding_dim=1536, + workspace="test", + ) + + # Verify migration WAS attempted (INSERT calls made) + insert_calls = [ + call + for call in db.execute.call_args_list + if call[0][0] and "INSERT INTO" in call[0][0] + ] + assert ( + len(insert_calls) > 0 + ), "Migration should proceed with matching dimensions" diff --git a/tests/test_e2e_multi_instance.py b/tests/test_e2e_multi_instance.py new file mode 100644 index 0000000000..dcb875b85a --- /dev/null +++ b/tests/test_e2e_multi_instance.py @@ -0,0 +1,1639 @@ +""" +E2E Tests for Multi-Instance LightRAG with Multiple Workspaces + +These tests verify: +1. Legacy data migration from tables/collections without model suffix +2. Multiple LightRAG instances with different embedding models +3. Multiple workspaces isolation +4. Both PostgreSQL and Qdrant vector storage +5. Real document insertion and query operations + +Prerequisites: +- PostgreSQL with pgvector extension +- Qdrant server running +- Environment variables configured +""" + +import os +import pytest +import asyncio +import numpy as np +import tempfile +import shutil +from lightrag import LightRAG +from lightrag.utils import EmbeddingFunc +from lightrag.kg.postgres_impl import PostgreSQLDB + +# Conditional import for E2E dependencies +# This prevents offline tests from failing due to missing E2E dependencies +qdrant_client = pytest.importorskip( + "qdrant_client", reason="Qdrant client required for E2E tests" +) +QdrantClient = qdrant_client.QdrantClient + + +# Configuration fixtures +@pytest.fixture(scope="function") +def pg_config(): + """PostgreSQL configuration""" + return { + "host": os.getenv("POSTGRES_HOST", "localhost"), + "port": int(os.getenv("POSTGRES_PORT", "5432")), + "user": os.getenv("POSTGRES_USER", "lightrag"), + "password": os.getenv("POSTGRES_PASSWORD", "lightrag_test_password"), + "database": os.getenv("POSTGRES_DB", "lightrag_test"), + "workspace": "multi_instance_test", + "max_connections": 10, + "connection_retry_attempts": 3, + "connection_retry_backoff": 0.5, + "connection_retry_backoff_max": 5.0, + "pool_close_timeout": 5.0, + } + + +@pytest.fixture(scope="function") +def qdrant_config(): + """Qdrant configuration""" + return { + "url": os.getenv("QDRANT_URL", "http://localhost:6333"), + "api_key": os.getenv("QDRANT_API_KEY", None), + } + + +# Cleanup fixtures +@pytest.fixture(scope="function") +async def pg_cleanup(pg_config): + """Cleanup PostgreSQL tables before and after test""" + db = PostgreSQLDB(pg_config) + await db.initdb() + + tables_to_drop = [ + "lightrag_doc_full", + "lightrag_doc_chunks", + "lightrag_vdb_chunks", + "lightrag_vdb_chunks_text_embedding_ada_002_1536d", + "lightrag_vdb_chunks_text_embedding_3_large_3072d", + "lightrag_vdb_chunks_model_a_768d", + "lightrag_vdb_chunks_model_b_1024d", + "lightrag_vdb_entity", + "lightrag_vdb_relation", + "lightrag_llm_cache", + "lightrag_doc_status", + "lightrag_full_entities", + "lightrag_full_relations", + "lightrag_entity_chunks", + "lightrag_relation_chunks", + ] + + # Cleanup before + for table in tables_to_drop: + try: + await db.execute(f"DROP TABLE IF EXISTS {table} CASCADE", None) + except Exception: + pass + + yield db + + # Cleanup after + for table in tables_to_drop: + try: + await db.execute(f"DROP TABLE IF EXISTS {table} CASCADE", None) + except Exception: + pass + + if db.pool: + await db.pool.close() + + +@pytest.fixture(scope="function") +def qdrant_cleanup(qdrant_config): + """Cleanup Qdrant collections before and after test""" + client = QdrantClient( + url=qdrant_config["url"], + api_key=qdrant_config["api_key"], + timeout=60, + ) + + collections_to_delete = [ + "lightrag_vdb_chunks", # Legacy collection (no model suffix) + "lightrag_vdb_chunks_text_embedding_ada_002_1536d", # Migrated collection + "lightrag_vdb_chunks_model_a_768d", + "lightrag_vdb_chunks_model_b_1024d", + ] + + # Cleanup before + for collection in collections_to_delete: + try: + if client.collection_exists(collection): + client.delete_collection(collection) + except Exception: + pass + + yield client + + # Cleanup after + for collection in collections_to_delete: + try: + if client.collection_exists(collection): + client.delete_collection(collection) + except Exception: + pass + + +@pytest.fixture +def temp_working_dirs(): + """Create multiple temporary working directories""" + dirs = { + "workspace_a": tempfile.mkdtemp(prefix="lightrag_workspace_a_"), + "workspace_b": tempfile.mkdtemp(prefix="lightrag_workspace_b_"), + } + yield dirs + # Cleanup + for dir_path in dirs.values(): + shutil.rmtree(dir_path, ignore_errors=True) + + +@pytest.fixture +def mock_llm_func(): + """Mock LLM function that returns proper entity/relation format""" + + async def llm_func(prompt, system_prompt=None, history_messages=[], **kwargs): + await asyncio.sleep(0) # Simulate async I/O + return """entity<|#|>Artificial Intelligence<|#|>concept<|#|>AI is a field of computer science. +entity<|#|>Machine Learning<|#|>concept<|#|>ML is a subset of AI. +relation<|#|>Machine Learning<|#|>Artificial Intelligence<|#|>subset<|#|>ML is a subset of AI. +<|COMPLETE|>""" + + return llm_func + + +@pytest.fixture +def mock_tokenizer(): + """Create a mock tokenizer""" + from lightrag.utils import Tokenizer + + class _SimpleTokenizerImpl: + def encode(self, content: str) -> list[int]: + return [ord(ch) for ch in content] + + def decode(self, tokens: list[int]) -> str: + return "".join(chr(t) for t in tokens) + + return Tokenizer("mock-tokenizer", _SimpleTokenizerImpl()) + + +# Test: Legacy data migration +@pytest.mark.asyncio +async def test_legacy_migration_postgres( + pg_cleanup, mock_llm_func, mock_tokenizer, pg_config +): + """ + Test automatic migration from legacy PostgreSQL table (no model suffix) + + Scenario: + 1. Create legacy table without model suffix + 2. Insert test data with 1536d vectors + 3. Initialize LightRAG with model_name (triggers migration) + 4. Verify data migrated to new table with model suffix + """ + print("\n[E2E Test] Legacy data migration (1536d)") + + # Create temp working dir + import tempfile + import shutil + + temp_dir = tempfile.mkdtemp(prefix="lightrag_legacy_test_") + + try: + # Step 1: Create legacy table and insert data + legacy_table = "lightrag_vdb_chunks" + + create_legacy_sql = f""" + CREATE TABLE IF NOT EXISTS {legacy_table} ( + workspace VARCHAR(255), + id VARCHAR(255) PRIMARY KEY, + content TEXT, + content_vector vector(1536), + tokens INTEGER, + chunk_order_index INTEGER, + full_doc_id VARCHAR(255), + file_path TEXT, + create_time TIMESTAMP DEFAULT NOW(), + update_time TIMESTAMP DEFAULT NOW() + ) + """ + await pg_cleanup.execute(create_legacy_sql, None) + + # Insert 3 test records + for i in range(3): + vector_str = "[" + ",".join(["0.1"] * 1536) + "]" + insert_sql = f""" + INSERT INTO {legacy_table} + (workspace, id, content, content_vector, tokens, chunk_order_index, full_doc_id, file_path) + VALUES ($1, $2, $3, $4::vector, $5, $6, $7, $8) + """ + await pg_cleanup.execute( + insert_sql, + { + "workspace": pg_config["workspace"], + "id": f"legacy_{i}", + "content": f"Legacy content {i}", + "content_vector": vector_str, + "tokens": 100, + "chunk_order_index": i, + "full_doc_id": "legacy_doc", + "file_path": "/test/path", + }, + ) + + # Verify legacy data + count_result = await pg_cleanup.query( + f"SELECT COUNT(*) as count FROM {legacy_table} WHERE workspace=$1", + [pg_config["workspace"]], + ) + legacy_count = count_result.get("count", 0) + print(f"✅ Legacy table created with {legacy_count} records") + + # Step 2: Initialize LightRAG with model_name (triggers migration) + async def embed_func(texts): + await asyncio.sleep(0) + return np.random.rand(len(texts), 1536) + + embedding_func = EmbeddingFunc( + embedding_dim=1536, + max_token_size=8192, + func=embed_func, + model_name="text-embedding-ada-002", + ) + + rag = LightRAG( + working_dir=temp_dir, + workspace=pg_config["workspace"], # Match workspace with test data + llm_model_func=mock_llm_func, + embedding_func=embedding_func, + tokenizer=mock_tokenizer, + kv_storage="PGKVStorage", + vector_storage="PGVectorStorage", + # Use default NetworkXStorage for graph storage (AGE extension not available in CI) + doc_status_storage="PGDocStatusStorage", + vector_db_storage_cls_kwargs={ + **pg_config, + "cosine_better_than_threshold": 0.8, + }, + ) + + print("🔄 Initializing LightRAG (triggers migration)...") + await rag.initialize_storages() + + # Step 3: Verify migration + new_table = rag.chunks_vdb.table_name + assert "text_embedding_ada_002_1536d" in new_table.lower() + + new_count_result = await pg_cleanup.query( + f"SELECT COUNT(*) as count FROM {new_table} WHERE workspace=$1", + [pg_config["workspace"]], + ) + new_count = new_count_result.get("count", 0) + + assert ( + new_count == legacy_count + ), f"Expected {legacy_count} records migrated, got {new_count}" + print(f"✅ Migration successful: {new_count}/{legacy_count} records migrated") + print(f"✅ New table: {new_table}") + + # Verify legacy table was automatically deleted after migration (Case 4) + check_legacy_query = """ + SELECT EXISTS ( + SELECT FROM information_schema.tables + WHERE table_name = $1 + ) + """ + legacy_result = await pg_cleanup.query( + check_legacy_query, [legacy_table.lower()] + ) + legacy_exists = legacy_result.get("exists", True) + assert ( + not legacy_exists + ), f"Legacy table '{legacy_table}' should be deleted after successful migration" + print(f"✅ Legacy table '{legacy_table}' automatically deleted after migration") + + await rag.finalize_storages() + + finally: + # Cleanup temp dir + shutil.rmtree(temp_dir, ignore_errors=True) + + +# Test: Workspace migration isolation (P0 Bug Fix Verification) +@pytest.mark.asyncio +async def test_workspace_migration_isolation_e2e_postgres( + pg_cleanup, mock_llm_func, mock_tokenizer, pg_config +): + """ + E2E Test: Workspace isolation during PostgreSQL migration + + Critical P0 Bug Verification: + - Legacy table contains MIXED data from workspace_a and workspace_b + - Initialize LightRAG for workspace_a only + - Verify ONLY workspace_a data migrated to new table + - Verify workspace_b data NOT leaked to workspace_a's table + - Verify workspace_b data preserved in legacy table + + This test validates the fix for the cross-workspace data leakage bug + where setup_table() was copying ALL records regardless of workspace. + """ + print("\n[E2E P0 Bug Fix] Workspace migration isolation (PostgreSQL)") + + import tempfile + import shutil + + temp_dir = tempfile.mkdtemp(prefix="lightrag_workspace_isolation_") + + try: + # Step 1: Create legacy table with MIXED workspace data + legacy_table = "lightrag_vdb_chunks" + + create_legacy_sql = f""" + CREATE TABLE IF NOT EXISTS {legacy_table} ( + workspace VARCHAR(255), + id VARCHAR(255) PRIMARY KEY, + content TEXT, + content_vector vector(1536), + tokens INTEGER, + chunk_order_index INTEGER, + full_doc_id VARCHAR(255), + file_path TEXT, + create_time TIMESTAMP DEFAULT NOW(), + update_time TIMESTAMP DEFAULT NOW() + ) + """ + await pg_cleanup.execute(create_legacy_sql, None) + + # Insert 3 records for workspace_a + for i in range(3): + vector_str = "[" + ",".join([str(0.1 + i * 0.01)] * 1536) + "]" + insert_sql = f""" + INSERT INTO {legacy_table} + (workspace, id, content, content_vector, tokens, chunk_order_index, full_doc_id, file_path) + VALUES ($1, $2, $3, $4::vector, $5, $6, $7, $8) + """ + await pg_cleanup.execute( + insert_sql, + { + "workspace": "workspace_a", + "id": f"a_{i}", + "content": f"Workspace A content {i}", + "content_vector": vector_str, + "tokens": 100, + "chunk_order_index": i, + "full_doc_id": "doc_a", + "file_path": "/workspace_a/doc.txt", + }, + ) + + # Insert 3 records for workspace_b + for i in range(3): + vector_str = "[" + ",".join([str(0.5 + i * 0.01)] * 1536) + "]" + insert_sql = f""" + INSERT INTO {legacy_table} + (workspace, id, content, content_vector, tokens, chunk_order_index, full_doc_id, file_path) + VALUES ($1, $2, $3, $4::vector, $5, $6, $7, $8) + """ + await pg_cleanup.execute( + insert_sql, + { + "workspace": "workspace_b", + "id": f"b_{i}", + "content": f"Workspace B content {i}", + "content_vector": vector_str, + "tokens": 100, + "chunk_order_index": i, + "full_doc_id": "doc_b", + "file_path": "/workspace_b/doc.txt", + }, + ) + + # Verify legacy table has BOTH workspaces' data + total_count_result = await pg_cleanup.query( + f"SELECT COUNT(*) as count FROM {legacy_table}", [] + ) + total_count = total_count_result.get("count", 0) + assert total_count == 6, f"Expected 6 total records, got {total_count}" + + workspace_a_count_result = await pg_cleanup.query( + f"SELECT COUNT(*) as count FROM {legacy_table} WHERE workspace=$1", + ["workspace_a"], + ) + workspace_a_count = workspace_a_count_result.get("count", 0) + assert ( + workspace_a_count == 3 + ), f"Expected 3 workspace_a records, got {workspace_a_count}" + + workspace_b_count_result = await pg_cleanup.query( + f"SELECT COUNT(*) as count FROM {legacy_table} WHERE workspace=$1", + ["workspace_b"], + ) + workspace_b_count = workspace_b_count_result.get("count", 0) + assert ( + workspace_b_count == 3 + ), f"Expected 3 workspace_b records, got {workspace_b_count}" + + print( + f"✅ Legacy table created: {total_count} records (workspace_a: {workspace_a_count}, workspace_b: {workspace_b_count})" + ) + + # Step 2: Initialize LightRAG for workspace_a ONLY + async def embed_func(texts): + await asyncio.sleep(0) + return np.random.rand(len(texts), 1536) + + embedding_func = EmbeddingFunc( + embedding_dim=1536, + max_token_size=8192, + func=embed_func, + model_name="text-embedding-ada-002", + ) + + rag = LightRAG( + working_dir=temp_dir, + workspace="workspace_a", # CRITICAL: Only workspace_a + llm_model_func=mock_llm_func, + embedding_func=embedding_func, + tokenizer=mock_tokenizer, + kv_storage="PGKVStorage", + vector_storage="PGVectorStorage", + doc_status_storage="PGDocStatusStorage", + vector_db_storage_cls_kwargs={ + **pg_config, + "workspace": "workspace_a", # CRITICAL: Filter by workspace_a + "cosine_better_than_threshold": 0.8, + }, + ) + + print("🔄 Initializing LightRAG for workspace_a (triggers migration)...") + await rag.initialize_storages() + + # Step 3: Verify workspace isolation + new_table = rag.chunks_vdb.table_name + assert "text_embedding_ada_002_1536d" in new_table.lower() + print(f"✅ New table created: {new_table}") + + # Verify: NEW table contains ONLY workspace_a data (3 records) + new_workspace_a_result = await pg_cleanup.query( + f"SELECT COUNT(*) as count FROM {new_table} WHERE workspace=$1", + ["workspace_a"], + ) + new_workspace_a_count = new_workspace_a_result.get("count", 0) + assert ( + new_workspace_a_count == 3 + ), f"Expected 3 workspace_a records in new table, got {new_workspace_a_count}" + print( + f"✅ Migration successful: {new_workspace_a_count} workspace_a records migrated" + ) + + # Verify: NEW table does NOT contain workspace_b data (0 records) + new_workspace_b_result = await pg_cleanup.query( + f"SELECT COUNT(*) as count FROM {new_table} WHERE workspace=$1", + ["workspace_b"], + ) + new_workspace_b_count = new_workspace_b_result.get("count", 0) + assert ( + new_workspace_b_count == 0 + ), f"workspace_b data leaked! Found {new_workspace_b_count} records in new table" + print("✅ No data leakage: 0 workspace_b records in new table (isolated)") + + # Verify: LEGACY table still exists (because workspace_b data remains) + check_legacy_query = """ + SELECT EXISTS ( + SELECT FROM information_schema.tables + WHERE table_name = $1 + ) + """ + legacy_result = await pg_cleanup.query( + check_legacy_query, [legacy_table.lower()] + ) + legacy_exists = legacy_result.get("exists", False) + assert ( + legacy_exists + ), f"Legacy table '{legacy_table}' should still exist (has workspace_b data)" + + # Verify: LEGACY table still has workspace_b data (3 records) + legacy_workspace_b_result = await pg_cleanup.query( + f"SELECT COUNT(*) as count FROM {legacy_table} WHERE workspace=$1", + ["workspace_b"], + ) + legacy_workspace_b_count = legacy_workspace_b_result.get("count", 0) + assert ( + legacy_workspace_b_count == 3 + ), f"workspace_b data lost! Only {legacy_workspace_b_count} remain in legacy table" + print( + f"✅ Legacy table preserved: {legacy_workspace_b_count} workspace_b records remain (not migrated)" + ) + + # Verify: LEGACY table does NOT have workspace_a data (migrated and deleted) + legacy_workspace_a_result = await pg_cleanup.query( + f"SELECT COUNT(*) as count FROM {legacy_table} WHERE workspace=$1", + ["workspace_a"], + ) + legacy_workspace_a_count = legacy_workspace_a_result.get("count", 0) + assert ( + legacy_workspace_a_count == 0 + ), f"workspace_a data should be removed from legacy after migration, found {legacy_workspace_a_count}" + print( + "✅ Legacy cleanup verified: 0 workspace_a records in legacy (cleaned after migration)" + ) + + print( + "\n🎉 P0 Bug Fix Verified: Workspace migration isolation working correctly!" + ) + print( + " - workspace_a: 3 records migrated to new table, 0 in legacy (migrated)" + ) + print( + " - workspace_b: 0 records in new table (isolated), 3 in legacy (preserved)" + ) + + await rag.finalize_storages() + + finally: + # Cleanup temp dir + shutil.rmtree(temp_dir, ignore_errors=True) + + +# Test: Qdrant legacy data migration +@pytest.mark.asyncio +async def test_legacy_migration_qdrant( + qdrant_cleanup, mock_llm_func, mock_tokenizer, qdrant_config +): + """ + Test automatic migration from legacy Qdrant collection (no model suffix) + + Scenario: + 1. Create legacy collection without model suffix + 2. Insert test vectors with 1536d + 3. Initialize LightRAG with model_name (triggers migration) + 4. Verify data migrated to new collection with model suffix + """ + print("\n[E2E Test] Qdrant legacy data migration (1536d)") + + # Create temp working dir + import tempfile + import shutil + + temp_dir = tempfile.mkdtemp(prefix="lightrag_qdrant_legacy_") + + try: + # Step 1: Create legacy collection and insert data + legacy_collection = "lightrag_vdb_chunks" + + # Create legacy collection without model suffix + from qdrant_client.models import Distance, VectorParams + + qdrant_cleanup.create_collection( + collection_name=legacy_collection, + vectors_config=VectorParams(size=1536, distance=Distance.COSINE), + ) + print(f"✅ Created legacy collection: {legacy_collection}") + + # Insert 3 test records + from qdrant_client.models import PointStruct + + test_vectors = [] + for i in range(3): + vector = np.random.rand(1536).tolist() + point = PointStruct( + id=i, + vector=vector, + payload={ + "id": f"legacy_{i}", + "content": f"Legacy content {i}", + "tokens": 100, + "chunk_order_index": i, + "full_doc_id": "legacy_doc", + "file_path": "/test/path", + }, + ) + test_vectors.append(point) + + qdrant_cleanup.upsert(collection_name=legacy_collection, points=test_vectors) + + # Verify legacy data + legacy_count = qdrant_cleanup.count(legacy_collection).count + print(f"✅ Legacy collection created with {legacy_count} vectors") + + # Step 2: Initialize LightRAG with model_name (triggers migration) + async def embed_func(texts): + await asyncio.sleep(0) + return np.random.rand(len(texts), 1536) + + embedding_func = EmbeddingFunc( + embedding_dim=1536, + max_token_size=8192, + func=embed_func, + model_name="text-embedding-ada-002", + ) + + rag = LightRAG( + working_dir=temp_dir, + llm_model_func=mock_llm_func, + embedding_func=embedding_func, + tokenizer=mock_tokenizer, + vector_storage="QdrantVectorDBStorage", + vector_db_storage_cls_kwargs={ + **qdrant_config, + "cosine_better_than_threshold": 0.8, + }, + ) + + print("🔄 Initializing LightRAG (triggers migration)...") + await rag.initialize_storages() + + # Step 3: Verify migration + new_collection = rag.chunks_vdb.final_namespace + assert "text_embedding_ada_002_1536d" in new_collection + + # Verify new collection exists + assert qdrant_cleanup.collection_exists( + new_collection + ), f"New collection {new_collection} should exist" + + new_count = qdrant_cleanup.count(new_collection).count + + assert ( + new_count == legacy_count + ), f"Expected {legacy_count} vectors migrated, got {new_count}" + print(f"✅ Migration successful: {new_count}/{legacy_count} vectors migrated") + print(f"✅ New collection: {new_collection}") + + # Verify vector dimension + collection_info = qdrant_cleanup.get_collection(new_collection) + assert ( + collection_info.config.params.vectors.size == 1536 + ), "Migrated collection should have 1536 dimensions" + print( + f"✅ Vector dimension verified: {collection_info.config.params.vectors.size}d" + ) + + # Verify legacy collection was automatically deleted after migration (Case 4) + legacy_exists = qdrant_cleanup.collection_exists(legacy_collection) + assert not legacy_exists, f"Legacy collection '{legacy_collection}' should be deleted after successful migration" + print( + f"✅ Legacy collection '{legacy_collection}' automatically deleted after migration" + ) + + await rag.finalize_storages() + + finally: + # Cleanup temp dir + shutil.rmtree(temp_dir, ignore_errors=True) + + +# Test: Multiple LightRAG instances with PostgreSQL +@pytest.mark.asyncio +async def test_multi_instance_postgres( + pg_cleanup, temp_working_dirs, mock_llm_func, mock_tokenizer, pg_config +): + """ + Test multiple LightRAG instances with different dimensions and model names + + Scenarios: + - Instance A: model-a (768d) - explicit model name + - Instance B: model-b (1024d) - explicit model name + - Both instances insert documents independently + - Verify separate tables created for each model+dimension combination + - Verify data isolation between instances + """ + print("\n[E2E Multi-Instance] PostgreSQL with 2 models (768d vs 1024d)") + + # Instance A: 768d with model-a + async def embed_func_a(texts): + await asyncio.sleep(0) + return np.random.rand(len(texts), 768) + + embedding_func_a = EmbeddingFunc( + embedding_dim=768, max_token_size=8192, func=embed_func_a, model_name="model-a" + ) + + # Instance B: 1024d with model-b + async def embed_func_b(texts): + await asyncio.sleep(0) + return np.random.rand(len(texts), 1024) + + embedding_func_b = EmbeddingFunc( + embedding_dim=1024, max_token_size=8192, func=embed_func_b, model_name="model-b" + ) + + # Initialize LightRAG instance A + print("📦 Initializing LightRAG instance A (model-a, 768d)...") + rag_a = LightRAG( + working_dir=temp_working_dirs["workspace_a"], + workspace=pg_config["workspace"], # Use same workspace to test model isolation + llm_model_func=mock_llm_func, + embedding_func=embedding_func_a, + tokenizer=mock_tokenizer, + kv_storage="PGKVStorage", + vector_storage="PGVectorStorage", + # Use default NetworkXStorage for graph storage (AGE extension not available in CI) + doc_status_storage="PGDocStatusStorage", + vector_db_storage_cls_kwargs={**pg_config, "cosine_better_than_threshold": 0.8}, + ) + + await rag_a.initialize_storages() + table_a = rag_a.chunks_vdb.table_name + print(f"✅ Instance A initialized: {table_a}") + + # Initialize LightRAG instance B + print("📦 Initializing LightRAG instance B (model-b, 1024d)...") + rag_b = LightRAG( + working_dir=temp_working_dirs["workspace_b"], + workspace=pg_config["workspace"], # Use same workspace to test model isolation + llm_model_func=mock_llm_func, + embedding_func=embedding_func_b, + tokenizer=mock_tokenizer, + kv_storage="PGKVStorage", + vector_storage="PGVectorStorage", + # Use default NetworkXStorage for graph storage (AGE extension not available in CI) + doc_status_storage="PGDocStatusStorage", + vector_db_storage_cls_kwargs={**pg_config, "cosine_better_than_threshold": 0.8}, + ) + + await rag_b.initialize_storages() + table_b = rag_b.chunks_vdb.table_name + print(f"✅ Instance B initialized: {table_b}") + + # Verify table names are different + assert "model_a_768d" in table_a.lower() + assert "model_b_1024d" in table_b.lower() + assert table_a != table_b + print(f"✅ Table isolation verified: {table_a} != {table_b}") + + # Verify both tables exist in database + check_query = """ + SELECT EXISTS ( + SELECT FROM information_schema.tables + WHERE table_name = $1 + ) + """ + result_a = await pg_cleanup.query(check_query, [table_a.lower()]) + result_b = await pg_cleanup.query(check_query, [table_b.lower()]) + + assert result_a.get("exists") is True, f"Table {table_a} should exist" + assert result_b.get("exists") is True, f"Table {table_b} should exist" + print("✅ Both tables exist in PostgreSQL") + + # Insert documents in instance A + print("📝 Inserting document in instance A...") + await rag_a.ainsert( + "Document A: This is about artificial intelligence and neural networks." + ) + + # Insert documents in instance B + print("📝 Inserting document in instance B...") + await rag_b.ainsert("Document B: This is about machine learning and deep learning.") + + # Verify data isolation + count_a_result = await pg_cleanup.query( + f"SELECT COUNT(*) as count FROM {table_a}", [] + ) + count_b_result = await pg_cleanup.query( + f"SELECT COUNT(*) as count FROM {table_b}", [] + ) + + count_a = count_a_result.get("count", 0) + count_b = count_b_result.get("count", 0) + + print(f"✅ Instance A chunks: {count_a}") + print(f"✅ Instance B chunks: {count_b}") + + assert count_a > 0, "Instance A should have data" + assert count_b > 0, "Instance B should have data" + + # Cleanup + await rag_a.finalize_storages() + await rag_b.finalize_storages() + + print("✅ Multi-instance PostgreSQL test passed!") + + +# Test: Multiple LightRAG instances with Qdrant +@pytest.mark.asyncio +async def test_multi_instance_qdrant( + qdrant_cleanup, temp_working_dirs, mock_llm_func, mock_tokenizer, qdrant_config +): + """ + Test multiple LightRAG instances with different models using Qdrant + + Scenario: + - Instance A: model-a (768d) + - Instance B: model-b (1024d) + - Both insert documents independently + - Verify separate collections created and data isolated + """ + print("\n[E2E Multi-Instance] Qdrant with 2 models (768d vs 1024d)") + + # Create embedding function for model A (768d) + async def embed_func_a(texts): + await asyncio.sleep(0) + return np.random.rand(len(texts), 768) + + embedding_func_a = EmbeddingFunc( + embedding_dim=768, max_token_size=8192, func=embed_func_a, model_name="model-a" + ) + + # Create embedding function for model B (1024d) + async def embed_func_b(texts): + await asyncio.sleep(0) + return np.random.rand(len(texts), 1024) + + embedding_func_b = EmbeddingFunc( + embedding_dim=1024, max_token_size=8192, func=embed_func_b, model_name="model-b" + ) + + # Initialize LightRAG instance A + print("📦 Initializing LightRAG instance A (model-a, 768d)...") + rag_a = LightRAG( + working_dir=temp_working_dirs["workspace_a"], + llm_model_func=mock_llm_func, + embedding_func=embedding_func_a, + tokenizer=mock_tokenizer, + vector_storage="QdrantVectorDBStorage", + vector_db_storage_cls_kwargs={ + **qdrant_config, + "cosine_better_than_threshold": 0.8, + }, + ) + + await rag_a.initialize_storages() + collection_a = rag_a.chunks_vdb.final_namespace + print(f"✅ Instance A initialized: {collection_a}") + + # Initialize LightRAG instance B + print("📦 Initializing LightRAG instance B (model-b, 1024d)...") + rag_b = LightRAG( + working_dir=temp_working_dirs["workspace_b"], + llm_model_func=mock_llm_func, + embedding_func=embedding_func_b, + tokenizer=mock_tokenizer, + vector_storage="QdrantVectorDBStorage", + vector_db_storage_cls_kwargs={ + **qdrant_config, + "cosine_better_than_threshold": 0.8, + }, + ) + + await rag_b.initialize_storages() + collection_b = rag_b.chunks_vdb.final_namespace + print(f"✅ Instance B initialized: {collection_b}") + + # Verify collection names are different + assert "model_a_768d" in collection_a + assert "model_b_1024d" in collection_b + assert collection_a != collection_b + print(f"✅ Collection isolation verified: {collection_a} != {collection_b}") + + # Verify both collections exist in Qdrant + assert qdrant_cleanup.collection_exists( + collection_a + ), f"Collection {collection_a} should exist" + assert qdrant_cleanup.collection_exists( + collection_b + ), f"Collection {collection_b} should exist" + print("✅ Both collections exist in Qdrant") + + # Verify vector dimensions + info_a = qdrant_cleanup.get_collection(collection_a) + info_b = qdrant_cleanup.get_collection(collection_b) + + assert info_a.config.params.vectors.size == 768, "Model A should use 768 dimensions" + assert ( + info_b.config.params.vectors.size == 1024 + ), "Model B should use 1024 dimensions" + print( + f"✅ Vector dimensions verified: {info_a.config.params.vectors.size}d vs {info_b.config.params.vectors.size}d" + ) + + # Insert documents in instance A + print("📝 Inserting document in instance A...") + await rag_a.ainsert( + "Document A: This is about artificial intelligence and neural networks." + ) + + # Insert documents in instance B + print("📝 Inserting document in instance B...") + await rag_b.ainsert("Document B: This is about machine learning and deep learning.") + + # Verify data isolation + count_a = qdrant_cleanup.count(collection_a).count + count_b = qdrant_cleanup.count(collection_b).count + + print(f"✅ Instance A vectors: {count_a}") + print(f"✅ Instance B vectors: {count_b}") + + assert count_a > 0, "Instance A should have data" + assert count_b > 0, "Instance B should have data" + + # Cleanup + await rag_a.finalize_storages() + await rag_b.finalize_storages() + + print("✅ Multi-instance Qdrant test passed!") + + +# ============================================================================ +# Complete Migration Scenario Tests with Real Databases +# ============================================================================ + + +@pytest.mark.asyncio +async def test_case1_both_exist_with_data_qdrant( + qdrant_cleanup, mock_llm_func, mock_tokenizer, qdrant_config +): + """ + E2E Case 1b: Both new and legacy collections exist, legacy has data + Expected: Log warning, do not delete legacy (preserve data), use new collection + """ + print("\n[E2E Case 1b] Both collections exist with data - preservation scenario") + + import tempfile + import shutil + from qdrant_client.models import Distance, VectorParams, PointStruct + + temp_dir = tempfile.mkdtemp(prefix="lightrag_case1_") + + try: + # Step 1: Create both legacy and new collection + legacy_collection = "lightrag_vdb_chunks" + new_collection = "lightrag_vdb_chunks_text_embedding_ada_002_1536d" + + # Create legacy collection with data + qdrant_cleanup.create_collection( + collection_name=legacy_collection, + vectors_config=VectorParams(size=1536, distance=Distance.COSINE), + ) + legacy_points = [ + PointStruct( + id=i, + vector=np.random.rand(1536).tolist(), + payload={"id": f"legacy_{i}", "content": f"Legacy doc {i}"}, + ) + for i in range(3) + ] + qdrant_cleanup.upsert(collection_name=legacy_collection, points=legacy_points) + print(f"✅ Created legacy collection with {len(legacy_points)} points") + + # Create new collection (simulate already migrated) + qdrant_cleanup.create_collection( + collection_name=new_collection, + vectors_config=VectorParams(size=1536, distance=Distance.COSINE), + ) + print(f"✅ Created new collection '{new_collection}'") + + # Step 2: Initialize LightRAG (should detect both and warn) + async def embed_func(texts): + await asyncio.sleep(0) + return np.random.rand(len(texts), 1536) + + embedding_func = EmbeddingFunc( + embedding_dim=1536, + max_token_size=8192, + func=embed_func, + model_name="text-embedding-ada-002", + ) + + rag = LightRAG( + working_dir=temp_dir, + llm_model_func=mock_llm_func, + embedding_func=embedding_func, + tokenizer=mock_tokenizer, + vector_storage="QdrantVectorDBStorage", + vector_db_storage_cls_kwargs={ + **qdrant_config, + "cosine_better_than_threshold": 0.8, + }, + ) + + await rag.initialize_storages() + + # Step 3: Verify behavior + # Should use new collection (not migrate) + assert rag.chunks_vdb.final_namespace == new_collection + + # Verify legacy collection still exists (Case 1b: has data, should NOT be deleted) + legacy_exists = qdrant_cleanup.collection_exists(legacy_collection) + assert legacy_exists, "Legacy collection with data should NOT be deleted" + + legacy_count = qdrant_cleanup.count(legacy_collection).count + # Legacy should still have its data (not migrated, not deleted) + assert legacy_count == 3 + print( + f"✅ Legacy collection still has {legacy_count} points (preserved, not deleted)" + ) + + await rag.finalize_storages() + + finally: + shutil.rmtree(temp_dir, ignore_errors=True) + + +@pytest.mark.asyncio +async def test_case2_only_new_exists_qdrant( + qdrant_cleanup, mock_llm_func, mock_tokenizer, qdrant_config +): + """ + E2E Case 2: Only new collection exists (already migrated scenario) + Expected: Use existing collection, no migration + """ + print("\n[E2E Case 2] Only new collection exists - already migrated") + + import tempfile + import shutil + from qdrant_client.models import Distance, VectorParams, PointStruct + + temp_dir = tempfile.mkdtemp(prefix="lightrag_case2_") + + try: + # Step 1: Create only new collection with data + new_collection = "lightrag_vdb_chunks_text_embedding_ada_002_1536d" + + qdrant_cleanup.create_collection( + collection_name=new_collection, + vectors_config=VectorParams(size=1536, distance=Distance.COSINE), + ) + + # Add some existing data + existing_points = [ + PointStruct( + id=i, + vector=np.random.rand(1536).tolist(), + payload={ + "id": f"existing_{i}", + "content": f"Existing doc {i}", + "workspace_id": "test_ws", + }, + ) + for i in range(5) + ] + qdrant_cleanup.upsert(collection_name=new_collection, points=existing_points) + print(f"✅ Created new collection with {len(existing_points)} existing points") + + # Step 2: Initialize LightRAG + async def embed_func(texts): + await asyncio.sleep(0) + return np.random.rand(len(texts), 1536) + + embedding_func = EmbeddingFunc( + embedding_dim=1536, + max_token_size=8192, + func=embed_func, + model_name="text-embedding-ada-002", + ) + + rag = LightRAG( + working_dir=temp_dir, + llm_model_func=mock_llm_func, + embedding_func=embedding_func, + tokenizer=mock_tokenizer, + vector_storage="QdrantVectorDBStorage", + vector_db_storage_cls_kwargs={ + **qdrant_config, + "cosine_better_than_threshold": 0.8, + }, + ) + + await rag.initialize_storages() + + # Step 3: Verify collection reused + assert rag.chunks_vdb.final_namespace == new_collection + count = qdrant_cleanup.count(new_collection).count + assert count == 5 # Existing data preserved + print(f"✅ Reused existing collection with {count} points") + + await rag.finalize_storages() + + finally: + shutil.rmtree(temp_dir, ignore_errors=True) + + +@pytest.mark.asyncio +async def test_backward_compat_old_workspace_naming_qdrant( + qdrant_cleanup, mock_llm_func, mock_tokenizer, qdrant_config +): + """ + E2E: Backward compatibility with old workspace-based naming + Old format: {workspace}_{namespace} + """ + print("\n[E2E Backward Compat] Old workspace naming migration") + + import tempfile + import shutil + from qdrant_client.models import Distance, VectorParams, PointStruct + + temp_dir = tempfile.mkdtemp(prefix="lightrag_backward_compat_") + + try: + # Step 1: Create old-style collection + old_collection = "prod_chunks" # Old format: {workspace}_{namespace} + + qdrant_cleanup.create_collection( + collection_name=old_collection, + vectors_config=VectorParams(size=1536, distance=Distance.COSINE), + ) + + # Add legacy data + legacy_points = [ + PointStruct( + id=i, + vector=np.random.rand(1536).tolist(), + payload={"id": f"old_{i}", "content": f"Old document {i}"}, + ) + for i in range(10) + ] + qdrant_cleanup.upsert(collection_name=old_collection, points=legacy_points) + print( + f"✅ Created old-style collection '{old_collection}' with {len(legacy_points)} points" + ) + + # Step 2: Initialize LightRAG with prod workspace + async def embed_func(texts): + await asyncio.sleep(0) + return np.random.rand(len(texts), 1536) + + embedding_func = EmbeddingFunc( + embedding_dim=1536, + max_token_size=8192, + func=embed_func, + model_name="text-embedding-ada-002", + ) + + # Important: Use "prod" workspace to match old naming + rag = LightRAG( + working_dir=temp_dir, + workspace="prod", # Pass workspace to LightRAG instance + llm_model_func=mock_llm_func, + embedding_func=embedding_func, + tokenizer=mock_tokenizer, + vector_storage="QdrantVectorDBStorage", + vector_db_storage_cls_kwargs={ + **qdrant_config, + "cosine_better_than_threshold": 0.8, + }, + ) + + print( + "🔄 Initializing with 'prod' workspace (triggers backward-compat migration)..." + ) + await rag.initialize_storages() + + # Step 3: Verify migration + new_collection = rag.chunks_vdb.final_namespace + new_count = qdrant_cleanup.count(new_collection).count + + assert new_count == len(legacy_points) + print( + f"✅ Migrated {new_count} points from old collection '{old_collection}' to '{new_collection}'" + ) + + await rag.finalize_storages() + + finally: + shutil.rmtree(temp_dir, ignore_errors=True) + + +@pytest.mark.asyncio +async def test_empty_legacy_qdrant( + qdrant_cleanup, mock_llm_func, mock_tokenizer, qdrant_config +): + """ + E2E: Empty legacy collection migration + Expected: Skip data migration, create new collection + """ + print("\n[E2E Empty Legacy] Empty collection migration") + + import tempfile + import shutil + from qdrant_client.models import Distance, VectorParams + + temp_dir = tempfile.mkdtemp(prefix="lightrag_empty_legacy_") + + try: + # Step 1: Create empty legacy collection + legacy_collection = "lightrag_vdb_chunks" + + qdrant_cleanup.create_collection( + collection_name=legacy_collection, + vectors_config=VectorParams(size=1536, distance=Distance.COSINE), + ) + print(f"✅ Created empty legacy collection '{legacy_collection}'") + + # Step 2: Initialize LightRAG + async def embed_func(texts): + await asyncio.sleep(0) + return np.random.rand(len(texts), 1536) + + embedding_func = EmbeddingFunc( + embedding_dim=1536, + max_token_size=8192, + func=embed_func, + model_name="text-embedding-ada-002", + ) + + rag = LightRAG( + working_dir=temp_dir, + llm_model_func=mock_llm_func, + embedding_func=embedding_func, + tokenizer=mock_tokenizer, + vector_storage="QdrantVectorDBStorage", + vector_db_storage_cls_kwargs={ + **qdrant_config, + "cosine_better_than_threshold": 0.8, + }, + ) + + print("🔄 Initializing (should skip data migration for empty collection)...") + await rag.initialize_storages() + + # Step 3: Verify new collection created + new_collection = rag.chunks_vdb.final_namespace + assert qdrant_cleanup.collection_exists(new_collection) + print(f"✅ New collection '{new_collection}' created (data migration skipped)") + + await rag.finalize_storages() + + finally: + shutil.rmtree(temp_dir, ignore_errors=True) + + +@pytest.mark.asyncio +async def test_workspace_isolation_e2e_qdrant( + qdrant_cleanup, temp_working_dirs, mock_llm_func, mock_tokenizer, qdrant_config +): + """ + E2E: Workspace isolation within same collection + Expected: Same model+dim uses same collection, isolated by workspace_id + """ + print("\n[E2E Workspace Isolation] Same collection, different workspaces") + + async def embed_func(texts): + await asyncio.sleep(0) + return np.random.rand(len(texts), 768) + + embedding_func = EmbeddingFunc( + embedding_dim=768, max_token_size=8192, func=embed_func, model_name="test-model" + ) + + # Instance A: workspace_a + rag_a = LightRAG( + working_dir=temp_working_dirs["workspace_a"], + workspace="workspace_a", # Pass workspace to LightRAG instance + llm_model_func=mock_llm_func, + embedding_func=embedding_func, + tokenizer=mock_tokenizer, + vector_storage="QdrantVectorDBStorage", + vector_db_storage_cls_kwargs={ + **qdrant_config, + "cosine_better_than_threshold": 0.8, + }, + ) + + # Instance B: workspace_b + rag_b = LightRAG( + working_dir=temp_working_dirs["workspace_b"], + workspace="workspace_b", # Pass workspace to LightRAG instance + llm_model_func=mock_llm_func, + embedding_func=embedding_func, + tokenizer=mock_tokenizer, + vector_storage="QdrantVectorDBStorage", + vector_db_storage_cls_kwargs={ + **qdrant_config, + "cosine_better_than_threshold": 0.8, + }, + ) + + await rag_a.initialize_storages() + await rag_b.initialize_storages() + + # Verify: Same collection + collection_a = rag_a.chunks_vdb.final_namespace + collection_b = rag_b.chunks_vdb.final_namespace + assert collection_a == collection_b + print(f"✅ Both use same collection: '{collection_a}'") + + # Insert data to different workspaces + await rag_a.ainsert("Document A for workspace A") + await rag_b.ainsert("Document B for workspace B") + + # Verify isolation: Each workspace should see only its own data + # This is ensured by workspace_id filtering in queries + + await rag_a.finalize_storages() + await rag_b.finalize_storages() + + print("✅ Workspace isolation verified (same collection, isolated data)") + + +# Test: Dimension mismatch during migration (PostgreSQL) +@pytest.mark.asyncio +async def test_dimension_mismatch_postgres( + pg_cleanup, mock_llm_func, mock_tokenizer, pg_config +): + """ + Test dimension mismatch scenario - upgrading from 1536d to 3072d model + + Scenario: + 1. Create legacy table with 1536d vectors + 2. Insert test data + 3. Initialize LightRAG with 3072d model + 4. Verify system handles dimension mismatch gracefully + """ + print("\n[E2E Test] Dimension mismatch: 1536d -> 3072d (PostgreSQL)") + + import tempfile + import shutil + + temp_dir = tempfile.mkdtemp(prefix="lightrag_dim_test_") + + try: + # Step 1: Create legacy table with 1536d vectors + legacy_table = "lightrag_vdb_chunks" + + create_legacy_sql = f""" + CREATE TABLE IF NOT EXISTS {legacy_table} ( + workspace VARCHAR(255), + id VARCHAR(255) PRIMARY KEY, + content TEXT, + content_vector vector(1536), + tokens INTEGER, + chunk_order_index INTEGER, + full_doc_id VARCHAR(255), + file_path TEXT, + create_time TIMESTAMP DEFAULT NOW(), + update_time TIMESTAMP DEFAULT NOW() + ) + """ + await pg_cleanup.execute(create_legacy_sql, None) + + # Insert test records with 1536d vectors + for i in range(3): + vector_str = "[" + ",".join(["0.1"] * 1536) + "]" + insert_sql = f""" + INSERT INTO {legacy_table} + (workspace, id, content, content_vector, tokens, chunk_order_index, full_doc_id, file_path) + VALUES ($1, $2, $3, $4::vector, $5, $6, $7, $8) + """ + await pg_cleanup.execute( + insert_sql, + { + "workspace": pg_config["workspace"], + "id": f"legacy_{i}", + "content": f"Legacy content {i}", + "content_vector": vector_str, + "tokens": 100, + "chunk_order_index": i, + "full_doc_id": "legacy_doc", + "file_path": "/test/path", + }, + ) + + print("✅ Legacy table created with 3 records (1536d)") + + # Step 2: Try to initialize LightRAG with NEW model (3072d) + async def embed_func_new(texts): + await asyncio.sleep(0) + return np.random.rand(len(texts), 3072) # NEW dimension + + embedding_func_new = EmbeddingFunc( + embedding_dim=3072, # NEW dimension + max_token_size=8192, + func=embed_func_new, + model_name="text-embedding-3-large", + ) + + print("📦 Initializing LightRAG with new model (3072d)...") + + # With our fix, this should handle dimension mismatch gracefully: + # Expected behavior: + # 1. Detect dimension mismatch (1536d legacy vs 3072d new) + # 2. Skip migration to prevent data corruption + # 3. Preserve legacy table with original data + # 4. Create new empty table for 3072d model + # 5. System initializes successfully + + rag = LightRAG( + working_dir=temp_dir, + workspace=pg_config["workspace"], # Match workspace with test data + llm_model_func=mock_llm_func, + embedding_func=embedding_func_new, + tokenizer=mock_tokenizer, + kv_storage="PGKVStorage", + vector_storage="PGVectorStorage", + doc_status_storage="PGDocStatusStorage", + vector_db_storage_cls_kwargs={ + **pg_config, + "cosine_better_than_threshold": 0.8, + }, + ) + + await rag.initialize_storages() + + # Verify expected behavior + new_table = rag.chunks_vdb.table_name + print(f"✅ Initialization succeeded, new table: {new_table}") + + # 1. New table should exist and be created with model suffix + assert "text_embedding_3_large_3072d" in new_table.lower() + check_new = f"SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = '{new_table.lower()}')" + new_exists = await pg_cleanup.query(check_new, []) + assert new_exists.get("exists") is True, "New table should exist" + print(f"✅ New table created: {new_table}") + + # 2. Legacy table should be preserved (not deleted) + check_legacy = f"SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = '{legacy_table}')" + legacy_exists = await pg_cleanup.query(check_legacy, []) + assert ( + legacy_exists.get("exists") is True + ), "Legacy table should be preserved when dimensions don't match" + print(f"✅ Legacy table preserved: {legacy_table}") + + # 3. Legacy table should still have original data (not migrated) + legacy_count_result = await pg_cleanup.query( + f"SELECT COUNT(*) as count FROM {legacy_table}", [] + ) + legacy_count = legacy_count_result.get("count", 0) + assert ( + legacy_count == 3 + ), f"Legacy table should still have 3 records, got {legacy_count}" + print(f"✅ Legacy data preserved: {legacy_count} records") + + # 4. New table should be empty (migration skipped) + new_count_result = await pg_cleanup.query( + f"SELECT COUNT(*) as count FROM {new_table}", [] + ) + new_count = new_count_result.get("count", 0) + assert ( + new_count == 0 + ), f"New table should be empty (migration skipped), got {new_count}" + print( + f"✅ New table is empty (migration correctly skipped): {new_count} records" + ) + + # 5. System should be operational + print("✅ System initialized successfully despite dimension mismatch") + + await rag.finalize_storages() + + finally: + shutil.rmtree(temp_dir, ignore_errors=True) + + +# Test: Dimension mismatch during migration (Qdrant) +@pytest.mark.asyncio +async def test_dimension_mismatch_qdrant( + qdrant_cleanup, mock_llm_func, mock_tokenizer, qdrant_config +): + """ + Test dimension mismatch scenario - upgrading from 768d to 1024d model + + Scenario: + 1. Create legacy collection with 768d vectors + 2. Insert test data + 3. Initialize LightRAG with 1024d model + 4. Verify system handles dimension mismatch gracefully + """ + print("\n[E2E Test] Dimension mismatch: 768d -> 1024d (Qdrant)") + + import tempfile + import shutil + + temp_dir = tempfile.mkdtemp(prefix="lightrag_qdrant_dim_test_") + + try: + # Step 1: Create legacy collection with 768d vectors + legacy_collection = "lightrag_vdb_chunks" + + client = QdrantClient(**qdrant_config) + + # Delete if exists + try: + client.delete_collection(legacy_collection) + except Exception: + pass + + # Create legacy collection with 768d + from qdrant_client import models + + client.create_collection( + collection_name=legacy_collection, + vectors_config=models.VectorParams( + size=768, distance=models.Distance.COSINE + ), + ) + + # Insert test points with 768d vectors + points = [] + for i in range(3): + points.append( + models.PointStruct( + id=i, # Use integer ID instead of string + vector=[0.1] * 768, # OLD dimension + payload={"content": f"Legacy content {i}", "id": f"doc_{i}"}, + ) + ) + + client.upsert(collection_name=legacy_collection, points=points, wait=True) + print("✅ Legacy collection created with 3 records (768d)") + + # Step 2: Try to initialize LightRAG with NEW model (1024d) + async def embed_func_new(texts): + await asyncio.sleep(0) + return np.random.rand(len(texts), 1024) # NEW dimension + + embedding_func_new = EmbeddingFunc( + embedding_dim=1024, # NEW dimension + max_token_size=8192, + func=embed_func_new, + model_name="bge-large", + ) + + print("📦 Initializing LightRAG with new model (1024d)...") + + # With our fix, this should handle dimension mismatch gracefully: + # Expected behavior: + # 1. Detect dimension mismatch (768d legacy vs 1024d new) + # 2. Skip migration to prevent data corruption + # 3. Preserve legacy collection with original data + # 4. Create new empty collection for 1024d model + # 5. System initializes successfully + + rag = LightRAG( + working_dir=temp_dir, + llm_model_func=mock_llm_func, + embedding_func=embedding_func_new, + tokenizer=mock_tokenizer, + vector_storage="QdrantVectorDBStorage", + vector_db_storage_cls_kwargs={ + **qdrant_config, + "cosine_better_than_threshold": 0.8, + }, + ) + + await rag.initialize_storages() + + # Verify expected behavior + new_collection = rag.chunks_vdb.final_namespace + print(f"✅ Initialization succeeded, new collection: {new_collection}") + + # 1. New collection should exist with model suffix + assert "bge_large_1024d" in new_collection + assert client.collection_exists( + new_collection + ), f"New collection {new_collection} should exist" + print(f"✅ New collection created: {new_collection}") + + # 2. Legacy collection should be preserved (not deleted) + legacy_exists = client.collection_exists(legacy_collection) + assert ( + legacy_exists + ), "Legacy collection should be preserved when dimensions don't match" + print(f"✅ Legacy collection preserved: {legacy_collection}") + + # 3. Legacy collection should still have original data (not migrated) + legacy_count = client.count(legacy_collection).count + assert ( + legacy_count == 3 + ), f"Legacy collection should still have 3 vectors, got {legacy_count}" + print(f"✅ Legacy data preserved: {legacy_count} vectors") + + # 4. New collection should be empty (migration skipped) + new_count = client.count(new_collection).count + assert ( + new_count == 0 + ), f"New collection should be empty (migration skipped), got {new_count}" + print( + f"✅ New collection is empty (migration correctly skipped): {new_count} vectors" + ) + + # 5. Verify new collection has correct dimension + collection_info = client.get_collection(new_collection) + new_dim = collection_info.config.params.vectors.size + assert new_dim == 1024, f"New collection should have 1024d, got {new_dim}d" + print(f"✅ New collection dimension verified: {new_dim}d") + + # 6. System should be operational + print("✅ System initialized successfully despite dimension mismatch") + + await rag.finalize_storages() + + finally: + shutil.rmtree(temp_dir, ignore_errors=True) + # Cleanup collections + try: + for coll in client.get_collections().collections: + if "lightrag" in coll.name.lower(): + client.delete_collection(coll.name) + except Exception: + pass + + +if __name__ == "__main__": + # Run tests with pytest + pytest.main([__file__, "-v", "-s"]) diff --git a/tests/test_embedding_func.py b/tests/test_embedding_func.py new file mode 100644 index 0000000000..8997a13a1c --- /dev/null +++ b/tests/test_embedding_func.py @@ -0,0 +1,31 @@ +from lightrag.utils import EmbeddingFunc + + +def dummy_func(*args, **kwargs): + pass + + +def test_embedding_func_with_model_name(): + func = EmbeddingFunc( + embedding_dim=1536, func=dummy_func, model_name="text-embedding-ada-002" + ) + assert func.get_model_identifier() == "text_embedding_ada_002_1536d" + + +def test_embedding_func_without_model_name(): + func = EmbeddingFunc(embedding_dim=768, func=dummy_func) + assert func.get_model_identifier() == "unknown_768d" + + +def test_model_name_sanitization(): + func = EmbeddingFunc( + embedding_dim=1024, + func=dummy_func, + model_name="models/text-embedding-004", # Contains special chars + ) + assert func.get_model_identifier() == "models_text_embedding_004_1024d" + + +def test_model_name_with_uppercase(): + func = EmbeddingFunc(embedding_dim=512, func=dummy_func, model_name="My-Model-V1") + assert func.get_model_identifier() == "my_model_v1_512d" diff --git a/tests/test_no_model_suffix_safety.py b/tests/test_no_model_suffix_safety.py new file mode 100644 index 0000000000..b1dca80c77 --- /dev/null +++ b/tests/test_no_model_suffix_safety.py @@ -0,0 +1,213 @@ +""" +Tests for safety when model suffix is absent (no model_name provided). + +This test module verifies that the system correctly handles the case when +no model_name is provided, preventing accidental deletion of the only table/collection +on restart. + +Critical Bug: When model_suffix is empty, table_name == legacy_table_name. +On second startup, Case 1 logic would delete the only table/collection thinking +it's "legacy", causing all subsequent operations to fail. +""" + +import pytest +from unittest.mock import MagicMock, AsyncMock, patch + +from lightrag.kg.qdrant_impl import QdrantVectorDBStorage +from lightrag.kg.postgres_impl import PGVectorStorage + + +class TestNoModelSuffixSafety: + """Test suite for preventing data loss when model_suffix is absent.""" + + def test_qdrant_no_suffix_second_startup(self): + """ + Test Qdrant doesn't delete collection on second startup when no model_name. + + Scenario: + 1. First startup: Creates collection without suffix + 2. Collection is empty + 3. Second startup: Should NOT delete the collection + + Bug: Without fix, Case 1 would delete the only collection. + """ + from qdrant_client import models + + client = MagicMock() + + # Simulate second startup: collection already exists and is empty + # IMPORTANT: Without suffix, collection_name == legacy collection name + collection_name = "lightrag_vdb_chunks" # No suffix, same as legacy + + # Both exist (they're the same collection) + client.collection_exists.return_value = True + + # Collection is empty + client.count.return_value.count = 0 + + # Call setup_collection + # This should detect that new == legacy and skip deletion + QdrantVectorDBStorage.setup_collection( + client, + collection_name, + namespace="chunks", + workspace=None, + vectors_config=models.VectorParams( + size=1536, distance=models.Distance.COSINE + ), + ) + + # CRITICAL: Collection should NOT be deleted + client.delete_collection.assert_not_called() + + # Verify we returned early (skipped Case 1 cleanup) + # The collection_exists was checked, but we didn't proceed to count + # because we detected same name + assert client.collection_exists.call_count >= 1 + + @pytest.mark.asyncio + async def test_postgres_no_suffix_second_startup(self): + """ + Test PostgreSQL doesn't delete table on second startup when no model_name. + + Scenario: + 1. First startup: Creates table without suffix + 2. Table is empty + 3. Second startup: Should NOT delete the table + + Bug: Without fix, Case 1 would delete the only table. + """ + db = AsyncMock() + + # Simulate second startup: table already exists and is empty + # IMPORTANT: table_name and legacy_table_name are THE SAME + table_name = "LIGHTRAG_VDB_CHUNKS" # No suffix + legacy_table_name = "LIGHTRAG_VDB_CHUNKS" # Same as new + + # Setup mock responses + async def table_exists_side_effect(db_instance, name): + # Both tables exist (they're the same) + return True + + # Mock _pg_table_exists function + with patch( + "lightrag.kg.postgres_impl._pg_table_exists", + side_effect=table_exists_side_effect, + ): + # Call setup_table + # This should detect that new == legacy and skip deletion + await PGVectorStorage.setup_table( + db, + table_name, + legacy_table_name=legacy_table_name, + base_table="LIGHTRAG_VDB_CHUNKS", + embedding_dim=1536, + ) + + # CRITICAL: Table should NOT be deleted (no DROP TABLE) + drop_calls = [ + call + for call in db.execute.call_args_list + if call[0][0] and "DROP TABLE" in call[0][0] + ] + assert ( + len(drop_calls) == 0 + ), "Should not drop table when new and legacy are the same" + + # Also should not try to count (we returned early) + count_calls = [ + call + for call in db.query.call_args_list + if call[0][0] and "COUNT(*)" in call[0][0] + ] + assert ( + len(count_calls) == 0 + ), "Should not check count when new and legacy are the same" + + def test_qdrant_with_suffix_case1_still_works(self): + """ + Test that Case 1 cleanup still works when there IS a suffix. + + This ensures our fix doesn't break the normal Case 1 scenario. + """ + from qdrant_client import models + + client = MagicMock() + + # Different names (normal case) + collection_name = "lightrag_vdb_chunks_ada_002_1536d" # With suffix + legacy_collection = "lightrag_vdb_chunks" # Without suffix + + # Setup: both exist + def collection_exists_side_effect(name): + return name in [collection_name, legacy_collection] + + client.collection_exists.side_effect = collection_exists_side_effect + + # Legacy is empty + client.count.return_value.count = 0 + + # Call setup_collection + QdrantVectorDBStorage.setup_collection( + client, + collection_name, + namespace="chunks", + workspace=None, + vectors_config=models.VectorParams( + size=1536, distance=models.Distance.COSINE + ), + ) + + # SHOULD delete legacy (normal Case 1 behavior) + client.delete_collection.assert_called_once_with( + collection_name=legacy_collection + ) + + @pytest.mark.asyncio + async def test_postgres_with_suffix_case1_still_works(self): + """ + Test that Case 1 cleanup still works when there IS a suffix. + + This ensures our fix doesn't break the normal Case 1 scenario. + """ + db = AsyncMock() + + # Different names (normal case) + table_name = "LIGHTRAG_VDB_CHUNKS_ADA_002_1536D" # With suffix + legacy_table_name = "LIGHTRAG_VDB_CHUNKS" # Without suffix + + # Setup mock responses + async def table_exists_side_effect(db_instance, name): + # Both tables exist + return True + + # Mock empty table + async def query_side_effect(sql, params, **kwargs): + if "COUNT(*)" in sql: + return {"count": 0} + return {} + + db.query.side_effect = query_side_effect + + # Mock _pg_table_exists function + with patch( + "lightrag.kg.postgres_impl._pg_table_exists", + side_effect=table_exists_side_effect, + ): + # Call setup_table + await PGVectorStorage.setup_table( + db, + table_name, + legacy_table_name=legacy_table_name, + base_table="LIGHTRAG_VDB_CHUNKS", + embedding_dim=1536, + ) + + # SHOULD delete legacy (normal Case 1 behavior) + drop_calls = [ + call + for call in db.execute.call_args_list + if call[0][0] and "DROP TABLE" in call[0][0] + ] + assert len(drop_calls) == 1, "Should drop legacy table in normal Case 1" + assert legacy_table_name in drop_calls[0][0][0] diff --git a/tests/test_postgres_migration.py b/tests/test_postgres_migration.py new file mode 100644 index 0000000000..df88e70096 --- /dev/null +++ b/tests/test_postgres_migration.py @@ -0,0 +1,805 @@ +import pytest +from unittest.mock import patch, AsyncMock +import numpy as np +from lightrag.utils import EmbeddingFunc +from lightrag.kg.postgres_impl import ( + PGVectorStorage, +) +from lightrag.namespace import NameSpace + + +# Mock PostgreSQLDB +@pytest.fixture +def mock_pg_db(): + """Mock PostgreSQL database connection""" + db = AsyncMock() + db.workspace = "test_workspace" + + # Mock query responses with multirows support + async def mock_query(sql, params=None, multirows=False, **kwargs): + # Default return value + if multirows: + return [] # Return empty list for multirows + return {"exists": False, "count": 0} + + # Mock for execute that mimics PostgreSQLDB.execute() behavior + async def mock_execute(sql, data=None, **kwargs): + """ + Mock that mimics PostgreSQLDB.execute() behavior: + - Accepts data as dict[str, Any] | None (second parameter) + - Internally converts dict.values() to tuple for AsyncPG + """ + # Mimic real execute() which accepts dict and converts to tuple + if data is not None and not isinstance(data, dict): + raise TypeError( + f"PostgreSQLDB.execute() expects data as dict, got {type(data).__name__}" + ) + return None + + db.query = AsyncMock(side_effect=mock_query) + db.execute = AsyncMock(side_effect=mock_execute) + + return db + + +# Mock get_data_init_lock to avoid async lock issues in tests +@pytest.fixture(autouse=True) +def mock_data_init_lock(): + with patch("lightrag.kg.postgres_impl.get_data_init_lock") as mock_lock: + mock_lock_ctx = AsyncMock() + mock_lock.return_value = mock_lock_ctx + yield mock_lock + + +# Mock ClientManager +@pytest.fixture +def mock_client_manager(mock_pg_db): + with patch("lightrag.kg.postgres_impl.ClientManager") as mock_manager: + mock_manager.get_client = AsyncMock(return_value=mock_pg_db) + mock_manager.release_client = AsyncMock() + yield mock_manager + + +# Mock Embedding function +@pytest.fixture +def mock_embedding_func(): + async def embed_func(texts, **kwargs): + return np.array([[0.1] * 768 for _ in texts]) + + func = EmbeddingFunc(embedding_dim=768, func=embed_func, model_name="test_model") + return func + + +@pytest.mark.asyncio +async def test_postgres_table_naming( + mock_client_manager, mock_pg_db, mock_embedding_func +): + """Test if table name is correctly generated with model suffix""" + config = { + "embedding_batch_num": 10, + "vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8}, + } + + storage = PGVectorStorage( + namespace=NameSpace.VECTOR_STORE_CHUNKS, + global_config=config, + embedding_func=mock_embedding_func, + workspace="test_ws", + ) + + # Verify table name contains model suffix + expected_suffix = "test_model_768d" + assert expected_suffix in storage.table_name + assert storage.table_name == f"LIGHTRAG_VDB_CHUNKS_{expected_suffix}" + + # Verify legacy table name + assert storage.legacy_table_name == "LIGHTRAG_VDB_CHUNKS" + + +@pytest.mark.asyncio +async def test_postgres_migration_trigger( + mock_client_manager, mock_pg_db, mock_embedding_func +): + """Test if migration logic is triggered correctly""" + config = { + "embedding_batch_num": 10, + "vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8}, + } + + storage = PGVectorStorage( + namespace=NameSpace.VECTOR_STORE_CHUNKS, + global_config=config, + embedding_func=mock_embedding_func, + workspace="test_ws", + ) + + # Setup mocks for migration scenario + # 1. New table does not exist, legacy table exists + async def mock_table_exists(db, table_name): + return table_name == storage.legacy_table_name + + # 2. Legacy table has 100 records + mock_rows = [ + {"id": f"test_id_{i}", "content": f"content_{i}", "workspace": "test_ws"} + for i in range(100) + ] + + async def mock_query(sql, params=None, multirows=False, **kwargs): + if "COUNT(*)" in sql: + return {"count": 100} + elif multirows and "SELECT *" in sql: + # Mock batch fetch for migration + # Handle workspace filtering: params = [workspace, offset, limit] or [offset, limit] + if "WHERE workspace" in sql: + # With workspace filter: params[0]=workspace, params[1]=offset, params[2]=limit + offset = params[1] if len(params) > 1 else 0 + limit = params[2] if len(params) > 2 else 500 + else: + # No workspace filter: params[0]=offset, params[1]=limit + offset = params[0] if params else 0 + limit = params[1] if len(params) > 1 else 500 + start = offset + end = min(offset + limit, len(mock_rows)) + return mock_rows[start:end] + return {} + + mock_pg_db.query = AsyncMock(side_effect=mock_query) + + with ( + patch( + "lightrag.kg.postgres_impl._pg_table_exists", side_effect=mock_table_exists + ), + patch("lightrag.kg.postgres_impl._pg_create_table", AsyncMock()), + ): + # Initialize storage (should trigger migration) + await storage.initialize() + + # Verify migration was executed + # Check that execute was called for inserting rows + assert mock_pg_db.execute.call_count > 0 + + +@pytest.mark.asyncio +async def test_postgres_no_migration_needed( + mock_client_manager, mock_pg_db, mock_embedding_func +): + """Test scenario where new table already exists (no migration needed)""" + config = { + "embedding_batch_num": 10, + "vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8}, + } + + storage = PGVectorStorage( + namespace=NameSpace.VECTOR_STORE_CHUNKS, + global_config=config, + embedding_func=mock_embedding_func, + workspace="test_ws", + ) + + # Mock: new table already exists + async def mock_table_exists(db, table_name): + return table_name == storage.table_name + + with ( + patch( + "lightrag.kg.postgres_impl._pg_table_exists", side_effect=mock_table_exists + ), + patch("lightrag.kg.postgres_impl._pg_create_table", AsyncMock()) as mock_create, + ): + await storage.initialize() + + # Verify no table creation was attempted + mock_create.assert_not_called() + + +@pytest.mark.asyncio +async def test_scenario_1_new_workspace_creation( + mock_client_manager, mock_pg_db, mock_embedding_func +): + """ + Scenario 1: New workspace creation + + Expected behavior: + - No legacy table exists + - Directly create new table with model suffix + - No migration needed + """ + config = { + "embedding_batch_num": 10, + "vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8}, + } + + embedding_func = EmbeddingFunc( + embedding_dim=3072, + func=mock_embedding_func.func, + model_name="text-embedding-3-large", + ) + + storage = PGVectorStorage( + namespace=NameSpace.VECTOR_STORE_CHUNKS, + global_config=config, + embedding_func=embedding_func, + workspace="new_workspace", + ) + + # Mock: neither table exists + async def mock_table_exists(db, table_name): + return False + + with ( + patch( + "lightrag.kg.postgres_impl._pg_table_exists", side_effect=mock_table_exists + ), + patch("lightrag.kg.postgres_impl._pg_create_table", AsyncMock()) as mock_create, + ): + await storage.initialize() + + # Verify table name format + assert "text_embedding_3_large_3072d" in storage.table_name + + # Verify new table creation was called + mock_create.assert_called_once() + call_args = mock_create.call_args + assert ( + call_args[0][1] == storage.table_name + ) # table_name is second positional arg + + +@pytest.mark.asyncio +async def test_scenario_2_legacy_upgrade_migration( + mock_client_manager, mock_pg_db, mock_embedding_func +): + """ + Scenario 2: Upgrade from legacy version + + Expected behavior: + - Legacy table exists (without model suffix) + - New table doesn't exist + - Automatically migrate data to new table with suffix + """ + config = { + "embedding_batch_num": 10, + "vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8}, + } + + embedding_func = EmbeddingFunc( + embedding_dim=1536, + func=mock_embedding_func.func, + model_name="text-embedding-ada-002", + ) + + storage = PGVectorStorage( + namespace=NameSpace.VECTOR_STORE_CHUNKS, + global_config=config, + embedding_func=embedding_func, + workspace="legacy_workspace", + ) + + # Mock: only legacy table exists + async def mock_table_exists(db, table_name): + return table_name == storage.legacy_table_name + + # Mock: legacy table has 50 records + mock_rows = [ + { + "id": f"legacy_id_{i}", + "content": f"legacy_content_{i}", + "workspace": "legacy_workspace", + } + for i in range(50) + ] + + # Track which queries have been made for proper response + query_history = [] + + async def mock_query(sql, params=None, multirows=False, **kwargs): + query_history.append(sql) + + if "COUNT(*)" in sql: + # Determine table type: + # - Legacy: contains base name but NOT model suffix + # - New: contains model suffix (e.g., text_embedding_ada_002_1536d) + sql_upper = sql.upper() + base_name = storage.legacy_table_name.upper() + + # Check if this is querying the new table (has model suffix) + has_model_suffix = any( + suffix in sql_upper + for suffix in ["TEXT_EMBEDDING", "_1536D", "_768D", "_1024D", "_3072D"] + ) + + is_legacy_table = base_name in sql_upper and not has_model_suffix + is_new_table = has_model_suffix + has_workspace_filter = "WHERE workspace" in sql + + if is_legacy_table and has_workspace_filter: + # Count for legacy table with workspace filter (before migration) + return {"count": 50} + elif is_legacy_table and not has_workspace_filter: + # Total count for legacy table (after deletion, checking remaining) + return {"count": 0} + elif is_new_table: + # Count for new table (verification after migration) + return {"count": 50} + else: + # Fallback + return {"count": 0} + elif multirows and "SELECT *" in sql: + # Mock batch fetch for migration + # Handle workspace filtering: params = [workspace, offset, limit] or [offset, limit] + if "WHERE workspace" in sql: + # With workspace filter: params[0]=workspace, params[1]=offset, params[2]=limit + offset = params[1] if len(params) > 1 else 0 + limit = params[2] if len(params) > 2 else 500 + else: + # No workspace filter: params[0]=offset, params[1]=limit + offset = params[0] if params else 0 + limit = params[1] if len(params) > 1 else 500 + start = offset + end = min(offset + limit, len(mock_rows)) + return mock_rows[start:end] + return {} + + mock_pg_db.query = AsyncMock(side_effect=mock_query) + + with ( + patch( + "lightrag.kg.postgres_impl._pg_table_exists", side_effect=mock_table_exists + ), + patch("lightrag.kg.postgres_impl._pg_create_table", AsyncMock()) as mock_create, + ): + await storage.initialize() + + # Verify table name contains ada-002 + assert "text_embedding_ada_002_1536d" in storage.table_name + + # Verify migration was executed + assert mock_pg_db.execute.call_count >= 50 # At least one execute per row + mock_create.assert_called_once() + + # Verify legacy table was automatically deleted after successful migration + # This prevents Case 1 warnings on next startup + delete_calls = [ + call + for call in mock_pg_db.execute.call_args_list + if call[0][0] and "DROP TABLE" in call[0][0] + ] + assert ( + len(delete_calls) >= 1 + ), "Legacy table should be deleted after successful migration" + # Check if legacy table was dropped + dropped_table = storage.legacy_table_name + assert any( + dropped_table in str(call) for call in delete_calls + ), f"Expected to drop '{dropped_table}'" + + +@pytest.mark.asyncio +async def test_scenario_3_multi_model_coexistence( + mock_client_manager, mock_pg_db, mock_embedding_func +): + """ + Scenario 3: Multiple embedding models coexist + + Expected behavior: + - Different embedding models create separate tables + - Tables are isolated by model suffix + - No interference between different models + """ + config = { + "embedding_batch_num": 10, + "vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8}, + } + + # Workspace A: uses bge-small (768d) + embedding_func_a = EmbeddingFunc( + embedding_dim=768, func=mock_embedding_func.func, model_name="bge-small" + ) + + storage_a = PGVectorStorage( + namespace=NameSpace.VECTOR_STORE_CHUNKS, + global_config=config, + embedding_func=embedding_func_a, + workspace="workspace_a", + ) + + # Workspace B: uses bge-large (1024d) + async def embed_func_b(texts, **kwargs): + return np.array([[0.1] * 1024 for _ in texts]) + + embedding_func_b = EmbeddingFunc( + embedding_dim=1024, func=embed_func_b, model_name="bge-large" + ) + + storage_b = PGVectorStorage( + namespace=NameSpace.VECTOR_STORE_CHUNKS, + global_config=config, + embedding_func=embedding_func_b, + workspace="workspace_b", + ) + + # Verify different table names + assert storage_a.table_name != storage_b.table_name + assert "bge_small_768d" in storage_a.table_name + assert "bge_large_1024d" in storage_b.table_name + + # Mock: both tables don't exist yet + async def mock_table_exists(db, table_name): + return False + + with ( + patch( + "lightrag.kg.postgres_impl._pg_table_exists", side_effect=mock_table_exists + ), + patch("lightrag.kg.postgres_impl._pg_create_table", AsyncMock()) as mock_create, + ): + # Initialize both storages + await storage_a.initialize() + await storage_b.initialize() + + # Verify two separate tables were created + assert mock_create.call_count == 2 + + # Verify table names are different + call_args_list = mock_create.call_args_list + table_names = [call[0][1] for call in call_args_list] # Second positional arg + assert len(set(table_names)) == 2 # Two unique table names + assert storage_a.table_name in table_names + assert storage_b.table_name in table_names + + +@pytest.mark.asyncio +async def test_case1_empty_legacy_auto_cleanup( + mock_client_manager, mock_pg_db, mock_embedding_func +): + """ + Case 1a: Both new and legacy tables exist, but legacy is EMPTY + Expected: Automatically delete empty legacy table (safe cleanup) + """ + config = { + "embedding_batch_num": 10, + "vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8}, + } + + embedding_func = EmbeddingFunc( + embedding_dim=1536, + func=mock_embedding_func.func, + model_name="test-model", + ) + + storage = PGVectorStorage( + namespace=NameSpace.VECTOR_STORE_CHUNKS, + global_config=config, + embedding_func=embedding_func, + workspace="test_ws", + ) + + # Mock: Both tables exist + async def mock_table_exists(db, table_name): + return True # Both new and legacy exist + + # Mock: Legacy table is empty (0 records) + async def mock_query(sql, params=None, multirows=False, **kwargs): + if "COUNT(*)" in sql: + if storage.legacy_table_name in sql: + return {"count": 0} # Empty legacy table + else: + return {"count": 100} # New table has data + return {} + + mock_pg_db.query = AsyncMock(side_effect=mock_query) + + with patch( + "lightrag.kg.postgres_impl._pg_table_exists", side_effect=mock_table_exists + ): + await storage.initialize() + + # Verify: Empty legacy table should be automatically cleaned up + # Empty tables are safe to delete without data loss risk + delete_calls = [ + call + for call in mock_pg_db.execute.call_args_list + if call[0][0] and "DROP TABLE" in call[0][0] + ] + assert len(delete_calls) >= 1, "Empty legacy table should be auto-deleted" + # Check if legacy table was dropped + dropped_table = storage.legacy_table_name + assert any( + dropped_table in str(call) for call in delete_calls + ), f"Expected to drop empty legacy table '{dropped_table}'" + + print( + f"✅ Case 1a: Empty legacy table '{dropped_table}' auto-deleted successfully" + ) + + +@pytest.mark.asyncio +async def test_case1_nonempty_legacy_warning( + mock_client_manager, mock_pg_db, mock_embedding_func +): + """ + Case 1b: Both new and legacy tables exist, and legacy HAS DATA + Expected: Log warning, do not delete legacy (preserve data) + """ + config = { + "embedding_batch_num": 10, + "vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8}, + } + + embedding_func = EmbeddingFunc( + embedding_dim=1536, + func=mock_embedding_func.func, + model_name="test-model", + ) + + storage = PGVectorStorage( + namespace=NameSpace.VECTOR_STORE_CHUNKS, + global_config=config, + embedding_func=embedding_func, + workspace="test_ws", + ) + + # Mock: Both tables exist + async def mock_table_exists(db, table_name): + return True # Both new and legacy exist + + # Mock: Legacy table has data (50 records) + async def mock_query(sql, params=None, multirows=False, **kwargs): + if "COUNT(*)" in sql: + if storage.legacy_table_name in sql: + return {"count": 50} # Legacy has data + else: + return {"count": 100} # New table has data + return {} + + mock_pg_db.query = AsyncMock(side_effect=mock_query) + + with patch( + "lightrag.kg.postgres_impl._pg_table_exists", side_effect=mock_table_exists + ): + await storage.initialize() + + # Verify: Legacy table with data should be preserved + # We never auto-delete tables that contain data to prevent accidental data loss + delete_calls = [ + call + for call in mock_pg_db.execute.call_args_list + if call[0][0] and "DROP TABLE" in call[0][0] + ] + # Check if legacy table was deleted (it should not be) + dropped_table = storage.legacy_table_name + legacy_deleted = any(dropped_table in str(call) for call in delete_calls) + assert not legacy_deleted, "Legacy table with data should NOT be auto-deleted" + + print( + f"✅ Case 1b: Legacy table '{dropped_table}' with data preserved (warning only)" + ) + + +@pytest.mark.asyncio +async def test_case1_sequential_workspace_migration( + mock_client_manager, mock_pg_db, mock_embedding_func +): + """ + Case 1c: Sequential workspace migration (Multi-tenant scenario) + + Critical bug fix verification: + Timeline: + 1. Legacy table has workspace_a (3 records) + workspace_b (3 records) + 2. Workspace A initializes first → Case 4 (only legacy exists) → migrates A's data + 3. Workspace B initializes later → Case 1 (both tables exist) → should migrate B's data + 4. Verify workspace B's data is correctly migrated to new table + 5. Verify legacy table is cleaned up after both workspaces migrate + + This test verifies the fix where Case 1 now checks and migrates current + workspace's data instead of just checking if legacy table is empty globally. + """ + config = { + "embedding_batch_num": 10, + "vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8}, + } + + embedding_func = EmbeddingFunc( + embedding_dim=1536, + func=mock_embedding_func.func, + model_name="test-model", + ) + + # Mock data: Legacy table has 6 records total (3 from workspace_a, 3 from workspace_b) + mock_rows_a = [ + {"id": f"a_{i}", "content": f"A content {i}", "workspace": "workspace_a"} + for i in range(3) + ] + mock_rows_b = [ + {"id": f"b_{i}", "content": f"B content {i}", "workspace": "workspace_b"} + for i in range(3) + ] + + # Track migration state + migration_state = {"new_table_exists": False, "workspace_a_migrated": False} + + # Step 1: Simulate workspace_a initialization (Case 4) + # CRITICAL: Set db.workspace to workspace_a + mock_pg_db.workspace = "workspace_a" + + storage_a = PGVectorStorage( + namespace=NameSpace.VECTOR_STORE_CHUNKS, + global_config=config, + embedding_func=embedding_func, + workspace="workspace_a", + ) + + # Mock table_exists for workspace_a + async def mock_table_exists_a(db, table_name): + if table_name == storage_a.legacy_table_name: + return True + if table_name == storage_a.table_name: + return migration_state["new_table_exists"] + return False + + # Track inserted records count for verification + inserted_count = {"workspace_a": 0} + + # Mock execute to track inserts + async def mock_execute_a(sql, data=None, **kwargs): + if sql and "INSERT INTO" in sql.upper(): + inserted_count["workspace_a"] += 1 + return None + + # Mock query for workspace_a (Case 4) + async def mock_query_a(sql, params=None, multirows=False, **kwargs): + sql_upper = sql.upper() + base_name = storage_a.legacy_table_name.upper() + + if "COUNT(*)" in sql: + has_model_suffix = "TEST_MODEL_1536D" in sql_upper + is_legacy = base_name in sql_upper and not has_model_suffix + has_workspace_filter = "WHERE workspace" in sql + + if is_legacy and has_workspace_filter: + workspace = params[0] if params and len(params) > 0 else None + if workspace == "workspace_a": + # After migration starts, pretend legacy is empty for this workspace + return {"count": 3 - inserted_count["workspace_a"]} + elif workspace == "workspace_b": + return {"count": 3} + elif is_legacy and not has_workspace_filter: + # Global count in legacy table + remaining = 6 - inserted_count["workspace_a"] + return {"count": remaining} + elif has_model_suffix: + # New table count (for verification) + return {"count": inserted_count["workspace_a"]} + elif multirows and "SELECT *" in sql: + if "WHERE workspace" in sql: + workspace = params[0] if params and len(params) > 0 else None + if workspace == "workspace_a": + offset = params[1] if len(params) > 1 else 0 + limit = params[2] if len(params) > 2 else 500 + return mock_rows_a[offset : offset + limit] + return {} + + mock_pg_db.query = AsyncMock(side_effect=mock_query_a) + mock_pg_db.execute = AsyncMock(side_effect=mock_execute_a) + + # Initialize workspace_a (Case 4) + with ( + patch( + "lightrag.kg.postgres_impl._pg_table_exists", + side_effect=mock_table_exists_a, + ), + patch("lightrag.kg.postgres_impl._pg_create_table", AsyncMock()), + ): + await storage_a.initialize() + migration_state["new_table_exists"] = True + migration_state["workspace_a_migrated"] = True + + print("✅ Step 1: Workspace A initialized (Case 4)") + assert mock_pg_db.execute.call_count >= 3 + print(f"✅ Step 1: {mock_pg_db.execute.call_count} execute calls") + + # Step 2: Simulate workspace_b initialization (Case 1) + # CRITICAL: Set db.workspace to workspace_b + mock_pg_db.workspace = "workspace_b" + + storage_b = PGVectorStorage( + namespace=NameSpace.VECTOR_STORE_CHUNKS, + global_config=config, + embedding_func=embedding_func, + workspace="workspace_b", + ) + + mock_pg_db.reset_mock() + migration_state["workspace_b_migrated"] = False + + # Mock table_exists for workspace_b (both exist) + async def mock_table_exists_b(db, table_name): + return True + + # Track inserted records count for workspace_b + inserted_count["workspace_b"] = 0 + + # Mock execute for workspace_b to track inserts + async def mock_execute_b(sql, data=None, **kwargs): + if sql and "INSERT INTO" in sql.upper(): + inserted_count["workspace_b"] += 1 + return None + + # Mock query for workspace_b (Case 1) + async def mock_query_b(sql, params=None, multirows=False, **kwargs): + sql_upper = sql.upper() + base_name = storage_b.legacy_table_name.upper() + + if "COUNT(*)" in sql: + has_model_suffix = "TEST_MODEL_1536D" in sql_upper + is_legacy = base_name in sql_upper and not has_model_suffix + has_workspace_filter = "WHERE workspace" in sql + + if is_legacy and has_workspace_filter: + workspace = params[0] if params and len(params) > 0 else None + if workspace == "workspace_b": + # After migration starts, pretend legacy is empty for this workspace + return {"count": 3 - inserted_count["workspace_b"]} + elif workspace == "workspace_a": + return {"count": 0} # Already migrated + elif is_legacy and not has_workspace_filter: + # Global count: only workspace_b data remains + return {"count": 3 - inserted_count["workspace_b"]} + elif has_model_suffix: + # New table total count (workspace_a: 3 + workspace_b: inserted) + if has_workspace_filter: + workspace = params[0] if params and len(params) > 0 else None + if workspace == "workspace_b": + return {"count": inserted_count["workspace_b"]} + elif workspace == "workspace_a": + return {"count": 3} + else: + # Total count in new table (for verification) + return {"count": 3 + inserted_count["workspace_b"]} + elif multirows and "SELECT *" in sql: + if "WHERE workspace" in sql: + workspace = params[0] if params and len(params) > 0 else None + if workspace == "workspace_b": + offset = params[1] if len(params) > 1 else 0 + limit = params[2] if len(params) > 2 else 500 + return mock_rows_b[offset : offset + limit] + return {} + + mock_pg_db.query = AsyncMock(side_effect=mock_query_b) + mock_pg_db.execute = AsyncMock(side_effect=mock_execute_b) + + # Initialize workspace_b (Case 1) + with patch( + "lightrag.kg.postgres_impl._pg_table_exists", side_effect=mock_table_exists_b + ): + await storage_b.initialize() + migration_state["workspace_b_migrated"] = True + + print("✅ Step 2: Workspace B initialized (Case 1)") + + # Verify workspace_b migration happened + execute_calls = mock_pg_db.execute.call_args_list + insert_calls = [ + call for call in execute_calls if call[0][0] and "INSERT INTO" in call[0][0] + ] + assert len(insert_calls) >= 3, f"Expected >= 3 inserts, got {len(insert_calls)}" + print(f"✅ Step 2: {len(insert_calls)} insert calls") + + # Verify DELETE and DROP TABLE + delete_calls = [ + call + for call in execute_calls + if call[0][0] + and "DELETE FROM" in call[0][0] + and "WHERE workspace" in call[0][0] + ] + assert len(delete_calls) >= 1, "Expected DELETE workspace_b data" + print("✅ Step 2: DELETE workspace_b from legacy") + + drop_calls = [ + call for call in execute_calls if call[0][0] and "DROP TABLE" in call[0][0] + ] + assert len(drop_calls) >= 1, "Expected DROP TABLE" + print("✅ Step 2: Legacy table dropped") + + print("\n🎉 Case 1c: Sequential workspace migration verified!") diff --git a/tests/test_qdrant_migration.py b/tests/test_qdrant_migration.py new file mode 100644 index 0000000000..0da237b8f5 --- /dev/null +++ b/tests/test_qdrant_migration.py @@ -0,0 +1,522 @@ +import pytest +from unittest.mock import MagicMock, patch, AsyncMock +import numpy as np +from lightrag.utils import EmbeddingFunc +from lightrag.kg.qdrant_impl import QdrantVectorDBStorage + + +# Mock QdrantClient +@pytest.fixture +def mock_qdrant_client(): + with patch("lightrag.kg.qdrant_impl.QdrantClient") as mock_client_cls: + client = mock_client_cls.return_value + client.collection_exists.return_value = False + client.count.return_value.count = 0 + # Mock payload schema and vector config for get_collection + collection_info = MagicMock() + collection_info.payload_schema = {} + # Mock vector dimension to match mock_embedding_func (768d) + collection_info.config.params.vectors.size = 768 + client.get_collection.return_value = collection_info + yield client + + +# Mock get_data_init_lock to avoid async lock issues in tests +@pytest.fixture(autouse=True) +def mock_data_init_lock(): + with patch("lightrag.kg.qdrant_impl.get_data_init_lock") as mock_lock: + mock_lock_ctx = AsyncMock() + mock_lock.return_value = mock_lock_ctx + yield mock_lock + + +# Mock Embedding function +@pytest.fixture +def mock_embedding_func(): + async def embed_func(texts, **kwargs): + return np.array([[0.1] * 768 for _ in texts]) + + func = EmbeddingFunc(embedding_dim=768, func=embed_func, model_name="test-model") + return func + + +@pytest.mark.asyncio +async def test_qdrant_collection_naming(mock_qdrant_client, mock_embedding_func): + """Test if collection name is correctly generated with model suffix""" + config = { + "embedding_batch_num": 10, + "vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8}, + } + + storage = QdrantVectorDBStorage( + namespace="chunks", + global_config=config, + embedding_func=mock_embedding_func, + workspace="test_ws", + ) + + # Verify collection name contains model suffix + expected_suffix = "test_model_768d" + assert expected_suffix in storage.final_namespace + assert storage.final_namespace == f"lightrag_vdb_chunks_{expected_suffix}" + + # Verify legacy namespace (should not include workspace, just the base collection name) + assert storage.legacy_namespace == "lightrag_vdb_chunks" + + +@pytest.mark.asyncio +async def test_qdrant_migration_trigger(mock_qdrant_client, mock_embedding_func): + """Test if migration logic is triggered correctly""" + config = { + "embedding_batch_num": 10, + "vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8}, + } + + storage = QdrantVectorDBStorage( + namespace="chunks", + global_config=config, + embedding_func=mock_embedding_func, + workspace="test_ws", + ) + + # Setup mocks for migration scenario + # 1. New collection does not exist + mock_qdrant_client.collection_exists.side_effect = ( + lambda name: name == storage.legacy_namespace + ) + + # 2. Legacy collection exists and has data + mock_qdrant_client.count.return_value.count = 100 + + # 3. Mock scroll for data migration + + mock_point = MagicMock() + mock_point.id = "old_id" + mock_point.vector = [0.1] * 768 + mock_point.payload = {"content": "test"} + + # First call returns points, second call returns empty (end of scroll) + mock_qdrant_client.scroll.side_effect = [([mock_point], "next_offset"), ([], None)] + + # Initialize storage (triggers migration) + await storage.initialize() + + # Verify migration steps + # 1. Legacy count checked + mock_qdrant_client.count.assert_any_call( + collection_name=storage.legacy_namespace, exact=True + ) + + # 2. New collection created + mock_qdrant_client.create_collection.assert_called() + + # 3. Data scrolled from legacy + assert mock_qdrant_client.scroll.call_count >= 1 + call_args = mock_qdrant_client.scroll.call_args_list[0] + assert call_args.kwargs["collection_name"] == storage.legacy_namespace + assert call_args.kwargs["limit"] == 500 + + # 4. Data upserted to new + mock_qdrant_client.upsert.assert_called() + + # 5. Payload index created + mock_qdrant_client.create_payload_index.assert_called() + + +@pytest.mark.asyncio +async def test_qdrant_no_migration_needed(mock_qdrant_client, mock_embedding_func): + """Test scenario where new collection already exists""" + config = { + "embedding_batch_num": 10, + "vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8}, + } + + storage = QdrantVectorDBStorage( + namespace="chunks", + global_config=config, + embedding_func=mock_embedding_func, + workspace="test_ws", + ) + + # New collection exists and Legacy exists (warning case) + # or New collection exists and Legacy does not exist (normal case) + # Mocking case where both exist to test logic flow but without migration + + # Logic in code: + # Case 1: Both exist -> Warning only + # Case 2: Only new exists -> Ensure index + + # Let's test Case 2: Only new collection exists + mock_qdrant_client.collection_exists.side_effect = ( + lambda name: name == storage.final_namespace + ) + + # Initialize + await storage.initialize() + + # Should check index but NOT migrate + # In Qdrant implementation, Case 2 calls get_collection + mock_qdrant_client.get_collection.assert_called_with(storage.final_namespace) + mock_qdrant_client.scroll.assert_not_called() + + +# ============================================================================ +# Tests for scenarios described in design document (Lines 606-649) +# ============================================================================ + + +@pytest.mark.asyncio +async def test_scenario_1_new_workspace_creation( + mock_qdrant_client, mock_embedding_func +): + """ + 场景1:新建workspace + 预期:直接创建lightrag_vdb_chunks_text_embedding_3_large_3072d + """ + # Use a large embedding model + large_model_func = EmbeddingFunc( + embedding_dim=3072, + func=mock_embedding_func.func, + model_name="text-embedding-3-large", + ) + + config = { + "embedding_batch_num": 10, + "vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8}, + } + + storage = QdrantVectorDBStorage( + namespace="chunks", + global_config=config, + embedding_func=large_model_func, + workspace="test_new", + ) + + # Case 3: Neither legacy nor new collection exists + mock_qdrant_client.collection_exists.return_value = False + + # Initialize storage + await storage.initialize() + + # Verify: Should create new collection with model suffix + expected_collection = "lightrag_vdb_chunks_text_embedding_3_large_3072d" + assert storage.final_namespace == expected_collection + + # Verify create_collection was called with correct name + create_calls = [ + call for call in mock_qdrant_client.create_collection.call_args_list + ] + assert len(create_calls) > 0 + assert ( + create_calls[0][0][0] == expected_collection + or create_calls[0].kwargs.get("collection_name") == expected_collection + ) + + # Verify no migration was attempted + mock_qdrant_client.scroll.assert_not_called() + + print( + f"✅ Scenario 1: New workspace created with collection '{expected_collection}'" + ) + + +@pytest.mark.asyncio +async def test_scenario_2_legacy_upgrade_migration( + mock_qdrant_client, mock_embedding_func +): + """ + 场景2:从旧版本升级 + 已存在lightrag_vdb_chunks(无后缀) + 预期:自动迁移数据到lightrag_vdb_chunks_text_embedding_ada_002_1536d + """ + # Use ada-002 model + ada_func = EmbeddingFunc( + embedding_dim=1536, + func=mock_embedding_func.func, + model_name="text-embedding-ada-002", + ) + + config = { + "embedding_batch_num": 10, + "vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8}, + } + + storage = QdrantVectorDBStorage( + namespace="chunks", + global_config=config, + embedding_func=ada_func, + workspace="test_legacy", + ) + + legacy_collection = storage.legacy_namespace + new_collection = storage.final_namespace + + # Case 4: Only legacy collection exists + mock_qdrant_client.collection_exists.side_effect = ( + lambda name: name == legacy_collection + ) + + # Mock legacy collection info with 1536d vectors + legacy_collection_info = MagicMock() + legacy_collection_info.payload_schema = {} + legacy_collection_info.config.params.vectors.size = 1536 + mock_qdrant_client.get_collection.return_value = legacy_collection_info + + # Mock legacy data + mock_qdrant_client.count.return_value.count = 150 + + # Mock scroll results (simulate migration in batches) + + mock_points = [] + for i in range(10): + point = MagicMock() + point.id = f"legacy-{i}" + point.vector = [0.1] * 1536 + point.payload = {"content": f"Legacy document {i}", "id": f"doc-{i}"} + mock_points.append(point) + + # First batch returns points, second batch returns empty + mock_qdrant_client.scroll.side_effect = [(mock_points, "offset1"), ([], None)] + + # Initialize (triggers migration) + await storage.initialize() + + # Verify: New collection should be created + expected_new_collection = "lightrag_vdb_chunks_text_embedding_ada_002_1536d" + assert storage.final_namespace == expected_new_collection + + # Verify migration steps + # 1. Check legacy count + mock_qdrant_client.count.assert_any_call( + collection_name=legacy_collection, exact=True + ) + + # 2. Create new collection + mock_qdrant_client.create_collection.assert_called() + + # 3. Scroll legacy data + scroll_calls = [call for call in mock_qdrant_client.scroll.call_args_list] + assert len(scroll_calls) >= 1 + assert scroll_calls[0].kwargs["collection_name"] == legacy_collection + + # 4. Upsert to new collection + upsert_calls = [call for call in mock_qdrant_client.upsert.call_args_list] + assert len(upsert_calls) >= 1 + assert upsert_calls[0].kwargs["collection_name"] == new_collection + + # 5. Verify legacy collection was automatically deleted after successful migration + # This prevents Case 1 warnings on next startup + delete_calls = [ + call for call in mock_qdrant_client.delete_collection.call_args_list + ] + assert ( + len(delete_calls) >= 1 + ), "Legacy collection should be deleted after successful migration" + # Check if legacy_collection was passed to delete_collection + deleted_collection = ( + delete_calls[0][0][0] + if delete_calls[0][0] + else delete_calls[0].kwargs.get("collection_name") + ) + assert ( + deleted_collection == legacy_collection + ), f"Expected to delete '{legacy_collection}', but deleted '{deleted_collection}'" + + print( + f"✅ Scenario 2: Legacy data migrated from '{legacy_collection}' to '{expected_new_collection}' and legacy collection deleted" + ) + + +@pytest.mark.asyncio +async def test_scenario_3_multi_model_coexistence(mock_qdrant_client): + """ + 场景3:多模型并存 + 预期:两个独立的collection,互不干扰 + """ + + # Model A: bge-small with 768d + async def embed_func_a(texts, **kwargs): + return np.array([[0.1] * 768 for _ in texts]) + + model_a_func = EmbeddingFunc( + embedding_dim=768, func=embed_func_a, model_name="bge-small" + ) + + # Model B: bge-large with 1024d + async def embed_func_b(texts, **kwargs): + return np.array([[0.2] * 1024 for _ in texts]) + + model_b_func = EmbeddingFunc( + embedding_dim=1024, func=embed_func_b, model_name="bge-large" + ) + + config = { + "embedding_batch_num": 10, + "vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8}, + } + + # Create storage for workspace A with model A + storage_a = QdrantVectorDBStorage( + namespace="chunks", + global_config=config, + embedding_func=model_a_func, + workspace="workspace_a", + ) + + # Create storage for workspace B with model B + storage_b = QdrantVectorDBStorage( + namespace="chunks", + global_config=config, + embedding_func=model_b_func, + workspace="workspace_b", + ) + + # Verify: Collection names are different + assert storage_a.final_namespace != storage_b.final_namespace + + # Verify: Model A collection + expected_collection_a = "lightrag_vdb_chunks_bge_small_768d" + assert storage_a.final_namespace == expected_collection_a + + # Verify: Model B collection + expected_collection_b = "lightrag_vdb_chunks_bge_large_1024d" + assert storage_b.final_namespace == expected_collection_b + + # Verify: Different embedding dimensions are preserved + assert storage_a.embedding_func.embedding_dim == 768 + assert storage_b.embedding_func.embedding_dim == 1024 + + print("✅ Scenario 3: Multi-model coexistence verified") + print(f" - Workspace A: {expected_collection_a} (768d)") + print(f" - Workspace B: {expected_collection_b} (1024d)") + print(" - Collections are independent") + + +@pytest.mark.asyncio +async def test_case1_empty_legacy_auto_cleanup(mock_qdrant_client, mock_embedding_func): + """ + Case 1a: 新旧collection都存在,且旧库为空 + 预期:自动删除旧库 + """ + config = { + "embedding_batch_num": 10, + "vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8}, + } + + storage = QdrantVectorDBStorage( + namespace="chunks", + global_config=config, + embedding_func=mock_embedding_func, + workspace="test_ws", + ) + + legacy_collection = storage.legacy_namespace + new_collection = storage.final_namespace + + # Mock: Both collections exist + mock_qdrant_client.collection_exists.side_effect = lambda name: name in [ + legacy_collection, + new_collection, + ] + + # Mock: Legacy collection is empty (0 records) + def count_mock(collection_name, exact=True): + mock_result = MagicMock() + if collection_name == legacy_collection: + mock_result.count = 0 # Empty legacy collection + else: + mock_result.count = 100 # New collection has data + return mock_result + + mock_qdrant_client.count.side_effect = count_mock + + # Mock get_collection for Case 2 check + collection_info = MagicMock() + collection_info.payload_schema = {"workspace_id": True} + mock_qdrant_client.get_collection.return_value = collection_info + + # Initialize storage + await storage.initialize() + + # Verify: Empty legacy collection should be automatically cleaned up + # Empty collections are safe to delete without data loss risk + delete_calls = [ + call for call in mock_qdrant_client.delete_collection.call_args_list + ] + assert len(delete_calls) >= 1, "Empty legacy collection should be auto-deleted" + deleted_collection = ( + delete_calls[0][0][0] + if delete_calls[0][0] + else delete_calls[0].kwargs.get("collection_name") + ) + assert ( + deleted_collection == legacy_collection + ), f"Expected to delete '{legacy_collection}', but deleted '{deleted_collection}'" + + print( + f"✅ Case 1a: Empty legacy collection '{legacy_collection}' auto-deleted successfully" + ) + + +@pytest.mark.asyncio +async def test_case1_nonempty_legacy_warning(mock_qdrant_client, mock_embedding_func): + """ + Case 1b: 新旧collection都存在,且旧库有数据 + 预期:警告但不删除 + """ + config = { + "embedding_batch_num": 10, + "vector_db_storage_cls_kwargs": {"cosine_better_than_threshold": 0.8}, + } + + storage = QdrantVectorDBStorage( + namespace="chunks", + global_config=config, + embedding_func=mock_embedding_func, + workspace="test_ws", + ) + + legacy_collection = storage.legacy_namespace + new_collection = storage.final_namespace + + # Mock: Both collections exist + mock_qdrant_client.collection_exists.side_effect = lambda name: name in [ + legacy_collection, + new_collection, + ] + + # Mock: Legacy collection has data (50 records) + def count_mock(collection_name, exact=True): + mock_result = MagicMock() + if collection_name == legacy_collection: + mock_result.count = 50 # Legacy has data + else: + mock_result.count = 100 # New collection has data + return mock_result + + mock_qdrant_client.count.side_effect = count_mock + + # Mock get_collection for Case 2 check + collection_info = MagicMock() + collection_info.payload_schema = {"workspace_id": True} + mock_qdrant_client.get_collection.return_value = collection_info + + # Initialize storage + await storage.initialize() + + # Verify: Legacy collection with data should be preserved + # We never auto-delete collections that contain data to prevent accidental data loss + delete_calls = [ + call for call in mock_qdrant_client.delete_collection.call_args_list + ] + # Check if legacy collection was deleted (it should not be) + legacy_deleted = any( + (call[0][0] if call[0] else call.kwargs.get("collection_name")) + == legacy_collection + for call in delete_calls + ) + assert not legacy_deleted, "Legacy collection with data should NOT be auto-deleted" + + print( + f"✅ Case 1b: Legacy collection '{legacy_collection}' with data preserved (warning only)" + ) diff --git a/tests/test_unified_lock_safety.py b/tests/test_unified_lock_safety.py new file mode 100644 index 0000000000..41d2ec19a9 --- /dev/null +++ b/tests/test_unified_lock_safety.py @@ -0,0 +1,191 @@ +""" +Tests for UnifiedLock safety when lock is None. + +This test module verifies that UnifiedLock raises RuntimeError instead of +allowing unprotected execution when the underlying lock is None, preventing +false security and potential race conditions. + +Critical Bug 1: When self._lock is None, __aenter__ used to log WARNING but +still return successfully, allowing critical sections to run without lock +protection, causing race conditions and data corruption. + +Critical Bug 2: In __aexit__, when async_lock.release() fails, the error +recovery logic would attempt to release it again, causing double-release issues. +""" + +import pytest +from unittest.mock import MagicMock, AsyncMock +from lightrag.kg.shared_storage import UnifiedLock + + +class TestUnifiedLockSafety: + """Test suite for UnifiedLock None safety checks.""" + + @pytest.mark.asyncio + async def test_unified_lock_raises_on_none_async(self): + """ + Test that UnifiedLock raises RuntimeError when lock is None (async mode). + + Scenario: Attempt to use UnifiedLock before initialize_share_data() is called. + Expected: RuntimeError raised, preventing unprotected critical section execution. + """ + lock = UnifiedLock( + lock=None, is_async=True, name="test_async_lock", enable_logging=False + ) + + with pytest.raises( + RuntimeError, match="shared data not initialized|Lock.*is None" + ): + async with lock: + # This code should NEVER execute + pytest.fail( + "Code inside lock context should not execute when lock is None" + ) + + @pytest.mark.asyncio + async def test_unified_lock_raises_on_none_sync(self): + """ + Test that UnifiedLock raises RuntimeError when lock is None (sync mode). + + Scenario: Attempt to use UnifiedLock with None lock in sync mode. + Expected: RuntimeError raised with clear error message. + """ + lock = UnifiedLock( + lock=None, is_async=False, name="test_sync_lock", enable_logging=False + ) + + with pytest.raises( + RuntimeError, match="shared data not initialized|Lock.*is None" + ): + async with lock: + # This code should NEVER execute + pytest.fail( + "Code inside lock context should not execute when lock is None" + ) + + @pytest.mark.asyncio + async def test_error_message_clarity(self): + """ + Test that the error message clearly indicates the problem and solution. + + Scenario: Lock is None and user tries to acquire it. + Expected: Error message mentions 'shared data not initialized' and + 'initialize_share_data()'. + """ + lock = UnifiedLock( + lock=None, + is_async=True, + name="test_error_message", + enable_logging=False, + ) + + with pytest.raises(RuntimeError) as exc_info: + async with lock: + pass + + error_message = str(exc_info.value) + # Verify error message contains helpful information + assert ( + "shared data not initialized" in error_message.lower() + or "lock" in error_message.lower() + ) + assert "initialize_share_data" in error_message or "None" in error_message + + @pytest.mark.asyncio + async def test_aexit_no_double_release_on_async_lock_failure(self): + """ + Test that __aexit__ doesn't attempt to release async_lock twice when it fails. + + Scenario: async_lock.release() fails during normal release. + Expected: Recovery logic should NOT attempt to release async_lock again, + preventing double-release issues. + + This tests Bug 2 fix: async_lock_released tracking prevents double release. + """ + # Create mock locks + main_lock = MagicMock() + main_lock.acquire = MagicMock() + main_lock.release = MagicMock() + + async_lock = AsyncMock() + async_lock.acquire = AsyncMock() + + # Make async_lock.release() fail + release_call_count = 0 + + def mock_release_fail(): + nonlocal release_call_count + release_call_count += 1 + raise RuntimeError("Async lock release failed") + + async_lock.release = MagicMock(side_effect=mock_release_fail) + + # Create UnifiedLock with both locks (sync mode with async_lock) + lock = UnifiedLock( + lock=main_lock, + is_async=False, + name="test_double_release", + enable_logging=False, + ) + lock._async_lock = async_lock + + # Try to use the lock - should fail during __aexit__ + try: + async with lock: + pass + except RuntimeError as e: + # Should get the async lock release error + assert "Async lock release failed" in str(e) + + # Verify async_lock.release() was called only ONCE, not twice + assert ( + release_call_count == 1 + ), f"async_lock.release() should be called only once, but was called {release_call_count} times" + + # Main lock should have been released successfully + main_lock.release.assert_called_once() + + @pytest.mark.asyncio + async def test_aexit_recovery_on_main_lock_failure(self): + """ + Test that __aexit__ recovery logic works when main lock release fails. + + Scenario: main_lock.release() fails before async_lock is attempted. + Expected: Recovery logic should attempt to release async_lock to prevent + resource leaks. + + This verifies the recovery logic still works correctly with async_lock_released tracking. + """ + # Create mock locks + main_lock = MagicMock() + main_lock.acquire = MagicMock() + + # Make main_lock.release() fail + def mock_main_release_fail(): + raise RuntimeError("Main lock release failed") + + main_lock.release = MagicMock(side_effect=mock_main_release_fail) + + async_lock = AsyncMock() + async_lock.acquire = AsyncMock() + async_lock.release = MagicMock() + + # Create UnifiedLock with both locks (sync mode with async_lock) + lock = UnifiedLock( + lock=main_lock, is_async=False, name="test_recovery", enable_logging=False + ) + lock._async_lock = async_lock + + # Try to use the lock - should fail during __aexit__ + try: + async with lock: + pass + except RuntimeError as e: + # Should get the main lock release error + assert "Main lock release failed" in str(e) + + # Main lock release should have been attempted + main_lock.release.assert_called_once() + + # Recovery logic should have attempted to release async_lock + async_lock.release.assert_called_once() diff --git a/tests/test_workspace_migration_isolation.py b/tests/test_workspace_migration_isolation.py new file mode 100644 index 0000000000..07b8920cd0 --- /dev/null +++ b/tests/test_workspace_migration_isolation.py @@ -0,0 +1,308 @@ +""" +Tests for workspace isolation during PostgreSQL migration. + +This test module verifies that setup_table() properly filters migration data +by workspace, preventing cross-workspace data leakage during legacy table migration. + +Critical Bug: Migration copied ALL records from legacy table regardless of workspace, +causing workspace A to receive workspace B's data, violating multi-tenant isolation. +""" + +import pytest +from unittest.mock import AsyncMock + +from lightrag.kg.postgres_impl import PGVectorStorage + + +class TestWorkspaceMigrationIsolation: + """Test suite for workspace-scoped migration in PostgreSQL.""" + + @pytest.mark.asyncio + async def test_migration_filters_by_workspace(self): + """ + Test that migration only copies data from the specified workspace. + + Scenario: Legacy table contains data from multiple workspaces. + Migrate only workspace_a's data to new table. + Expected: New table contains only workspace_a data, workspace_b data excluded. + """ + db = AsyncMock() + + # Mock table existence checks + async def table_exists_side_effect(db_instance, name): + if name == "lightrag_doc_chunks": # legacy + return True + elif name == "lightrag_doc_chunks_model_1536d": # new + return False + return False + + # Mock query responses + async def query_side_effect(sql, params, **kwargs): + multirows = kwargs.get("multirows", False) + + # Table existence check + if "information_schema.tables" in sql: + if params[0] == "lightrag_doc_chunks": + return {"exists": True} + elif params[0] == "lightrag_doc_chunks_model_1536d": + return {"exists": False} + + # Count query with workspace filter (legacy table) + elif "COUNT(*)" in sql and "WHERE workspace" in sql: + if params[0] == "workspace_a": + return {"count": 2} # workspace_a has 2 records + elif params[0] == "workspace_b": + return {"count": 3} # workspace_b has 3 records + return {"count": 0} + + # Count query for new table (verification) + elif "COUNT(*)" in sql and "lightrag_doc_chunks_model_1536d" in sql: + return {"count": 2} # Verification: 2 records migrated + + # Count query for legacy table (no filter) + elif "COUNT(*)" in sql and "lightrag_doc_chunks" in sql: + return {"count": 5} # Total records in legacy + + # Dimension check + elif "pg_attribute" in sql: + return {"vector_dim": 1536} + + # SELECT with workspace filter + elif "SELECT * FROM" in sql and "WHERE workspace" in sql and multirows: + workspace = params[0] + if workspace == "workspace_a" and params[1] == 0: # offset = 0 + # Return only workspace_a data + return [ + { + "id": "a1", + "workspace": "workspace_a", + "content": "content_a1", + "content_vector": [0.1] * 1536, + }, + { + "id": "a2", + "workspace": "workspace_a", + "content": "content_a2", + "content_vector": [0.2] * 1536, + }, + ] + else: + return [] # No more data + + return {} + + db.query.side_effect = query_side_effect + db.execute = AsyncMock() + db._create_vector_index = AsyncMock() + + # Mock _pg_table_exists and _pg_create_table + from unittest.mock import patch + + with ( + patch( + "lightrag.kg.postgres_impl._pg_table_exists", + side_effect=table_exists_side_effect, + ), + patch("lightrag.kg.postgres_impl._pg_create_table", new=AsyncMock()), + ): + # Migrate for workspace_a only + await PGVectorStorage.setup_table( + db, + "lightrag_doc_chunks_model_1536d", + legacy_table_name="lightrag_doc_chunks", + base_table="lightrag_doc_chunks", + embedding_dim=1536, + workspace="workspace_a", # CRITICAL: Only migrate workspace_a + ) + + # Verify workspace filter was used in queries + count_calls = [ + call + for call in db.query.call_args_list + if call[0][0] + and "COUNT(*)" in call[0][0] + and "WHERE workspace" in call[0][0] + ] + assert len(count_calls) > 0, "Count query should use workspace filter" + assert ( + count_calls[0][0][1][0] == "workspace_a" + ), "Count should filter by workspace_a" + + select_calls = [ + call + for call in db.query.call_args_list + if call[0][0] + and "SELECT * FROM" in call[0][0] + and "WHERE workspace" in call[0][0] + ] + assert len(select_calls) > 0, "Select query should use workspace filter" + assert ( + select_calls[0][0][1][0] == "workspace_a" + ), "Select should filter by workspace_a" + + # Verify INSERT was called (migration happened) + insert_calls = [ + call + for call in db.execute.call_args_list + if call[0][0] and "INSERT INTO" in call[0][0] + ] + assert len(insert_calls) == 2, "Should insert 2 records from workspace_a" + + @pytest.mark.asyncio + async def test_migration_without_workspace_warns(self): + """ + Test that migration without workspace parameter logs a warning. + + Scenario: setup_table called without workspace parameter. + Expected: Warning logged about potential cross-workspace data copying. + """ + db = AsyncMock() + + async def table_exists_side_effect(db_instance, name): + if name == "lightrag_doc_chunks": + return True + elif name == "lightrag_doc_chunks_model_1536d": + return False + return False + + async def query_side_effect(sql, params, **kwargs): + if "information_schema.tables" in sql: + return {"exists": params[0] == "lightrag_doc_chunks"} + elif "COUNT(*)" in sql: + return {"count": 5} # 5 records total + elif "pg_attribute" in sql: + return {"vector_dim": 1536} + elif "SELECT * FROM" in sql and kwargs.get("multirows"): + if params[0] == 0: # offset = 0 + return [ + { + "id": "1", + "workspace": "workspace_a", + "content_vector": [0.1] * 1536, + }, + { + "id": "2", + "workspace": "workspace_b", + "content_vector": [0.2] * 1536, + }, + ] + else: + return [] + return {} + + db.query.side_effect = query_side_effect + db.execute = AsyncMock() + db._create_vector_index = AsyncMock() + + from unittest.mock import patch + + with ( + patch( + "lightrag.kg.postgres_impl._pg_table_exists", + side_effect=table_exists_side_effect, + ), + patch("lightrag.kg.postgres_impl._pg_create_table", new=AsyncMock()), + ): + # Migrate WITHOUT workspace parameter (dangerous!) + await PGVectorStorage.setup_table( + db, + "lightrag_doc_chunks_model_1536d", + legacy_table_name="lightrag_doc_chunks", + base_table="lightrag_doc_chunks", + embedding_dim=1536, + workspace=None, # No workspace filter! + ) + + # Verify queries do NOT use workspace filter + count_calls = [ + call + for call in db.query.call_args_list + if call[0][0] and "COUNT(*)" in call[0][0] + ] + assert len(count_calls) > 0, "Count query should be executed" + # Check that workspace filter was NOT used + has_workspace_filter = any( + "WHERE workspace" in call[0][0] for call in count_calls + ) + assert ( + not has_workspace_filter + ), "Count should NOT filter by workspace when workspace=None" + + @pytest.mark.asyncio + async def test_no_cross_workspace_contamination(self): + """ + Test that workspace B's migration doesn't include workspace A's data. + + Scenario: Two separate migrations for workspace_a and workspace_b. + Expected: Each workspace only gets its own data. + """ + db = AsyncMock() + + # Track which workspace is being queried + queried_workspace = None + + async def table_exists_side_effect(db_instance, name): + return "lightrag_doc_chunks" in name and "model" not in name + + async def query_side_effect(sql, params, **kwargs): + nonlocal queried_workspace + multirows = kwargs.get("multirows", False) + + if "information_schema.tables" in sql: + return {"exists": "lightrag_doc_chunks" in params[0]} + elif "COUNT(*)" in sql and "WHERE workspace" in sql: + queried_workspace = params[0] + return {"count": 1} + elif "COUNT(*)" in sql and "lightrag_doc_chunks_model_1536d" in sql: + return {"count": 1} # Verification count + elif "pg_attribute" in sql: + return {"vector_dim": 1536} + elif "SELECT * FROM" in sql and "WHERE workspace" in sql and multirows: + workspace = params[0] + if params[1] == 0: # offset = 0 + # Return data ONLY for the queried workspace + return [ + { + "id": f"{workspace}_1", + "workspace": workspace, + "content": f"content_{workspace}", + "content_vector": [0.1] * 1536, + } + ] + else: + return [] + return {} + + db.query.side_effect = query_side_effect + db.execute = AsyncMock() + db._create_vector_index = AsyncMock() + + from unittest.mock import patch + + with ( + patch( + "lightrag.kg.postgres_impl._pg_table_exists", + side_effect=table_exists_side_effect, + ), + patch("lightrag.kg.postgres_impl._pg_create_table", new=AsyncMock()), + ): + # Migrate workspace_b + await PGVectorStorage.setup_table( + db, + "lightrag_doc_chunks_model_1536d", + legacy_table_name="lightrag_doc_chunks", + base_table="lightrag_doc_chunks", + embedding_dim=1536, + workspace="workspace_b", + ) + + # Verify only workspace_b was queried + assert queried_workspace == "workspace_b", "Should only query workspace_b" + + # Verify INSERT contains workspace_b data only + insert_calls = [ + call + for call in db.execute.call_args_list + if call[0][0] and "INSERT INTO" in call[0][0] + ] + assert len(insert_calls) > 0, "Should have INSERT calls"