|
10 | 10 | import time |
11 | 11 | import uuid |
12 | 12 | from abc import ABC, abstractmethod |
| 13 | +from collections.abc import Callable |
13 | 14 | from typing import Annotated, Any |
14 | 15 |
|
15 | 16 | from fastapi import Body |
|
63 | 64 | RetrieveFileContentRequest, |
64 | 65 | RetrieveFileRequest, |
65 | 66 | ) |
| 67 | +from llama_stack_api.filters import ( |
| 68 | + COMPARISON_FILTER_TYPES, |
| 69 | + COMPOUND_FILTER_TYPES, |
| 70 | +) |
66 | 71 | from llama_stack_api.internal.kvstore import KVStore |
67 | 72 |
|
68 | 73 | EMBEDDING_DIMENSION = 768 |
69 | 74 |
|
70 | 75 | logger = get_logger(name=__name__, category="providers::utils") |
71 | 76 |
|
| 77 | +# Comparison operators for filter matching (dispatch table) |
| 78 | +COMPARISON_OPERATORS: dict[str, Callable[[Any, Any], bool]] = { |
| 79 | + "eq": lambda mv, fv: bool(mv == fv), |
| 80 | + "ne": lambda mv, fv: bool(mv != fv), |
| 81 | + "gt": lambda mv, fv: bool(mv > fv), |
| 82 | + "gte": lambda mv, fv: bool(mv >= fv), |
| 83 | + "lt": lambda mv, fv: bool(mv < fv), |
| 84 | + "lte": lambda mv, fv: bool(mv <= fv), |
| 85 | + "in": lambda mv, fv: bool(isinstance(fv, list) and mv in fv), |
| 86 | + "nin": lambda mv, fv: bool(isinstance(fv, list) and mv not in fv), |
| 87 | +} |
| 88 | + |
72 | 89 | # Constants for OpenAI vector stores |
73 | 90 |
|
74 | 91 | VERSION = "v3" |
@@ -803,64 +820,102 @@ def _build_reranker_params( |
803 | 820 |
|
804 | 821 | return params |
805 | 822 |
|
| 823 | + def _matches_comparison_filter(self, metadata_value: Any, filter_type: str, filter_value: Any) -> bool: |
| 824 | + """Check if a metadata value matches a comparison filter. |
| 825 | +
|
| 826 | + Args: |
| 827 | + metadata_value: The value from the metadata to test |
| 828 | + filter_type: The comparison operator (eq, ne, gt, etc.) |
| 829 | + filter_value: The value to compare against |
| 830 | +
|
| 831 | + Returns: |
| 832 | + bool: True if the comparison matches, False otherwise |
| 833 | + """ |
| 834 | + if filter_type not in COMPARISON_OPERATORS: |
| 835 | + raise ValueError(f"Unsupported comparison filter type: {filter_type}") |
| 836 | + |
| 837 | + return COMPARISON_OPERATORS[filter_type](metadata_value, filter_value) |
| 838 | + |
| 839 | + def _matches_legacy_filter(self, metadata: dict[str, Any], filters: dict[str, Any]) -> bool: |
| 840 | + """Handle legacy filter format (direct key-value pairs without type field). |
| 841 | +
|
| 842 | + Args: |
| 843 | + metadata: The metadata to test against |
| 844 | + filters: Dict of key-value pairs to match |
| 845 | +
|
| 846 | + Returns: |
| 847 | + bool: True if all key-value pairs match the metadata |
| 848 | + """ |
| 849 | + for key, value in filters.items(): |
| 850 | + if key not in metadata or metadata[key] != value: |
| 851 | + return False |
| 852 | + return True |
| 853 | + |
| 854 | + def _matches_compound_filter( |
| 855 | + self, metadata: dict[str, Any], filter_type: str, sub_filters: list[dict[str, Any]] |
| 856 | + ) -> bool: |
| 857 | + """Handle compound filters (and/or logic). |
| 858 | +
|
| 859 | + Args: |
| 860 | + metadata: The metadata to test against |
| 861 | + filter_type: Either "and" or "or" |
| 862 | + sub_filters: List of filters to combine |
| 863 | +
|
| 864 | + Returns: |
| 865 | + bool: True if the compound filter matches |
| 866 | + """ |
| 867 | + if filter_type == "and": |
| 868 | + return all(self._matches_filters(metadata, f) for f in sub_filters) |
| 869 | + elif filter_type == "or": |
| 870 | + return any(self._matches_filters(metadata, f) for f in sub_filters) |
| 871 | + else: |
| 872 | + raise ValueError(f"Unsupported compound filter type: {filter_type}") |
| 873 | + |
806 | 874 | def _matches_filters(self, metadata: dict[str, Any], filters: dict[str, Any]) -> bool: |
807 | | - """Check if metadata matches the provided filters.""" |
| 875 | + """Check if metadata matches the provided filters. |
| 876 | +
|
| 877 | + This method supports multiple filter formats: |
| 878 | + - OpenAI-style typed filters with "type" field |
| 879 | + - Legacy direct key-value filters |
| 880 | + - Compound filters with nested logic |
| 881 | +
|
| 882 | + Args: |
| 883 | + metadata: The metadata to test against |
| 884 | + filters: The filter specification |
| 885 | +
|
| 886 | + Returns: |
| 887 | + bool: True if the metadata matches the filter criteria |
| 888 | + """ |
808 | 889 | if not filters: |
809 | 890 | return True |
810 | 891 |
|
811 | 892 | filter_type = filters.get("type") |
812 | 893 |
|
| 894 | + # Handle legacy format (no type field) |
813 | 895 | if filter_type is None: |
814 | 896 | if "key" not in filters and "value" not in filters and "filters" not in filters: |
815 | | - for key, value in filters.items(): |
816 | | - if key not in metadata: |
817 | | - return False |
818 | | - if metadata[key] != value: |
819 | | - return False |
820 | | - return True |
| 897 | + return self._matches_legacy_filter(metadata, filters) |
821 | 898 | else: |
822 | 899 | raise ValueError("Unsupported filter structure: missing 'type' field") |
823 | 900 |
|
824 | | - if filter_type in ["eq", "ne", "gt", "gte", "lt", "lte"]: |
825 | | - # Comparison filter |
| 901 | + # Handle comparison filters |
| 902 | + if filter_type in COMPARISON_FILTER_TYPES: |
826 | 903 | filter_key = filters.get("key") |
827 | | - value = filters.get("value") |
| 904 | + filter_value = filters.get("value") |
828 | 905 |
|
829 | | - if filter_key is None or not isinstance(filter_key, str): |
830 | | - return False |
831 | | - |
832 | | - if filter_key not in metadata: |
| 906 | + # Validate filter structure |
| 907 | + if not isinstance(filter_key, str) or filter_key not in metadata: |
833 | 908 | return False |
834 | 909 |
|
835 | 910 | metadata_value = metadata[filter_key] |
| 911 | + return self._matches_comparison_filter(metadata_value, filter_type, filter_value) |
836 | 912 |
|
837 | | - if filter_type == "eq": |
838 | | - return bool(metadata_value == value) |
839 | | - elif filter_type == "ne": |
840 | | - return bool(metadata_value != value) |
841 | | - elif filter_type == "gt": |
842 | | - return bool(metadata_value > value) |
843 | | - elif filter_type == "gte": |
844 | | - return bool(metadata_value >= value) |
845 | | - elif filter_type == "lt": |
846 | | - return bool(metadata_value < value) |
847 | | - elif filter_type == "lte": |
848 | | - return bool(metadata_value <= value) |
849 | | - else: |
850 | | - raise ValueError(f"Unsupported filter type: {filter_type}") |
851 | | - |
852 | | - elif filter_type == "and": |
853 | | - # All filters must match |
| 913 | + # Handle compound filters |
| 914 | + elif filter_type in COMPOUND_FILTER_TYPES: |
854 | 915 | sub_filters = filters.get("filters", []) |
855 | | - return all(self._matches_filters(metadata, f) for f in sub_filters) |
856 | | - |
857 | | - elif filter_type == "or": |
858 | | - # At least one filter must match |
859 | | - sub_filters = filters.get("filters", []) |
860 | | - return any(self._matches_filters(metadata, f) for f in sub_filters) |
| 916 | + return self._matches_compound_filter(metadata, filter_type, sub_filters) |
861 | 917 |
|
862 | 918 | else: |
863 | | - # Unknown filter type, default to no match |
864 | 919 | raise ValueError(f"Unsupported filter type: {filter_type}") |
865 | 920 |
|
866 | 921 | def _chunk_to_vector_store_content( |
|
0 commit comments