|
4 | 4 | # This source code is licensed under the BSD-style license found in the |
5 | 5 | # LICENSE file in the root directory of this source tree. |
6 | 6 |
|
| 7 | +import os |
| 8 | + |
7 | 9 | import pytest |
| 10 | +from PIL import Image |
| 11 | + |
| 12 | +from tests.common import ASSETS |
8 | 13 | from torchtune.data import ( |
| 14 | + format_content_with_images, |
9 | 15 | Message, |
10 | 16 | PromptTemplate, |
11 | | - split_text_by_image_tag, |
12 | 17 | truncate, |
13 | 18 | validate_messages, |
14 | 19 | ) |
15 | | -from torchtune.data._utils import _get_prompt_template |
| 20 | +from torchtune.data._utils import _get_prompt_template, load_image |
16 | 21 | from torchtune.models.llama2 import Llama2ChatTemplate |
17 | 22 |
|
18 | 23 |
|
@@ -98,47 +103,136 @@ def test_validate_messages(): |
98 | 103 | validate_messages(messages) |
99 | 104 |
|
100 | 105 |
|
101 | | -def test_split_text_by_image_tag(): |
| 106 | +def test_format_content_with_images(): |
| 107 | + test_image_1 = Image.new(mode="RGB", size=(4, 4)) |
| 108 | + test_image_2 = Image.new(mode="RGB", size=(4, 4)) |
| 109 | + test_image_3 = Image.new(mode="RGB", size=(4, 4)) |
| 110 | + |
102 | 111 | # Test single image tag in the middle |
103 | 112 | text = "hello <image>world" |
104 | | - assert split_text_by_image_tag(text, "<image>") == [ |
| 113 | + assert format_content_with_images( |
| 114 | + text, |
| 115 | + image_tag="<image>", |
| 116 | + images=[test_image_1], |
| 117 | + ) == [ |
105 | 118 | {"type": "text", "content": "hello "}, |
106 | | - {"type": "image"}, |
| 119 | + {"type": "image", "content": test_image_1}, |
107 | 120 | {"type": "text", "content": "world"}, |
108 | 121 | ] |
109 | 122 |
|
110 | 123 | # Test multiple image tags and image tag in beginning |
111 | 124 | text = "[image]hello [image]world" |
112 | | - assert split_text_by_image_tag(text, "[image]") == [ |
113 | | - {"type": "image"}, |
| 125 | + assert format_content_with_images( |
| 126 | + text, |
| 127 | + image_tag="[image]", |
| 128 | + images=[test_image_1, test_image_2], |
| 129 | + ) == [ |
| 130 | + {"type": "image", "content": test_image_1}, |
114 | 131 | {"type": "text", "content": "hello "}, |
115 | | - {"type": "image"}, |
| 132 | + {"type": "image", "content": test_image_2}, |
116 | 133 | {"type": "text", "content": "world"}, |
117 | 134 | ] |
118 | 135 |
|
119 | 136 | # Test an image tag that is not present in the text |
120 | 137 | text = "hello world" |
121 | | - assert split_text_by_image_tag(text, "asdfghjkl;") == [ |
| 138 | + assert format_content_with_images(text, image_tag="asdfghjkl;", images=[]) == [ |
122 | 139 | {"type": "text", "content": "hello world"} |
123 | 140 | ] |
124 | 141 |
|
125 | 142 | # Test consecutive image tags |
126 | 143 | text = "<image><image>hello <image>world" |
127 | | - assert split_text_by_image_tag(text, "<image>") == [ |
128 | | - {"type": "image"}, |
129 | | - {"type": "image"}, |
| 144 | + assert format_content_with_images( |
| 145 | + text, |
| 146 | + image_tag="<image>", |
| 147 | + images=[test_image_1, test_image_2, test_image_3], |
| 148 | + ) == [ |
| 149 | + {"type": "image", "content": test_image_1}, |
| 150 | + {"type": "image", "content": test_image_2}, |
130 | 151 | {"type": "text", "content": "hello "}, |
131 | | - {"type": "image"}, |
| 152 | + {"type": "image", "content": test_image_3}, |
132 | 153 | {"type": "text", "content": "world"}, |
133 | 154 | ] |
134 | 155 |
|
135 | 156 | # Test image tag at the end |
136 | 157 | text = "hello <image>" |
137 | | - assert split_text_by_image_tag(text, "<image>") == [ |
| 158 | + assert format_content_with_images( |
| 159 | + text, |
| 160 | + image_tag="<image>", |
| 161 | + images=[test_image_1], |
| 162 | + ) == [ |
138 | 163 | {"type": "text", "content": "hello "}, |
139 | | - {"type": "image"}, |
| 164 | + {"type": "image", "content": test_image_1}, |
140 | 165 | ] |
141 | 166 |
|
| 167 | + # Test errors when the number of images does not match the number of image tags |
| 168 | + text = "hello <image>world" |
| 169 | + with pytest.raises( |
| 170 | + ValueError, |
| 171 | + match="does not match number of image tags", |
| 172 | + ): |
| 173 | + format_content_with_images( |
| 174 | + text, image_tag="<image>", images=[test_image_1, test_image_2] |
| 175 | + ) |
| 176 | + |
| 177 | + |
| 178 | +def test_load_image(monkeypatch, tmp_path): |
| 179 | + tmp_image = str(ASSETS / "dog_on_skateboard.jpg") |
| 180 | + |
| 181 | + # Test loading from local file |
| 182 | + image = load_image(tmp_image) |
| 183 | + assert isinstance(image, Image.Image) |
| 184 | + assert image.size == (580, 403) |
| 185 | + |
| 186 | + # Test loading from remote file |
| 187 | + # Mock the urlopen function to return a BytesIO object |
| 188 | + def mock_urlopen(url): |
| 189 | + return open(tmp_image, "rb") |
| 190 | + |
| 191 | + monkeypatch.setattr("urllib.request.urlopen", mock_urlopen) |
| 192 | + image = load_image("http://example.com/test_image.jpg") |
| 193 | + assert isinstance(image, Image.Image) |
| 194 | + assert image.size == (580, 403) |
| 195 | + |
| 196 | + # Test that a ValueError is raised when the image path is invalid |
| 197 | + with pytest.raises(ValueError, match="Failed to open image as PIL.Image"): |
| 198 | + load_image("invalid_path") |
| 199 | + |
| 200 | + # Test a temporary file with invalid image data |
| 201 | + image_path = tmp_path / "test_image.jpg" |
| 202 | + with open(image_path, "w") as f: |
| 203 | + f.write("Invalid image data") |
| 204 | + |
| 205 | + # Test that a ValueError is raised when the image data is invalid |
| 206 | + with pytest.raises(ValueError, match="Failed to open image as PIL.Image"): |
| 207 | + load_image(str(image_path)) |
| 208 | + |
| 209 | + # Test that a ValueError is raised when there is an HTTP error |
| 210 | + # Mock the urlopen function to raise an exception |
| 211 | + def mock_urlopen(url): |
| 212 | + raise Exception("Failed to load image") |
| 213 | + |
| 214 | + monkeypatch.setattr("urllib.request.urlopen", mock_urlopen) |
| 215 | + with pytest.raises(ValueError, match="Failed to load image"): |
| 216 | + load_image("http://example.com/test_image.jpg") |
| 217 | + |
| 218 | + # Test that a ValueError is raised when there is an IO error |
| 219 | + # Create a temporary file that cannot be read |
| 220 | + image_path = tmp_path / "test_image.jpg" |
| 221 | + with open(image_path, "w") as f: |
| 222 | + f.write("Test data") |
| 223 | + os.chmod(image_path, 0o000) # Remove read permissions |
| 224 | + with pytest.raises(ValueError, match="Failed to open image as PIL.Image"): |
| 225 | + load_image(str(image_path)) |
| 226 | + os.chmod(image_path, 0o644) # Restore read permissions |
| 227 | + |
| 228 | + # Test that a ValueError is raised with invalid image data is read |
| 229 | + # Create a temporary file with invalid image data |
| 230 | + image_path = tmp_path / "test_image.jpg" |
| 231 | + with open(image_path, "wb") as f: |
| 232 | + f.write(b"Invalid image data") |
| 233 | + with pytest.raises(ValueError, match="Failed to open image as PIL.Image"): |
| 234 | + load_image(str(image_path)) |
| 235 | + |
142 | 236 |
|
143 | 237 | def test_get_prompt_template(): |
144 | 238 | template = _get_prompt_template("torchtune.models.llama2.Llama2ChatTemplate") |
|
0 commit comments