Skip to content

Commit 65e27aa

Browse files
author
Val Brodsky
committed
Fix search filters
1 parent f811af2 commit 65e27aa

File tree

2 files changed

+59
-51
lines changed

2 files changed

+59
-51
lines changed

libs/labelbox/src/labelbox/schema/search_filters.py

Lines changed: 51 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,25 @@
11
import datetime
22
from enum import Enum
3-
from typing import List, Literal, Union
3+
from typing import List, Union
4+
from pydantic import PlainSerializer, BaseModel, Field
5+
6+
from typing_extensions import Annotated
47

58
from pydantic import BaseModel, Field, field_validator
69
from labelbox.schema.labeling_service_status import LabelingServiceStatus
710
from labelbox.utils import format_iso_datetime
8-
from pydantic.config import ConfigDict
911

1012

1113
class BaseSearchFilter(BaseModel):
1214
"""
1315
Shared code for all search filters
1416
"""
1517

16-
model_config = ConfigDict(use_enum_values=True)
17-
18-
def dict(self, *args, **kwargs):
19-
res = super().dict(*args, **kwargs)
20-
# go through all the keys and convert date to string
21-
for key in res:
22-
if isinstance(res[key], datetime.datetime):
23-
res[key] = format_iso_datetime(res[key])
24-
return res
18+
class Config:
19+
use_enum_values = True
2520

2621

27-
class OperationType(Enum):
22+
class OperationTypeEnum(Enum):
2823
"""
2924
Supported search entity types
3025
Each type corresponds to a different filter class
@@ -40,6 +35,13 @@ class OperationType(Enum):
4035
TaskRemainingCount = 'task_remaining_count'
4136

4237

38+
OperationType = Annotated[OperationTypeEnum,
39+
PlainSerializer(lambda x: x.value, return_type=str)]
40+
41+
IsoDatetimeType = Annotated[datetime.datetime,
42+
PlainSerializer(format_iso_datetime)]
43+
44+
4345
class IdOperator(Enum):
4446
"""
4547
Supported operators for ids like org ids, workspace ids, etc
@@ -75,7 +77,8 @@ class OrganizationFilter(BaseSearchFilter):
7577
"""
7678
Filter for organization to which projects belong
7779
"""
78-
operation: Literal[OperationType.Organization] = OperationType.Organization
80+
operation: OperationType = Field(default=OperationTypeEnum.Organization,
81+
serialization_alias='type')
7982
operator: IdOperator
8083
values: List[str]
8184

@@ -84,9 +87,10 @@ class SharedWithOrganizationFilter(BaseSearchFilter):
8487
"""
8588
Find project shared with the organization (i.e. not having this organization as a tenantId)
8689
"""
87-
operation: Literal[
88-
OperationType.
89-
SharedWithOrganization] = OperationType.SharedWithOrganization
90+
91+
operation: OperationType = Field(
92+
default=OperationTypeEnum.SharedWithOrganization,
93+
serialization_alias='type')
9094
operator: IdOperator
9195
values: List[str]
9296

@@ -95,7 +99,8 @@ class WorkspaceFilter(BaseSearchFilter):
9599
"""
96100
Filter for workspace
97101
"""
98-
operation: Literal[OperationType.Workspace] = OperationType.Workspace
102+
operation: OperationType = Field(default=OperationTypeEnum.Workspace,
103+
serialization_alias='type')
99104
operator: IdOperator
100105
values: List[str]
101106

@@ -105,7 +110,8 @@ class TagFilter(BaseSearchFilter):
105110
Filter for project tags
106111
values are tag ids
107112
"""
108-
operation: Literal[OperationType.Tag] = OperationType.Tag
113+
operation: OperationType = Field(default=OperationTypeEnum.Tag,
114+
serialization_alias='type')
109115
operator: IdOperator
110116
values: List[str]
111117

@@ -115,11 +121,12 @@ class ProjectStageFilter(BaseSearchFilter):
115121
Filter labelbox service / aka project stages
116122
Stages are: requested, in_progress, completed etc. as described by LabelingServiceStatus
117123
"""
118-
operation: Literal[OperationType.Stage] = OperationType.Stage
124+
operation: OperationType = Field(default=OperationTypeEnum.Stage,
125+
serialization_alias='type')
119126
operator: IdOperator
120127
values: List[LabelingServiceStatus]
121128

122-
@field_validator('values')
129+
@field_validator('values', mode='before')
123130
def validate_values(cls, values):
124131
disallowed_values = [LabelingServiceStatus.Missing]
125132
for value in values:
@@ -143,7 +150,7 @@ class DateValue(BaseSearchFilter):
143150
while the same string in EST will get converted to '2024-01-01T05:00:00Z'
144151
"""
145152
operator: RangeDateTimeOperatorWithSingleValue
146-
value: datetime.datetime
153+
value: IsoDatetimeType
147154

148155

149156
class IntegerValue(BaseSearchFilter):
@@ -155,28 +162,28 @@ class WorkforceStageUpdatedFilter(BaseSearchFilter):
155162
"""
156163
Filter for workforce stage updated date
157164
"""
158-
operation: Literal[
159-
OperationType.
160-
WorkforceStageUpdatedDate] = OperationType.WorkforceStageUpdatedDate
165+
operation: OperationType = Field(
166+
default=OperationTypeEnum.WorkforceStageUpdatedDate,
167+
serialization_alias='type')
161168
value: DateValue
162169

163170

164171
class WorkforceRequestedDateFilter(BaseSearchFilter):
165172
"""
166173
Filter for workforce requested date
167174
"""
168-
operation: Literal[
169-
OperationType.
170-
WorforceRequestedDate] = OperationType.WorforceRequestedDate
175+
operation: OperationType = Field(
176+
default=OperationTypeEnum.WorforceRequestedDate,
177+
serialization_alias='type')
171178
value: DateValue
172179

173180

174181
class DateRange(BaseSearchFilter):
175182
"""
176183
Date range for a search filter
177184
"""
178-
min: datetime.datetime
179-
max: datetime.datetime
185+
min: IsoDatetimeType
186+
max: IsoDatetimeType
180187

181188

182189
class DateRangeValue(BaseSearchFilter):
@@ -191,19 +198,19 @@ class WorkforceRequestedDateRangeFilter(BaseSearchFilter):
191198
"""
192199
Filter for workforce requested date range
193200
"""
194-
operation: Literal[
195-
OperationType.
196-
WorforceRequestedDate] = OperationType.WorforceRequestedDate
201+
operation: OperationType = Field(
202+
default=OperationTypeEnum.WorforceRequestedDate,
203+
serialization_alias='type')
197204
value: DateRangeValue
198205

199206

200207
class WorkforceStageUpdatedRangeFilter(BaseSearchFilter):
201208
"""
202209
Filter for workforce stage updated date range
203210
"""
204-
operation: Literal[
205-
OperationType.
206-
WorkforceStageUpdatedDate] = OperationType.WorkforceStageUpdatedDate
211+
operation: OperationType = Field(
212+
default=OperationTypeEnum.WorkforceStageUpdatedDate,
213+
serialization_alias='type')
207214
value: DateRangeValue
208215

209216

@@ -212,20 +219,19 @@ class TaskCompletedCountFilter(BaseSearchFilter):
212219
Filter for completed tasks count
213220
A task maps to a data row. Task completed should map to a data row in a labeling queue DONE
214221
"""
215-
operation: Literal[
216-
OperationType.TaskCompletedCount] = Field(default=OperationType.TaskCompletedCount, serialization_alias='type')
222+
operation: OperationType = Field(
223+
default=OperationTypeEnum.TaskCompletedCount,
224+
serialization_alias='type')
217225
value: IntegerValue
218226

219227

220-
221-
222-
223228
class TaskRemainingCountFilter(BaseSearchFilter):
224229
"""
225230
Filter for remaining tasks count. Reverse of TaskCompletedCountFilter
226231
"""
227-
operation: Literal[
228-
OperationType.TaskRemainingCount] = Field(OperationType.TaskRemainingCount, serialization_alias='type')
232+
operation: OperationType = Field(
233+
default=OperationTypeEnum.TaskRemainingCount,
234+
serialization_alias='type')
229235
value: IntegerValue
230236

231237

@@ -253,5 +259,7 @@ def build_search_filter(filter: List[SearchFilter]):
253259
"""
254260
Converts a list of search filters to a graphql string
255261
"""
256-
filters = [_dict_to_graphql_string(f.model_dump(by_alias=True)) for f in filter]
262+
filters = [
263+
_dict_to_graphql_string(f.model_dump(by_alias=True)) for f in filter
264+
]
257265
return "[" + ", ".join(filters) + "]"

libs/labelbox/tests/unit/test_unit_search_filters.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def test_id_filters():
2020

2121
assert build_search_filter(
2222
filters
23-
) == '[{operator: "is", values: ["clphb4vd7000cd2wv1ktu5cwa"], type: "organization_id"}, {operator: "is", values: ["clphb4vd7000cd2wv1ktu5cwa"], type: "shared_with_organizations"}, {operator: "is", values: ["clphb4vd7000cd2wv1ktu5cwa"], type: "workspace"}, {operator: "is", values: ["cls1vkrw401ab072vg2pq3t5d"], type: "tag"}, {operator: "is", values: ["REQUESTED"], type: "stage"}]'
23+
) == '[{type: "organization_id", operator: "is", values: ["clphb4vd7000cd2wv1ktu5cwa"]}, {type: "shared_with_organizations", operator: "is", values: ["clphb4vd7000cd2wv1ktu5cwa"]}, {type: "workspace", operator: "is", values: ["clphb4vd7000cd2wv1ktu5cwa"]}, {type: "tag", operator: "is", values: ["cls1vkrw401ab072vg2pq3t5d"]}, {type: "stage", operator: "is", values: ["REQUESTED"]}]'
2424

2525

2626
def test_stage_filter_with_invalid_values():
@@ -49,7 +49,7 @@ def test_date_filters():
4949
expected_start = format_iso_datetime(local_time_start)
5050
expected_end = format_iso_datetime(local_time_end)
5151

52-
expected = '[{value: {operator: "GREATER_THAN_OR_EQUAL", value: "' + expected_start + '"}, type: "workforce_requested_at"}, {value: {operator: "LESS_THAN_OR_EQUAL", value: "' + expected_end + '"}, type: "workforce_stage_updated_at"}]'
52+
expected = '[{type: "workforce_requested_at", value: {operator: "GREATER_THAN_OR_EQUAL", value: "' + expected_start + '"}}, {type: "workforce_stage_updated_at", value: {operator: "LESS_THAN_OR_EQUAL", value: "' + expected_end + '"}}]'
5353
assert build_search_filter(filters) == expected
5454

5555

@@ -70,16 +70,16 @@ def test_date_range_filters():
7070
]
7171
assert build_search_filter(
7272
filters
73-
) == '[{value: {operator: "BETWEEN", value: {min: "2024-01-01T08:00:00Z", max: "2025-01-01T08:00:00Z"}}, type: "workforce_requested_at"}, {value: {operator: "BETWEEN", value: {min: "2024-01-01T08:00:00Z", max: "2025-01-01T08:00:00Z"}}, type: "workforce_stage_updated_at"}]'
73+
) == '[{type: "workforce_requested_at", value: {operator: "BETWEEN", value: {min: "2024-01-01T08:00:00Z", max: "2025-01-01T08:00:00Z"}}}, {type: "workforce_stage_updated_at", value: {operator: "BETWEEN", value: {min: "2024-01-01T08:00:00Z", max: "2025-01-01T08:00:00Z"}}}]'
7474

7575

7676
def test_task_count_filters():
7777
filters = [
78-
TaskCompletedCountFilter(value=IntegerValue(operator=RangeOperatorWithSingleValue.GreaterThanOrEqual, value=1)),
79-
# TaskRemainingCountFilter(value=IntegerValue(
80-
# operator=RangeOperatorWithSingleValue.LessThanOrEqual, value=10)),
78+
TaskCompletedCountFilter(value=IntegerValue(
79+
operator=RangeOperatorWithSingleValue.GreaterThanOrEqual, value=1)),
80+
TaskRemainingCountFilter(value=IntegerValue(
81+
operator=RangeOperatorWithSingleValue.LessThanOrEqual, value=10)),
8182
]
8283

83-
# expected = '[{value: {operator: "GREATER_THAN_OR_EQUAL", value: 1}, type: "task_completed_count"}, {value: {operator: "LESS_THAN_OR_EQUAL", value: 10}, type: "task_remaining_count"}]'
84-
expected = '[{value: {operator: "GREATER_THAN_OR_EQUAL", value: 1}, type: "task_completed_count"}, {value: {operator: "LESS_THAN_OR_EQUAL", value: 10}, type: "task_remaining_count"}]'
84+
expected = '[{type: "task_completed_count", value: {operator: "GREATER_THAN_OR_EQUAL", value: 1}}, {type: "task_remaining_count", value: {operator: "LESS_THAN_OR_EQUAL", value: 10}}]'
8585
assert build_search_filter(filters) == expected

0 commit comments

Comments
 (0)