Skip to content

Commit 9eee28c

Browse files
cleanup
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
1 parent 4d8a34d commit 9eee28c

3 files changed

Lines changed: 176 additions & 58 deletions

File tree

src/llama_stack/core/routers/vector_io.py

Lines changed: 78 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,67 @@ async def insert_chunks(
135135
)
136136
return await self.routing_table.insert_chunks(vector_store_id, chunks, ttl_seconds)
137137

138+
def _parse_filter(self, filter_data: Any) -> Any:
139+
"""Recursively parse filter data into typed Filter objects.
140+
141+
This method handles the conversion of dictionary-based filter data into proper
142+
typed Filter objects (ComparisonFilter or CompoundFilter), supporting arbitrary
143+
nesting levels and pre-typed filter objects.
144+
145+
Args:
146+
filter_data: Filter data as dict, ComparisonFilter, or CompoundFilter
147+
148+
Returns:
149+
Typed filter object (ComparisonFilter or CompoundFilter)
150+
151+
Raises:
152+
ValueError: If filter data is invalid or has unsupported structure
153+
"""
154+
from llama_stack_api.filters import (
155+
ALL_FILTER_TYPES,
156+
COMPARISON_FILTER_TYPES,
157+
COMPOUND_FILTER_TYPES,
158+
ComparisonFilter,
159+
CompoundFilter,
160+
)
161+
162+
# Handle pre-typed filter objects - pass through unchanged
163+
if isinstance(filter_data, ComparisonFilter | CompoundFilter):
164+
return filter_data
165+
166+
# Validate that filter_data is a dictionary
167+
if not isinstance(filter_data, dict):
168+
raise ValueError("Filter must be a dict or typed Filter object")
169+
170+
filter_type = filter_data.get("type")
171+
if not filter_type:
172+
raise ValueError("Filter must have a 'type' field")
173+
174+
# Handle comparison filters
175+
if filter_type in COMPARISON_FILTER_TYPES:
176+
# Validate required fields for comparison filters
177+
if "key" not in filter_data or "value" not in filter_data:
178+
raise ValueError(f"Comparison filter '{filter_type}' must have 'key' and 'value' fields")
179+
return ComparisonFilter(**filter_data)
180+
181+
# Handle compound filters
182+
elif filter_type in COMPOUND_FILTER_TYPES:
183+
sub_filters_data = filter_data.get("filters", [])
184+
if not isinstance(sub_filters_data, list):
185+
raise ValueError(f"Compound filter '{filter_type}' must have a 'filters' list")
186+
187+
# Recursively parse all sub-filters
188+
parsed_sub_filters = []
189+
for sub_filter in sub_filters_data:
190+
parsed_sub_filters.append(self._parse_filter(sub_filter))
191+
192+
return CompoundFilter(type=filter_type, filters=parsed_sub_filters)
193+
194+
else:
195+
# Use the constants in error message for consistency
196+
supported_types = ", ".join(sorted(ALL_FILTER_TYPES))
197+
raise ValueError(f"Invalid filter type: '{filter_type}'. Supported types: {supported_types}")
198+
138199
async def query_chunks(
139200
self,
140201
vector_store_id: str,
@@ -143,26 +204,23 @@ async def query_chunks(
143204
) -> QueryChunksResponse:
144205
logger.debug(f"VectorIORouter.query_chunks: {vector_store_id}")
145206

146-
# Extract filters from params for internal processing
147-
filters = None
148-
if params and "filters" in params:
149-
from llama_stack_api.filters import ComparisonFilter, CompoundFilter
150-
151-
filter_data = params.pop("filters") # Remove from params to avoid duplication
152-
if isinstance(filter_data, dict):
153-
# Convert dict to typed filter
154-
if filter_data.get("type") in ["eq", "ne", "gt", "gte", "lt", "lte", "in", "nin"]:
155-
filters = ComparisonFilter(**filter_data)
156-
elif filter_data.get("type") in ["and", "or"]:
157-
# Handle nested filters recursively
158-
converted_sub_filters = []
159-
for sub_filter in filter_data.get("filters", []):
160-
if sub_filter.get("type") in ["eq", "ne", "gt", "gte", "lt", "lte", "in", "nin"]:
161-
converted_sub_filters.append(ComparisonFilter(**sub_filter))
162-
# Add more nesting support if needed
163-
filters = CompoundFilter(type=filter_data["type"], filters=converted_sub_filters)
164-
165-
return await self.routing_table.query_chunks(vector_store_id, query, params, filters)
207+
# Handle the no-filters case early
208+
if not params or "filters" not in params:
209+
return await self.routing_table.query_chunks(vector_store_id, query, params, None)
210+
211+
# Extract and parse filters from params
212+
# Create a shallow copy to avoid mutating the caller's dictionary
213+
params_copy = dict(params)
214+
filter_data = params_copy.pop("filters")
215+
216+
try:
217+
# Parse filter data with full recursive support
218+
parsed_filters = self._parse_filter(filter_data)
219+
except ValueError as e:
220+
logger.error(f"Invalid filter data: {e}")
221+
raise ValueError(f"Invalid filter: {e}") from e
222+
223+
return await self.routing_table.query_chunks(vector_store_id, query, params_copy, parsed_filters)
166224

167225
# OpenAI Vector Stores API endpoints
168226
async def openai_create_vector_store(

src/llama_stack/providers/utils/memory/openai_vector_store_mixin.py

Lines changed: 93 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import time
1111
import uuid
1212
from abc import ABC, abstractmethod
13+
from collections.abc import Callable
1314
from typing import Annotated, Any
1415

1516
from fastapi import Body
@@ -63,12 +64,28 @@
6364
RetrieveFileContentRequest,
6465
RetrieveFileRequest,
6566
)
67+
from llama_stack_api.filters import (
68+
COMPARISON_FILTER_TYPES,
69+
COMPOUND_FILTER_TYPES,
70+
)
6671
from llama_stack_api.internal.kvstore import KVStore
6772

6873
EMBEDDING_DIMENSION = 768
6974

7075
logger = get_logger(name=__name__, category="providers::utils")
7176

77+
# Comparison operators for filter matching (dispatch table)
78+
COMPARISON_OPERATORS: dict[str, Callable[[Any, Any], bool]] = {
79+
"eq": lambda mv, fv: bool(mv == fv),
80+
"ne": lambda mv, fv: bool(mv != fv),
81+
"gt": lambda mv, fv: bool(mv > fv),
82+
"gte": lambda mv, fv: bool(mv >= fv),
83+
"lt": lambda mv, fv: bool(mv < fv),
84+
"lte": lambda mv, fv: bool(mv <= fv),
85+
"in": lambda mv, fv: bool(isinstance(fv, list) and mv in fv),
86+
"nin": lambda mv, fv: bool(isinstance(fv, list) and mv not in fv),
87+
}
88+
7289
# Constants for OpenAI vector stores
7390

7491
VERSION = "v3"
@@ -803,64 +820,102 @@ def _build_reranker_params(
803820

804821
return params
805822

823+
def _matches_comparison_filter(self, metadata_value: Any, filter_type: str, filter_value: Any) -> bool:
824+
"""Check if a metadata value matches a comparison filter.
825+
826+
Args:
827+
metadata_value: The value from the metadata to test
828+
filter_type: The comparison operator (eq, ne, gt, etc.)
829+
filter_value: The value to compare against
830+
831+
Returns:
832+
bool: True if the comparison matches, False otherwise
833+
"""
834+
if filter_type not in COMPARISON_OPERATORS:
835+
raise ValueError(f"Unsupported comparison filter type: {filter_type}")
836+
837+
return COMPARISON_OPERATORS[filter_type](metadata_value, filter_value)
838+
839+
def _matches_legacy_filter(self, metadata: dict[str, Any], filters: dict[str, Any]) -> bool:
840+
"""Handle legacy filter format (direct key-value pairs without type field).
841+
842+
Args:
843+
metadata: The metadata to test against
844+
filters: Dict of key-value pairs to match
845+
846+
Returns:
847+
bool: True if all key-value pairs match the metadata
848+
"""
849+
for key, value in filters.items():
850+
if key not in metadata or metadata[key] != value:
851+
return False
852+
return True
853+
854+
def _matches_compound_filter(
855+
self, metadata: dict[str, Any], filter_type: str, sub_filters: list[dict[str, Any]]
856+
) -> bool:
857+
"""Handle compound filters (and/or logic).
858+
859+
Args:
860+
metadata: The metadata to test against
861+
filter_type: Either "and" or "or"
862+
sub_filters: List of filters to combine
863+
864+
Returns:
865+
bool: True if the compound filter matches
866+
"""
867+
if filter_type == "and":
868+
return all(self._matches_filters(metadata, f) for f in sub_filters)
869+
elif filter_type == "or":
870+
return any(self._matches_filters(metadata, f) for f in sub_filters)
871+
else:
872+
raise ValueError(f"Unsupported compound filter type: {filter_type}")
873+
806874
def _matches_filters(self, metadata: dict[str, Any], filters: dict[str, Any]) -> bool:
807-
"""Check if metadata matches the provided filters."""
875+
"""Check if metadata matches the provided filters.
876+
877+
This method supports multiple filter formats:
878+
- OpenAI-style typed filters with "type" field
879+
- Legacy direct key-value filters
880+
- Compound filters with nested logic
881+
882+
Args:
883+
metadata: The metadata to test against
884+
filters: The filter specification
885+
886+
Returns:
887+
bool: True if the metadata matches the filter criteria
888+
"""
808889
if not filters:
809890
return True
810891

811892
filter_type = filters.get("type")
812893

894+
# Handle legacy format (no type field)
813895
if filter_type is None:
814896
if "key" not in filters and "value" not in filters and "filters" not in filters:
815-
for key, value in filters.items():
816-
if key not in metadata:
817-
return False
818-
if metadata[key] != value:
819-
return False
820-
return True
897+
return self._matches_legacy_filter(metadata, filters)
821898
else:
822899
raise ValueError("Unsupported filter structure: missing 'type' field")
823900

824-
if filter_type in ["eq", "ne", "gt", "gte", "lt", "lte"]:
825-
# Comparison filter
901+
# Handle comparison filters
902+
if filter_type in COMPARISON_FILTER_TYPES:
826903
filter_key = filters.get("key")
827-
value = filters.get("value")
904+
filter_value = filters.get("value")
828905

829-
if filter_key is None or not isinstance(filter_key, str):
830-
return False
831-
832-
if filter_key not in metadata:
906+
# Validate filter structure
907+
if not isinstance(filter_key, str) or filter_key not in metadata:
833908
return False
834909

835910
metadata_value = metadata[filter_key]
911+
return self._matches_comparison_filter(metadata_value, filter_type, filter_value)
836912

837-
if filter_type == "eq":
838-
return bool(metadata_value == value)
839-
elif filter_type == "ne":
840-
return bool(metadata_value != value)
841-
elif filter_type == "gt":
842-
return bool(metadata_value > value)
843-
elif filter_type == "gte":
844-
return bool(metadata_value >= value)
845-
elif filter_type == "lt":
846-
return bool(metadata_value < value)
847-
elif filter_type == "lte":
848-
return bool(metadata_value <= value)
849-
else:
850-
raise ValueError(f"Unsupported filter type: {filter_type}")
851-
852-
elif filter_type == "and":
853-
# All filters must match
913+
# Handle compound filters
914+
elif filter_type in COMPOUND_FILTER_TYPES:
854915
sub_filters = filters.get("filters", [])
855-
return all(self._matches_filters(metadata, f) for f in sub_filters)
856-
857-
elif filter_type == "or":
858-
# At least one filter must match
859-
sub_filters = filters.get("filters", [])
860-
return any(self._matches_filters(metadata, f) for f in sub_filters)
916+
return self._matches_compound_filter(metadata, filter_type, sub_filters)
861917

862918
else:
863-
# Unknown filter type, default to no match
864919
raise ValueError(f"Unsupported filter type: {filter_type}")
865920

866921
def _chunk_to_vector_store_content(

src/llama_stack_api/filters.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,11 @@
2020

2121
from .schema_utils import json_schema_type
2222

23+
# Constants for filter type validation (using sets for O(1) membership testing)
24+
COMPARISON_FILTER_TYPES = frozenset(["eq", "ne", "gt", "gte", "lt", "lte", "in", "nin"])
25+
COMPOUND_FILTER_TYPES = frozenset(["and", "or"])
26+
ALL_FILTER_TYPES = COMPARISON_FILTER_TYPES | COMPOUND_FILTER_TYPES
27+
2328

2429
@json_schema_type
2530
class ComparisonFilter(BaseModel):

0 commit comments

Comments
 (0)