Skip to content

Commit a6abaaf

Browse files
author
Val Brodsky
committed
Migrate mmc classes
1 parent cde93dc commit a6abaaf

File tree

3 files changed

+29
-26
lines changed
  • libs/labelbox

3 files changed

+29
-26
lines changed

libs/labelbox/src/labelbox/data/annotation_types/mmc.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from abc import ABC
22
from typing import ClassVar, List, Union
33

4-
from labelbox import pydantic_compat
4+
from pydantic import field_validator
5+
56
from labelbox.utils import _CamelCaseMixin
67
from labelbox.data.annotation_types.annotation import BaseAnnotation
78

@@ -33,12 +34,15 @@ class MessageRankingTask(_BaseMessageEvaluationTask):
3334
format: ClassVar[str] = "message-ranking"
3435
ranked_messages: List[OrderedMessageInfo]
3536

36-
@pydantic_compat.validator("ranked_messages")
37+
@field_validator("ranked_messages")
3738
def _validate_ranked_messages(cls, v: List[OrderedMessageInfo]):
3839
if not {msg.order for msg in v} == set(range(1, len(v) + 1)):
39-
raise ValueError("Messages must be ordered by unique and consecutive natural numbers starting from 1")
40+
raise ValueError(
41+
"Messages must be ordered by unique and consecutive natural numbers starting from 1"
42+
)
4043
return v
4144

4245

4346
class MessageEvaluationTaskAnnotation(BaseAnnotation):
44-
value: Union[MessageSingleSelectionTask, MessageMultiSelectionTask, MessageRankingTask]
47+
value: Union[MessageSingleSelectionTask, MessageMultiSelectionTask,
48+
MessageRankingTask]

libs/labelbox/src/labelbox/data/serialization/ndjson/mmc.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,18 @@
22

33
from labelbox.utils import _CamelCaseMixin
44

5-
from .base import DataRow, NDAnnotation
5+
from .base import _SubclassRegistryBase, DataRow, NDAnnotation
66
from ...annotation_types.types import Cuid
77
from ...annotation_types.mmc import MessageSingleSelectionTask, MessageMultiSelectionTask, MessageRankingTask, MessageEvaluationTaskAnnotation
88

99

1010
class MessageTaskData(_CamelCaseMixin):
1111
format: str
12-
data: Union[MessageSingleSelectionTask, MessageMultiSelectionTask, MessageRankingTask]
12+
data: Union[MessageSingleSelectionTask, MessageMultiSelectionTask,
13+
MessageRankingTask]
1314

1415

15-
class NDMessageTask(NDAnnotation):
16+
class NDMessageTask(NDAnnotation, _SubclassRegistryBase):
1617

1718
message_evaluation_task: MessageTaskData
1819

@@ -26,17 +27,13 @@ def to_common(self) -> MessageEvaluationTaskAnnotation:
2627

2728
@classmethod
2829
def from_common(
29-
cls,
30-
annotation: MessageEvaluationTaskAnnotation,
31-
data: Any#Union[ImageData, TextData],
30+
cls,
31+
annotation: MessageEvaluationTaskAnnotation,
32+
data: Any #Union[ImageData, TextData],
3233
) -> "NDMessageTask":
33-
return cls(
34-
uuid=str(annotation._uuid),
35-
name=annotation.name,
36-
schema_id=annotation.feature_schema_id,
37-
data_row=DataRow(id=data.uid, global_key=data.global_key),
38-
message_evaluation_task=MessageTaskData(
39-
format=annotation.value.format,
40-
data=annotation.value
41-
)
42-
)
34+
return cls(uuid=str(annotation._uuid),
35+
name=annotation.name,
36+
schema_id=annotation.feature_schema_id,
37+
data_row=DataRow(id=data.uid, global_key=data.global_key),
38+
message_evaluation_task=MessageTaskData(
39+
format=annotation.value.format, data=annotation.value))

libs/labelbox/tests/data/serialization/ndjson/test_mmc.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,12 @@
33
import pytest
44

55
from labelbox.data.serialization import NDJsonConverter
6-
from labelbox.pydantic_compat import ValidationError
76

87

98
def test_message_task_annotation_serialization():
109
with open('tests/data/assets/ndjson/mmc_import.json', 'r') as file:
1110
data = json.load(file)
12-
11+
1312
deserialized = list(NDJsonConverter.deserialize(data))
1413
reserialized = list(NDJsonConverter.serialize(deserialized))
1514

@@ -19,9 +18,12 @@ def test_message_task_annotation_serialization():
1918
def test_mesage_ranking_task_wrong_order_serialization():
2019
with open('tests/data/assets/ndjson/mmc_import.json', 'r') as file:
2120
data = json.load(file)
22-
23-
some_ranking_task = next(task for task in data if task["messageEvaluationTask"]["format"] == "message-ranking")
24-
some_ranking_task["messageEvaluationTask"]["data"]["rankedMessages"][0]["order"] = 3
2521

26-
with pytest.raises(ValidationError):
22+
some_ranking_task = next(
23+
task for task in data
24+
if task["messageEvaluationTask"]["format"] == "message-ranking")
25+
some_ranking_task["messageEvaluationTask"]["data"]["rankedMessages"][0][
26+
"order"] = 3
27+
28+
with pytest.raises(ValueError):
2729
list(NDJsonConverter.deserialize([some_ranking_task]))

0 commit comments

Comments
 (0)