Skip to content

Commit f8073ed

Browse files
krammnicRdoubleAkrammnicjoecummingsSalmanMohammadi
authored
Add vqa_dataset, update docs (#1820)
Co-authored-by: Rafi Ayub <[email protected]> Co-authored-by: krammnic <[email protected]> Co-authored-by: Joe Cummings <[email protected]> Co-authored-by: Salman Mohammadi <[email protected]>
1 parent 7d29c21 commit f8073ed

File tree

6 files changed

+239
-9
lines changed

6 files changed

+239
-9
lines changed

tests/assets/rgb_pytorch.png

575 Bytes
Loading

tests/assets/vqa_tiny.json

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
[
2+
{
3+
"input": "What is presented on image?",
4+
"output": "PyTorch logo.",
5+
"image": "tests/assets/rgb_pytorch.png"
6+
}
7+
]
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import pytest
8+
from PIL.PngImagePlugin import PngImageFile
9+
from tests.common import ASSETS
10+
from tests.test_utils import DummyTokenizer
11+
12+
from torchtune.datasets.multimodal import vqa_dataset
13+
14+
15+
class TestMultimodalInstructDataset:
16+
@pytest.fixture
17+
def tokenizer(self):
18+
return DummyTokenizer()
19+
20+
def test_get_item(self, tokenizer):
21+
system_prompt = "follow this prompt"
22+
23+
dataset = vqa_dataset(
24+
model_transform=tokenizer,
25+
source="json",
26+
data_files=str(ASSETS / "vqa_tiny.json"),
27+
split="train",
28+
new_system_prompt=system_prompt,
29+
)
30+
31+
expected_tokens = [
32+
[0, 6, 4, 6, -2, 4, 2, 9, 2, 6, 7, 5, -1],
33+
]
34+
35+
expected_labels = [
36+
[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 7, 5, -1]
37+
]
38+
39+
assert len(dataset) == 1
40+
41+
for i in range(len(dataset)):
42+
prompt, label, image = (
43+
dataset[i]["tokens"],
44+
dataset[i]["labels"],
45+
dataset[i]["images"],
46+
)
47+
assert prompt == expected_tokens[i]
48+
assert label == expected_labels[i]
49+
assert isinstance(image[0], PngImageFile)

torchtune/data/_messages.py

Lines changed: 43 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,11 @@ class InputOutputToMessages(Transform):
161161
keeping the default "input" and "output" column names.
162162
new_system_prompt (Optional[str]): if specified, prepend a system message. This can
163163
serve as instructions to guide the model response. Default is None.
164+
image_dir (Optional[Path]): path to the directory containing the images that is prepended to all image
165+
paths in the dataset. For example, if ``image_dir="/home/user/dataset/"` and the sample image path
166+
was ``"images/1.jpg"``, the final image path that will be loaded is ``"/home/user/dataset/images/1.jpg"``.
167+
If None, assume images are available in current working directory or are located
168+
on a remote url. For text-only, leave as None. Default is None.
164169
165170
Raises:
166171
ValueError: If ``column_map`` is provided and ``input`` not in ``column_map``, or
@@ -172,33 +177,62 @@ def __init__(
172177
train_on_input: bool = False,
173178
column_map: Optional[Dict[str, str]] = None,
174179
new_system_prompt: Optional[str] = None,
180+
image_dir: Optional[Path] = None,
175181
):
176182
self.train_on_input = train_on_input
177183
self.new_system_prompt = new_system_prompt
178-
if column_map:
179-
if "input" not in column_map:
184+
185+
self.column_map = column_map
186+
187+
if self.column_map is not None:
188+
if "input" not in self.column_map:
180189
raise ValueError(
181-
f"Expected a key of 'input' in column_map but found {column_map.keys()}."
190+
f"Expected a key of 'input' in column_map but found {self.column_map.keys()}."
182191
)
183-
if "output" not in column_map:
192+
if "output" not in self.column_map:
184193
raise ValueError(
185-
f"Expected a key of 'output' in column_map but found {column_map.keys()}."
194+
f"Expected a key of 'output' in column_map but found {self.column_map.keys()}."
186195
)
187-
self._column_map = column_map
188196
else:
189-
self._column_map = {"input": "input", "output": "output"}
197+
self.column_map = {"input": "input", "output": "output", "image": "image"}
198+
199+
self.image_dir = image_dir
190200

191201
def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]:
202+
is_multimodal = "image" in sample or (
203+
"image" in self.column_map and self.column_map["image"] in sample
204+
)
205+
206+
if is_multimodal:
207+
image_path = sample[self.column_map["image"]]
208+
if isinstance(image_path, str):
209+
if self.image_dir is not None:
210+
image_path = self.image_dir / image_path
211+
# Load if not loaded
212+
pil_image = load_image(image_path)
213+
else:
214+
pil_image = image_path
215+
content = [
216+
{"type": "image", "content": pil_image},
217+
{"type": "text", "content": sample[self.column_map["input"]]},
218+
]
219+
else:
220+
content = [{"type": "text", "content": sample[self.column_map["input"]]}]
221+
222+
output_content = [
223+
{"type": "text", "content": sample[self.column_map["output"]]}
224+
]
225+
192226
messages = [
193227
Message(
194228
role="user",
195-
content=sample[self._column_map["input"]],
229+
content=content,
196230
masked=not self.train_on_input,
197231
eot=True,
198232
),
199233
Message(
200234
role="assistant",
201-
content=sample[self._column_map["output"]],
235+
content=output_content,
202236
masked=False,
203237
eot=True,
204238
),

torchtune/datasets/multimodal/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@
77
from ._llava_instruct import llava_instruct_dataset
88
from ._multimodal import multimodal_chat_dataset
99
from ._the_cauldron import the_cauldron_dataset
10+
from ._vqa import vqa_dataset
1011

1112
__all__ = [
1213
"the_cauldron_dataset",
1314
"llava_instruct_dataset",
1415
"multimodal_chat_dataset",
16+
"vqa_dataset",
1517
]
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Any, Callable, Dict, Optional
8+
9+
from torchtune.data import InputOutputToMessages
10+
from torchtune.datasets._sft import SFTDataset
11+
from torchtune.modules.transforms import Transform
12+
13+
14+
def vqa_dataset(
15+
model_transform: Transform,
16+
*,
17+
source: str,
18+
image_dir: str = None,
19+
column_map: Optional[Dict[str, str]] = None,
20+
new_system_prompt: Optional[str] = None,
21+
filter_fn: Optional[Callable] = None,
22+
split: str = "train",
23+
**load_dataset_kwargs: Dict[str, Any],
24+
) -> SFTDataset:
25+
"""
26+
Configure a custom visual question answer dataset with separate columns for user question, image, and model response.
27+
28+
This builder function can be used to configure a custom visual question answer dataset directly from the yaml config
29+
as an alternative to :class:`~torchtune.datasets.SFTDataset`, as it is made to be config friendly.
30+
31+
The dataset should follow this format:
32+
33+
.. code-block:: text
34+
35+
| input | image | output |
36+
|-----------------|-----------------|------------------|
37+
| "user prompt" | images/1.jpg | "model response" |
38+
39+
If your column names are different, you can use the ``column_map`` parameter to change
40+
the expected column names. For example, if your dataset has columns ``"question"``,
41+
``"answer"`` and ``"picture"`` you can use:
42+
43+
column_map = {"input": "question", "output": "answer", "image": "picture"}
44+
45+
Args:
46+
model_transform (Transform): callable that applies model-specific pre-processing to the sample.
47+
This includes tokenization and any modality-specific transforms. It is expected to return at
48+
minimum ``"tokens"`` and ``"mask"`` keys.
49+
source (str): path to dataset repository on Hugging Face. For local datasets,
50+
define source as the data file type (e.g. "json", "csv", "text"), pass
51+
in the filepath in ``data_files``, and set ``split="train"``. See `Hugging Face's
52+
<https://huggingface.co/docs/datasets/en/package_reference/loading_methods#datasets.load_dataset.path>`_
53+
``load_dataset`` for more details.
54+
image_dir (str): path to the directory containing the images that is prepended to all image
55+
paths in the dataset. For example, if ``image_dir="/home/user/dataset/"` and the sample image path
56+
was ``"images/1.jpg"``, the final image path that will be loaded is ``"/home/user/dataset/images/1.jpg"``.
57+
If None, assume images are available in current working directory or are located
58+
on a remote url. For text-only, leave as None. Default is None.
59+
column_map (Optional[Dict[str, str]]): a mapping to change the expected "input",
60+
"output", and "image" column names to the actual column names in the dataset. Keys should be "input",
61+
"output", and "image, and values should be the actual column names.
62+
Default is None, keeping the default "input" and "output", and "image" column names.
63+
new_system_prompt (Optional[str]): if specified, prepend a system message. This can
64+
serve as instructions to guide the model response. Setting this will OVERRIDE any system
65+
messages already present in the dataset. Default is None.
66+
filter_fn (Optional[Callable]): callable used to filter the dataset prior to any pre-processing. See
67+
the Hugging Face `docs <https://huggingface.co/docs/datasets/v2.20.0/process#select-and-filter>`_ for more
68+
details.
69+
split (str): ``split`` argument for ``datasets.load_dataset``. You can use this argument to load a subset
70+
of a given split, e.g. ``split="train[:10%]"``. Default is "train".
71+
**load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to ``load_dataset``,
72+
such as ``data_files`` or ``split``.
73+
74+
Examples:
75+
76+
::
77+
78+
my_dataset.json
79+
[
80+
{
81+
"question": "What is presented on the image?",
82+
"answer": "PyTorch logo.",
83+
"picture": "rgb_pytorch.png"
84+
},
85+
{
86+
...
87+
},
88+
...,
89+
]
90+
91+
::
92+
93+
>>> from torchtune.datasets.multimodal import vqa_dataset
94+
>>> dataset = vqa_dataset(
95+
... model_transform=model_transform,
96+
... source="json",
97+
... data_files="my_dataset.json",
98+
... column_map={
99+
... "input": "question",
100+
... "output": "answer",
101+
... "image": "picture"
102+
... },
103+
... split="train",
104+
... )
105+
>>> tokens = dataset[0]["tokens"]
106+
>>> model_transform.decode(tokens)
107+
"What is presented on the image?PyTorch logo."
108+
109+
This can also be accomplished via the yaml config:
110+
111+
.. code-block:: yaml
112+
113+
dataset:
114+
_component_: torchtune.datasets.multimodal.vqa_dataset
115+
source: json
116+
data_files: my_dataset.json
117+
column_map:
118+
input: question
119+
output: answer
120+
image: picture
121+
split: train
122+
123+
Returns:
124+
SFTDataset: the configured :class:`~torchtune.datasets.SFTDataset`
125+
"""
126+
message_transform = InputOutputToMessages(
127+
column_map=column_map, new_system_prompt=new_system_prompt, image_dir=image_dir
128+
)
129+
130+
ds = SFTDataset(
131+
source=source,
132+
message_transform=message_transform,
133+
model_transform=model_transform,
134+
filter_fn=filter_fn,
135+
split=split,
136+
**load_dataset_kwargs,
137+
)
138+
return ds

0 commit comments

Comments
 (0)