Skip to content

Commit df4f01a

Browse files
felipemello1Felipe Mello
authored andcommitted
Resizable image positional embeddings (#1695)
Co-authored-by: Felipe Mello <[email protected]>
1 parent ff2a85f commit df4f01a

File tree

3 files changed

+787
-27
lines changed

3 files changed

+787
-27
lines changed

docs/source/api_ref_models.rst

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -320,16 +320,16 @@ To download the Gemma 7B model:
320320
gemma.gemma_tokenizer
321321

322322

323-
.. clip
324-
.. -----
323+
clip
324+
-----
325325

326-
.. Vision components to support multimodality using `CLIP encoder <https://arxiv.org/abs/2103.00020>`_.
326+
Vision components to support multimodality using `CLIP encoder <https://arxiv.org/abs/2103.00020>`_.
327327

328-
.. .. autosummary::
329-
.. :toctree: generated/
330-
.. :nosignatures:
328+
.. autosummary::
329+
:toctree: generated/
330+
:nosignatures:
331331

332-
.. clip.clip_vision_encoder
333-
.. clip.TokenPositionalEmbedding
334-
.. clip.TiledTokenPositionalEmbedding
335-
.. clip.TilePositionalEmbedding
332+
clip.clip_vision_encoder
333+
clip.TokenPositionalEmbedding
334+
clip.TiledTokenPositionalEmbedding
335+
clip.TilePositionalEmbedding
Lines changed: 349 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,349 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import math
8+
9+
import pytest
10+
import torch
11+
12+
from tests.test_utils import assert_expected
13+
14+
from torchtune.models.clip._position_embeddings import (
15+
TiledTokenPositionalEmbedding,
16+
TilePositionalEmbedding,
17+
)
18+
19+
# generated comparing vs fairinternal/internal-llama-models
20+
tile_pos_emb_test_cases = [
21+
{
22+
"tgt_max_num_tiles": 1,
23+
"input_tensor": torch.tensor(
24+
[[[[0.0, 1.0]], [[2.0, 3.0]]], [[[4.0, 5.0]], [[6.0, 7.0]]]]
25+
),
26+
"expected_output": torch.tensor([[[[0.0, 1.0]]]]),
27+
},
28+
{
29+
"tgt_max_num_tiles": 3,
30+
"input_tensor": torch.tensor([[[[0.0]]]]),
31+
"expected_output": torch.tensor(
32+
[
33+
[[[0.0]], [[0.0]], [[0.0]]],
34+
[[[0.0]], [[0.0]], [[0.0]]],
35+
[[[0.0]], [[0.0]], [[0.0]]],
36+
]
37+
),
38+
},
39+
{
40+
"tgt_max_num_tiles": 2,
41+
"input_tensor": torch.tensor(
42+
[
43+
[[[0.0, 1.0]], [[2.0, 3.0]], [[4.0, 5.0]]],
44+
[[[6.0, 7.0]], [[8.0, 9.0]], [[10.0, 11.0]]],
45+
[[[12.0, 13.0]], [[14.0, 15.0]], [[16.0, 17.0]]],
46+
]
47+
),
48+
"expected_output": torch.tensor(
49+
[[[[0.0, 1.0]], [[4.0, 5.0]]], [[[12.0, 13.0]], [[16.0, 17.0]]]]
50+
),
51+
},
52+
]
53+
54+
local_pos_emb_test_cases = [
55+
{
56+
"tgt_patch_grid_size": 2,
57+
"expected_shape": torch.Size([5, 2]),
58+
"input_tensor": torch.tensor(
59+
[[0.0, 1.0], [2.0, 3.0], [4.0, 5.0], [6.0, 7.0], [8.0, 9.0]]
60+
),
61+
"expected_output": torch.tensor(
62+
[[0.0, 1.0], [2.0, 3.0], [4.0, 5.0], [6.0, 7.0], [8.0, 9.0]]
63+
),
64+
},
65+
{
66+
"tgt_patch_grid_size": 1,
67+
"expected_shape": torch.Size([2, 1]),
68+
"input_tensor": torch.tensor([[0.0], [1.0], [2.0], [3.0], [4.0]]),
69+
"expected_output": torch.tensor([[0.0], [1.0]]),
70+
},
71+
{
72+
"tgt_patch_grid_size": 2,
73+
"expected_shape": torch.Size([5, 2]),
74+
"input_tensor": torch.tensor([[0.0, 1.0], [2.0, 3.0]]),
75+
"expected_output": torch.tensor(
76+
[[0.0, 1.0], [2.0, 3.0], [2.0, 3.0], [2.0, 3.0], [2.0, 3.0]]
77+
),
78+
},
79+
]
80+
81+
global_pos_emb_test_cases = [
82+
{
83+
"tgt_max_num_tiles": 1,
84+
"tgt_patch_grid_size": 2,
85+
"input_tensor": torch.tensor(
86+
[
87+
[
88+
[[0.0, 1.0], [2.0, 3.0], [4.0, 5.0], [6.0, 7.0], [8.0, 9.0]],
89+
[
90+
[10.0, 11.0],
91+
[12.0, 13.0],
92+
[14.0, 15.0],
93+
[16.0, 17.0],
94+
[18.0, 19.0],
95+
],
96+
],
97+
[
98+
[
99+
[20.0, 21.0],
100+
[22.0, 23.0],
101+
[24.0, 25.0],
102+
[26.0, 27.0],
103+
[28.0, 29.0],
104+
],
105+
[
106+
[30.0, 31.0],
107+
[32.0, 33.0],
108+
[34.0, 35.0],
109+
[36.0, 37.0],
110+
[38.0, 39.0],
111+
],
112+
],
113+
]
114+
),
115+
"expected_output": torch.tensor(
116+
[[[[0.0, 1.0], [2.0, 3.0], [14.0, 15.0], [26.0, 27.0], [38.0, 39.0]]]]
117+
),
118+
},
119+
{
120+
"tgt_max_num_tiles": 3,
121+
"tgt_patch_grid_size": 1,
122+
"input_tensor": torch.tensor([[[[0.0], [1.0], [2.0], [3.0], [4.0]]]]),
123+
"expected_output": torch.tensor(
124+
[
125+
[[[0.0000], [1.0000]], [[0.0000], [1.5000]], [[0.0000], [2.0000]]],
126+
[[[0.0000], [2.0000]], [[0.0000], [2.5000]], [[0.0000], [3.0000]]],
127+
[[[0.0000], [3.0000]], [[0.0000], [3.5000]], [[0.0000], [4.0000]]],
128+
]
129+
),
130+
},
131+
{
132+
"tgt_max_num_tiles": 2,
133+
"tgt_patch_grid_size": 2,
134+
"input_tensor": torch.tensor(
135+
[
136+
[
137+
[[0.0, 1.0], [2.0, 3.0]],
138+
[[4.0, 5.0], [6.0, 7.0]],
139+
[[8.0, 9.0], [10.0, 11.0]],
140+
],
141+
[
142+
[[12.0, 13.0], [14.0, 15.0]],
143+
[[16.0, 17.0], [18.0, 19.0]],
144+
[[20.0, 21.0], [22.0, 23.0]],
145+
],
146+
[
147+
[[24.0, 25.0], [26.0, 27.0]],
148+
[[28.0, 29.0], [30.0, 31.0]],
149+
[[32.0, 33.0], [34.0, 35.0]],
150+
],
151+
]
152+
),
153+
"expected_output": torch.tensor(
154+
[
155+
[
156+
[
157+
[0.0000, 1.0000],
158+
[2.0000, 3.0000],
159+
[4.6667, 5.6667],
160+
[10.0000, 11.0000],
161+
[12.6667, 13.6667],
162+
],
163+
[
164+
[8.0000, 9.0000],
165+
[7.3333, 8.3333],
166+
[10.0000, 11.0000],
167+
[15.3333, 16.3333],
168+
[18.0000, 19.0000],
169+
],
170+
],
171+
[
172+
[
173+
[24.0000, 25.0000],
174+
[18.0000, 19.0000],
175+
[20.6667, 21.6667],
176+
[26.0000, 27.0000],
177+
[28.6667, 29.6667],
178+
],
179+
[
180+
[32.0000, 33.0000],
181+
[23.3333, 24.3333],
182+
[26.0000, 27.0000],
183+
[31.3333, 32.3333],
184+
[34.0000, 35.0000],
185+
],
186+
],
187+
]
188+
),
189+
},
190+
]
191+
192+
193+
class TestPositionalEmbeddingsInterpolation:
194+
@pytest.mark.parametrize("params", tile_pos_emb_test_cases)
195+
def test_tile_resize_position_embedding(self, params):
196+
tgt_max_num_tiles = params["tgt_max_num_tiles"]
197+
expected_output = params["expected_output"]
198+
embedding = params["input_tensor"]
199+
200+
resized_pos_embed = TilePositionalEmbedding._resize_position_embedding(
201+
embedding, tgt_max_num_tiles
202+
)
203+
204+
assert_expected(resized_pos_embed, expected_output, atol=1e-3, rtol=1e-4)
205+
206+
@pytest.mark.parametrize("params", local_pos_emb_test_cases)
207+
def test_resize_local_position_embedding(self, params):
208+
input_tensor = params["input_tensor"]
209+
tgt_patch_grid_size = params["tgt_patch_grid_size"]
210+
expected_output = params["expected_output"]
211+
212+
resized_pos_embed = (
213+
TiledTokenPositionalEmbedding._resize_local_position_embedding(
214+
input_tensor, tgt_patch_grid_size
215+
)
216+
)
217+
218+
assert_expected(resized_pos_embed, expected_output, atol=1e-3, rtol=1e-4)
219+
220+
@pytest.mark.parametrize("params", global_pos_emb_test_cases)
221+
def test_resize_global_position_embedding(self, params):
222+
input_tensor = params["input_tensor"]
223+
tgt_max_num_tiles = params["tgt_max_num_tiles"]
224+
tgt_patch_grid_size = params["tgt_patch_grid_size"]
225+
expected_output = params["expected_output"]
226+
227+
resized_pos_embed = (
228+
TiledTokenPositionalEmbedding._resize_global_position_embedding(
229+
input_tensor, tgt_max_num_tiles, tgt_patch_grid_size
230+
)
231+
)
232+
233+
assert_expected(resized_pos_embed, expected_output, atol=1e-3, rtol=1e-4)
234+
235+
@pytest.mark.parametrize(
236+
"local_params, global_params",
237+
zip(local_pos_emb_test_cases, global_pos_emb_test_cases),
238+
)
239+
def test_load_state_dict_hook_tiled_token(self, local_params, global_params):
240+
# Corrected parameters for instantiation
241+
global_max_num_tiles = global_params["expected_output"].shape[0]
242+
global_embed_dim = global_params["expected_output"].shape[-1]
243+
n_tokens_per_tile = local_params["expected_output"].shape[
244+
0
245+
] # Assuming first dimension is tokens per tile
246+
patch_grid_size = int(math.sqrt(n_tokens_per_tile - 1))
247+
tile_size = patch_grid_size * 1 # Assuming patch_size is 1 for simplicity
248+
patch_size = 1
249+
250+
# Instantiate the model
251+
model = TiledTokenPositionalEmbedding(
252+
max_num_tiles=global_max_num_tiles,
253+
embed_dim=global_embed_dim,
254+
tile_size=tile_size,
255+
patch_size=patch_size,
256+
)
257+
258+
# Create state_dict mimicking loading scenario
259+
state_dict = {
260+
"model.local_token_positional_embedding": local_params["input_tensor"],
261+
"model.global_token_positional_embedding": global_params["input_tensor"],
262+
}
263+
264+
# Call the hook directly (simulating loading process)
265+
state_dict_copy = state_dict.copy()
266+
model._load_state_dict_hook(state_dict_copy, "model.")
267+
268+
# Assert expected outputs
269+
assert_expected(
270+
state_dict_copy["model.local_token_positional_embedding"],
271+
local_params["expected_output"],
272+
atol=1e-3,
273+
rtol=1e-4,
274+
)
275+
assert_expected(
276+
state_dict_copy["model.global_token_positional_embedding"],
277+
global_params["expected_output"],
278+
atol=1e-3,
279+
rtol=1e-4,
280+
)
281+
282+
# Check for errors with non-square token grid sizes
283+
with pytest.raises(ValueError):
284+
bad_state_dict = state_dict.copy()
285+
286+
# add +1 to num_token dimension to make it non-square
287+
local_pos_emb = bad_state_dict["model.local_token_positional_embedding"]
288+
bad_local_pos_emb = torch.cat(
289+
(local_pos_emb, local_pos_emb[0].unsqueeze(0)), dim=0
290+
)
291+
bad_state_dict["model.local_token_positional_embedding"] = bad_local_pos_emb
292+
293+
# call
294+
model._load_state_dict_hook(bad_state_dict, "model.")
295+
296+
# Check for errors with non-square tile grid sizes
297+
with pytest.raises(ValueError):
298+
bad_state_dict = state_dict.copy()
299+
300+
# add +1 to num_token dimension to make it non-square
301+
global_pos_emb = bad_state_dict["model.global_token_positional_embedding"]
302+
bad_global_pos_emb = torch.cat(
303+
(global_pos_emb, global_pos_emb[:, :, [0]]), dim=2
304+
)
305+
bad_state_dict[
306+
"model.global_token_positional_embedding"
307+
] = bad_global_pos_emb
308+
309+
# call
310+
model._load_state_dict_hook(bad_state_dict, "model.")
311+
312+
@pytest.mark.parametrize("params", tile_pos_emb_test_cases)
313+
def test_load_state_dict_hook_tile(self, params):
314+
315+
# Extract parameters for instantiation
316+
max_num_tiles = params["expected_output"].shape[0]
317+
embed_dim = params["expected_output"].shape[-1]
318+
319+
# Instantiate the model
320+
model = TilePositionalEmbedding(
321+
max_num_tiles=max_num_tiles,
322+
embed_dim=embed_dim,
323+
)
324+
# Create state_dict mimicking loading scenario
325+
state_dict = {
326+
"model.embedding": params["input_tensor"],
327+
}
328+
329+
# Call the hook
330+
state_dict_copy = state_dict.copy()
331+
model._load_state_dict_hook(state_dict_copy, "model.")
332+
333+
# Assert expected outputs
334+
assert_expected(
335+
state_dict_copy["model.embedding"],
336+
params["expected_output"],
337+
atol=1e-3,
338+
rtol=1e-4,
339+
)
340+
341+
# Check for errors with non-square tile grid sizes
342+
with pytest.raises(ValueError):
343+
bad_state_dict = state_dict.copy()
344+
# Manipulate the tensor to have non-equal max_num_tiles_x and max_num_tiles_y
345+
bad_tensor = torch.cat(
346+
(params["input_tensor"], params["input_tensor"][:, [0], :, :]), dim=1
347+
)
348+
bad_state_dict["model.embedding"] = bad_tensor
349+
model._load_state_dict_hook(bad_state_dict, "model.")

0 commit comments

Comments
 (0)