Skip to content

Commit eb92658

Browse files
authored
Images in Messages (#1504)
1 parent 6deeda9 commit eb92658

File tree

15 files changed

+331
-93
lines changed

15 files changed

+331
-93
lines changed

docs/source/api_ref_data.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,3 +88,5 @@ Miscellaneous helper functions used in modifying data.
8888

8989
validate_messages
9090
truncate
91+
load_image
92+
format_content_with_images

tests/assets/dog_on_skateboard.jpg

39.7 KB
Loading

tests/test_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,10 +145,14 @@ def tokenize_messages(
145145
return tokenized_messages, mask
146146

147147
def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]:
148-
messages = sample.pop("messages")
148+
messages: List[Message] = sample.pop("messages")
149+
images = []
150+
for message in messages:
151+
images += message.get_media()
149152
tokens, mask = self.tokenize_messages(messages)
150153
sample["tokens"] = tokens
151154
sample["mask"] = mask
155+
sample["images"] = images
152156
return sample
153157

154158
@property

tests/torchtune/data/test_data_utils.py

Lines changed: 109 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,20 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import os
8+
79
import pytest
10+
from PIL import Image
11+
12+
from tests.common import ASSETS
813
from torchtune.data import (
14+
format_content_with_images,
915
Message,
1016
PromptTemplate,
11-
split_text_by_image_tag,
1217
truncate,
1318
validate_messages,
1419
)
15-
from torchtune.data._utils import _get_prompt_template
20+
from torchtune.data._utils import _get_prompt_template, load_image
1621
from torchtune.models.llama2 import Llama2ChatTemplate
1722

1823

@@ -98,47 +103,136 @@ def test_validate_messages():
98103
validate_messages(messages)
99104

100105

101-
def test_split_text_by_image_tag():
106+
def test_format_content_with_images():
107+
test_image_1 = Image.new(mode="RGB", size=(4, 4))
108+
test_image_2 = Image.new(mode="RGB", size=(4, 4))
109+
test_image_3 = Image.new(mode="RGB", size=(4, 4))
110+
102111
# Test single image tag in the middle
103112
text = "hello <image>world"
104-
assert split_text_by_image_tag(text, "<image>") == [
113+
assert format_content_with_images(
114+
text,
115+
image_tag="<image>",
116+
images=[test_image_1],
117+
) == [
105118
{"type": "text", "content": "hello "},
106-
{"type": "image"},
119+
{"type": "image", "content": test_image_1},
107120
{"type": "text", "content": "world"},
108121
]
109122

110123
# Test multiple image tags and image tag in beginning
111124
text = "[image]hello [image]world"
112-
assert split_text_by_image_tag(text, "[image]") == [
113-
{"type": "image"},
125+
assert format_content_with_images(
126+
text,
127+
image_tag="[image]",
128+
images=[test_image_1, test_image_2],
129+
) == [
130+
{"type": "image", "content": test_image_1},
114131
{"type": "text", "content": "hello "},
115-
{"type": "image"},
132+
{"type": "image", "content": test_image_2},
116133
{"type": "text", "content": "world"},
117134
]
118135

119136
# Test an image tag that is not present in the text
120137
text = "hello world"
121-
assert split_text_by_image_tag(text, "asdfghjkl;") == [
138+
assert format_content_with_images(text, image_tag="asdfghjkl;", images=[]) == [
122139
{"type": "text", "content": "hello world"}
123140
]
124141

125142
# Test consecutive image tags
126143
text = "<image><image>hello <image>world"
127-
assert split_text_by_image_tag(text, "<image>") == [
128-
{"type": "image"},
129-
{"type": "image"},
144+
assert format_content_with_images(
145+
text,
146+
image_tag="<image>",
147+
images=[test_image_1, test_image_2, test_image_3],
148+
) == [
149+
{"type": "image", "content": test_image_1},
150+
{"type": "image", "content": test_image_2},
130151
{"type": "text", "content": "hello "},
131-
{"type": "image"},
152+
{"type": "image", "content": test_image_3},
132153
{"type": "text", "content": "world"},
133154
]
134155

135156
# Test image tag at the end
136157
text = "hello <image>"
137-
assert split_text_by_image_tag(text, "<image>") == [
158+
assert format_content_with_images(
159+
text,
160+
image_tag="<image>",
161+
images=[test_image_1],
162+
) == [
138163
{"type": "text", "content": "hello "},
139-
{"type": "image"},
164+
{"type": "image", "content": test_image_1},
140165
]
141166

167+
# Test errors when the number of images does not match the number of image tags
168+
text = "hello <image>world"
169+
with pytest.raises(
170+
ValueError,
171+
match="does not match number of image tags",
172+
):
173+
format_content_with_images(
174+
text, image_tag="<image>", images=[test_image_1, test_image_2]
175+
)
176+
177+
178+
def test_load_image(monkeypatch, tmp_path):
179+
tmp_image = str(ASSETS / "dog_on_skateboard.jpg")
180+
181+
# Test loading from local file
182+
image = load_image(tmp_image)
183+
assert isinstance(image, Image.Image)
184+
assert image.size == (580, 403)
185+
186+
# Test loading from remote file
187+
# Mock the urlopen function to return a BytesIO object
188+
def mock_urlopen(url):
189+
return open(tmp_image, "rb")
190+
191+
monkeypatch.setattr("urllib.request.urlopen", mock_urlopen)
192+
image = load_image("http://example.com/test_image.jpg")
193+
assert isinstance(image, Image.Image)
194+
assert image.size == (580, 403)
195+
196+
# Test that a ValueError is raised when the image path is invalid
197+
with pytest.raises(ValueError, match="Failed to open image as PIL.Image"):
198+
load_image("invalid_path")
199+
200+
# Test a temporary file with invalid image data
201+
image_path = tmp_path / "test_image.jpg"
202+
with open(image_path, "w") as f:
203+
f.write("Invalid image data")
204+
205+
# Test that a ValueError is raised when the image data is invalid
206+
with pytest.raises(ValueError, match="Failed to open image as PIL.Image"):
207+
load_image(str(image_path))
208+
209+
# Test that a ValueError is raised when there is an HTTP error
210+
# Mock the urlopen function to raise an exception
211+
def mock_urlopen(url):
212+
raise Exception("Failed to load image")
213+
214+
monkeypatch.setattr("urllib.request.urlopen", mock_urlopen)
215+
with pytest.raises(ValueError, match="Failed to load image"):
216+
load_image("http://example.com/test_image.jpg")
217+
218+
# Test that a ValueError is raised when there is an IO error
219+
# Create a temporary file that cannot be read
220+
image_path = tmp_path / "test_image.jpg"
221+
with open(image_path, "w") as f:
222+
f.write("Test data")
223+
os.chmod(image_path, 0o000) # Remove read permissions
224+
with pytest.raises(ValueError, match="Failed to open image as PIL.Image"):
225+
load_image(str(image_path))
226+
os.chmod(image_path, 0o644) # Restore read permissions
227+
228+
# Test that a ValueError is raised with invalid image data is read
229+
# Create a temporary file with invalid image data
230+
image_path = tmp_path / "test_image.jpg"
231+
with open(image_path, "wb") as f:
232+
f.write(b"Invalid image data")
233+
with pytest.raises(ValueError, match="Failed to open image as PIL.Image"):
234+
load_image(str(image_path))
235+
142236

143237
def test_get_prompt_template():
144238
template = _get_prompt_template("torchtune.models.llama2.Llama2ChatTemplate")

tests/torchtune/data/test_messages.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import pytest
8+
9+
from PIL import Image
810
from tests.test_utils import (
911
assert_dialogue_equal,
1012
CHAT_SAMPLE,
@@ -26,17 +28,21 @@ def text_message(self):
2628
return Message(role="user", content="hello world")
2729

2830
@pytest.fixture
29-
def image_message(self):
31+
def test_image(self):
32+
return Image.new(mode="RGB", size=(4, 4))
33+
34+
@pytest.fixture
35+
def image_message(self, test_image):
3036
return Message(
3137
role="user",
3238
content=[
3339
{"type": "text", "content": "hello"},
34-
{"type": "image"},
40+
{"type": "image", "content": test_image},
3541
{"type": "text", "content": " world"},
3642
],
3743
)
3844

39-
def test_message_validation(self, text_message):
45+
def test_message_validation(self, text_message, test_image):
4046
message = text_message
4147
assert message.role == "user"
4248
assert message.content == [{"type": "text", "content": "hello world"}]
@@ -53,7 +59,7 @@ def test_message_validation(self, text_message):
5359
):
5460
message = Message(
5561
role="user",
56-
content=[{"type": "image"}],
62+
content=[{"type": "image", "content": test_image}],
5763
ipython=True,
5864
)
5965

@@ -69,6 +75,10 @@ def test_contains_media(self, text_message, image_message):
6975
assert not text_message.contains_media
7076
assert image_message.contains_media
7177

78+
def test_get_media(self, text_message, image_message, test_image):
79+
assert text_message.get_media() == []
80+
assert image_message.get_media() == [test_image]
81+
7282
def test_text_content(self, text_message, image_message):
7383
assert text_message.text_content == "hello world"
7484
assert image_message.text_content == "hello world"

tests/torchtune/datasets/multimodal/test_llava_instruct_dataset.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from collections import Counter
88
from unittest.mock import patch
99

10+
import PIL
11+
1012
import pytest
1113
from datasets import Dataset
1214

@@ -21,11 +23,22 @@ class TestLLaVAInstructDataset:
2123
def tokenizer(self):
2224
return DummyTokenizer()
2325

26+
@pytest.fixture
27+
def test_image_pil(self):
28+
return PIL.Image.new(mode="RGB", size=(4, 4))
29+
2430
@patch("torchtune.datasets._sft.load_dataset")
25-
def test_label_no_masking(self, load_dataset, tokenizer):
31+
@patch("torchtune.datasets.multimodal._llava_instruct.load_image")
32+
def test_label_no_masking(
33+
self, load_image, load_dataset, tokenizer, test_image_pil
34+
):
2635
"""
2736
Test whether the input and the labels are correctly created when the input is not masked.
37+
38+
WARNING: careful with these mocks, they are applied in bottom up order
2839
"""
40+
# mock the call to load_image
41+
load_image.return_value = test_image_pil
2942

3043
# mock the call to HF datasets
3144
load_dataset.return_value = Dataset.from_list(
@@ -55,6 +68,7 @@ def test_label_no_masking(self, load_dataset, tokenizer):
5568
model_transform=tokenizer,
5669
train_on_input=True,
5770
)
71+
5872
input, labels, images = ds[0]["tokens"], ds[0]["labels"], ds[0]["images"]
5973

6074
expected_count = {
@@ -76,13 +90,18 @@ def test_label_no_masking(self, load_dataset, tokenizer):
7690

7791
assert Counter(input) == expected_count
7892
assert Counter(labels) == expected_count
79-
assert images == "test_image.jpg"
93+
assert images == [test_image_pil]
8094

8195
@patch("torchtune.datasets._sft.load_dataset")
82-
def test_label_masking(self, load_dataset, tokenizer):
96+
@patch("torchtune.datasets.multimodal._llava_instruct.load_image")
97+
def test_label_masking(self, load_image, load_dataset, tokenizer, test_image_pil):
8398
"""
8499
Test whether the input and the labels are correctly created when the input is masked.
100+
101+
WARNING: careful with these mocks, they are applied in bottom up order
85102
"""
103+
# mock the call to load_image
104+
load_image.return_value = test_image_pil
86105

87106
# mock the call to HF datasets
88107
load_dataset.return_value = Dataset.from_list(
@@ -133,4 +152,4 @@ def test_label_masking(self, load_dataset, tokenizer):
133152

134153
assert Counter(input) == expected_count
135154
assert labels.count(CROSS_ENTROPY_IGNORE_IDX) == 11
136-
assert images == "test_image.jpg"
155+
assert images == [test_image_pil]

0 commit comments

Comments
 (0)