Skip to content

Commit c3ff864

Browse files
authored
Model transform docs (#1665)
1 parent 18efc81 commit c3ff864

File tree

3 files changed

+172
-1
lines changed

3 files changed

+172
-1
lines changed
Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
.. _model_transform_usage_label:
2+
3+
=====================
4+
Multimodal Transforms
5+
=====================
6+
7+
Multimodal model transforms apply model-specific data transforms to each modality and prepares :class:`~torchtune.data.Message`
8+
objects to be input into the model. torchtune currently supports text + image model transforms.
9+
These are intended to be drop-in replacements for tokenizers in multimodal datasets and support the standard
10+
``encode``, ``decode``, and ``tokenize_messages``.
11+
12+
.. code-block:: python
13+
14+
# torchtune.models.flamingo.FlamingoTransform
15+
class FlamingoTransform(ModelTokenizer, Transform):
16+
def __init__(...):
17+
# Text transform - standard tokenization
18+
self.tokenizer = llama3_tokenizer(...)
19+
# Image transforms
20+
self.transform_image = CLIPImageTransform(...)
21+
self.xattn_mask = VisionCrossAttentionMask(...)
22+
23+
24+
.. code-block:: python
25+
26+
from torchtune.models.flamingo import FlamingoTransform
27+
from torchtune.data import Message
28+
from PIL import Image
29+
30+
sample = {
31+
"messages": [
32+
Message(
33+
role="user",
34+
content=[
35+
{"type": "image", "content": Image.new(mode="RGB", size=(224, 224))},
36+
{"type": "image", "content": Image.new(mode="RGB", size=(224, 224))},
37+
{"type": "text", "content": "What is common in these two images?"},
38+
],
39+
),
40+
Message(
41+
role="assistant",
42+
content="A robot is in both images.",
43+
),
44+
],
45+
}
46+
transform = FlamingoTransform(
47+
path="/tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model",
48+
tile_size=224,
49+
patch_size=14,
50+
)
51+
tokenized_dict = transform(sample)
52+
print(transform.decode(tokenized_dict["tokens"], skip_special_tokens=False))
53+
# '<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n<|image|><|image|>What is common in these two images?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nA robot is in both images.<|eot_id|>'
54+
print(tokenized_dict["encoder_input"]["images"][0].shape) # (num_tiles, num_channels, tile_height, tile_width)
55+
# torch.Size([4, 3, 224, 224])
56+
57+
58+
Using model transforms
59+
----------------------
60+
You can pass them into any multimodal dataset builder just as you would a model tokenizer.
61+
62+
.. code-block:: python
63+
64+
from torchtune.datasets.multimodal import the_cauldron_dataset
65+
from torchtune.models.flamingo import FlamingoTransform
66+
67+
transform = FlamingoTransform(
68+
path="/tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model",
69+
tile_size=224,
70+
patch_size=14,
71+
)
72+
ds = the_cauldron_dataset(
73+
model_transform=transform,
74+
subset="ai2d",
75+
)
76+
tokenized_dict = ds[0]
77+
print(transform.decode(tokenized_dict["tokens"], skip_special_tokens=False))
78+
# <|begin_of_text|><|start_header_id|>user<|end_header_id|>
79+
#
80+
# <|image|>Question: What do respiration and combustion give out
81+
# Choices:
82+
# A. Oxygen
83+
# B. Carbon dioxide
84+
# C. Nitrogen
85+
# D. Heat
86+
# Answer with the letter.<|eot_id|><|start_header_id|>assistant<|end_header_id|>
87+
#
88+
# Answer: B<|eot_id|>
89+
print(tokenized_dict["encoder_input"]["images"][0].shape) # (num_tiles, num_channels, tile_height, tile_width)
90+
# torch.Size([4, 3, 224, 224])
91+
92+
Creating model transforms
93+
-------------------------
94+
Model transforms are expected to process both text and images in the sample dictionary.
95+
Both should be contained in the ``"messages"`` field of the sample.
96+
97+
The following methods are required on the model transform:
98+
99+
- ``tokenize_messages``
100+
- ``__call__``
101+
102+
.. code-block:: python
103+
104+
from torchtune.modules.tokenizers import ModelTokenizer
105+
from torchtune.modules.transforms import Transform
106+
107+
class MyMultimodalTransform(ModelTokenizer, Transform):
108+
def __init__(...):
109+
self.tokenizer = my_tokenizer_builder(...)
110+
self.transform_image = MyImageTransform(...)
111+
112+
def tokenize_messages(
113+
self,
114+
messages: List[Message],
115+
add_eos: bool = True,
116+
) -> Tuple[List[int], List[bool]]:
117+
# Any other custom logic here
118+
...
119+
120+
return self.tokenizer.tokenize_messages(
121+
messages=messages,
122+
add_eos=add_eos,
123+
)
124+
125+
def __call__(
126+
self, sample: Mapping[str, Any], inference: bool = False
127+
) -> Mapping[str, Any]:
128+
# Expected input parameters for vision encoder
129+
encoder_input = {"images": [], "aspect_ratio": []}
130+
messages = sample["messages"]
131+
132+
# Transform all images in sample
133+
for message in messages:
134+
for image in message.get_media():
135+
out = self.transform_image({"image": image}, inference=inference)
136+
encoder_input["images"].append(out["image"])
137+
encoder_input["aspect_ratio"].append(out["aspect_ratio"])
138+
sample["encoder_input"] = encoder_input
139+
140+
# Transform all text - returns same dictionary with additional keys "tokens" and "mask"
141+
sample = self.tokenizer(sample, inference=inference)
142+
143+
return sample
144+
145+
transform = MyMultimodalTransform(...)
146+
sample = {
147+
"messages": [
148+
Message(
149+
role="user",
150+
content=[
151+
{"type": "image", "content": Image.new(mode="RGB", size=(224, 224))},
152+
{"type": "image", "content": Image.new(mode="RGB", size=(224, 224))},
153+
{"type": "text", "content": "What is common in these two images?"},
154+
],
155+
),
156+
Message(
157+
role="assistant",
158+
content="A robot is in both images.",
159+
),
160+
],
161+
}
162+
tokenized_dict = transform(sample)
163+
print(tokenized_dict)
164+
# {'encoder_input': {'images': ..., 'aspect_ratio': ...}, 'tokens': ..., 'mask': ...}
165+
166+
167+
Example model transforms
168+
------------------------
169+
- Flamingo
170+
- :class:`~torchtune.models.flamingo.FlamingoTransform`

docs/source/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ torchtune tutorials.
120120
basics/tokenizers
121121
basics/prompt_templates
122122
basics/preference_datasets
123+
basics/model_transforms
123124

124125
.. toctree::
125126
:glob:

torchtune/models/flamingo/_transform.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class FlamingoTransform(ModelTokenizer, Transform):
2727
2828
Args:
2929
path (str): Path to pretrained tiktoken tokenizer file.
30-
tile_size (int): Size of the tiles to divide the image into. Default 224.
30+
tile_size (int): Size of the tiles to divide the image into.
3131
patch_size (int): Size of the patches used in the CLIP vision tranformer model. This is
3232
used to calculate the number of image embeddings per image.
3333
max_num_tiles (int): Only used if possible_resolutions is NOT given.

0 commit comments

Comments
 (0)