@@ -56,33 +56,41 @@ def test_batch_pad_sequence(self):
5656
5757
5858class TestPaddedCollateTiledImagesAndMask :
59+ img_shape = 1 , 1 , 1
60+ tokens_per_tile = 5
61+
5962 @pytest .fixture
6063 def batch (self ):
64+ c , h , w = self .img_shape
65+ s = self .tokens_per_tile
6166 return [
6267 {
6368 "tokens" : [1 , 2 , 1 , 3 ],
6469 "labels" : [4 , 5 , 6 , 7 ],
6570 "encoder_input" : {
66- "images" : [torch .ones (2 , 1 , 1 , 1 ), torch .ones (3 , 1 , 1 , 1 )],
71+ "images" : [torch .ones (2 , c , h , w ), torch .ones (3 , c , h , w )],
6772 "aspect_ratio" : [torch .tensor ([1 , 2 ]), torch .tensor ([1 , 3 ])],
6873 },
69- "encoder_mask" : [torch .ones (4 , 5 * 2 ), torch .ones (4 , 5 * 3 )],
74+ "encoder_mask" : [torch .ones (4 , s * 2 ), torch .ones (4 , s * 3 )],
7075 },
7176 {
7277 "tokens" : [1 , 4 ],
7378 "labels" : [8 , 9 ],
7479 "encoder_input" : {
75- "images" : [torch .ones (4 , 1 , 1 , 1 )],
80+ "images" : [torch .ones (4 , c , h , w )],
7681 "aspect_ratio" : [torch .tensor ([2 , 2 ])],
7782 },
78- "encoder_mask" : [torch .ones (2 , 5 * 4 )],
83+ "encoder_mask" : [torch .ones (2 , s * 4 )],
7984 },
8085 ]
8186
8287 def test_right_pad_sequence (self , batch ):
8388 actual = padded_collate_tiled_images_and_mask (
8489 batch = batch , padding_idx = 0 , ignore_idx = - 100 , pad_direction = "right"
8590 )
91+ imgs , tiles = actual ["encoder_input" ]["images" ].shape [1 :3 ]
92+ seq_len = actual ["encoder_mask" ].shape [- 1 ]
93+ assert imgs * tiles * self .tokens_per_tile == seq_len
8694
8795 mask_1 = torch .concat ([torch .ones (4 , 5 * 2 ), torch .zeros (4 , 10 )], dim = 1 )
8896 mask_2 = torch .concat ([torch .ones (4 , 5 * 3 ), torch .zeros (4 , 5 )], dim = 1 )
@@ -126,28 +134,36 @@ def test_left_pad_sequence(self, batch):
126134 ignore_idx = - 100 ,
127135 pad_direction = "left" ,
128136 pad_max_images = 4 ,
137+ pad_max_tiles = 5 ,
129138 )
139+ imgs , tiles = actual ["encoder_input" ]["images" ].shape [1 :3 ]
140+ seq_len = actual ["encoder_mask" ].shape [- 1 ]
141+ assert 5 * 4 * self .tokens_per_tile == seq_len
130142
131- mask_1 = torch .concat ([torch .ones (4 , 5 * 2 ), torch .zeros (4 , 10 )], dim = 1 )
132- mask_2 = torch .concat ([torch .ones (4 , 5 * 3 ), torch .zeros (4 , 5 )], dim = 1 )
143+ # pad 3 extra tiles
144+ mask_1 = torch .concat ([torch .ones (4 , 5 * 2 ), torch .zeros (4 , 5 * 3 )], dim = 1 )
145+ # pad 2 extra tiles
146+ mask_2 = torch .concat ([torch .ones (4 , 5 * 3 ), torch .zeros (4 , 5 * 2 )], dim = 1 )
147+ # Left pad text tokens
133148 mask_3 = torch .concat ([torch .zeros (2 , 20 ), torch .ones (2 , 5 * 4 )], dim = 0 )
149+ mask_3 = F .pad (mask_3 , (0 , 5 ), value = 0 ) # pad 5th tile
134150 sample_1 = torch .stack ([mask_1 , mask_2 ])
135- sample_2 = torch .stack ([mask_3 , torch .zeros (4 , 20 )])
151+ sample_2 = torch .stack ([mask_3 , torch .zeros (4 , 25 )])
136152 expected_mask = torch .stack ([sample_1 , sample_2 ]).view (2 , 4 , - 1 )
137- expected_mask = F .pad (expected_mask , (0 , 40 ), value = 0 )
153+ expected_mask = F .pad (expected_mask , (0 , 50 ), value = 0 )
138154
139155 expected = {
140156 "tokens" : torch .tensor ([[1 , 2 , 1 , 3 ], [0 , 0 , 1 , 4 ]]),
141157 "encoder_input" : {
142158 "images" : torch .tensor (
143159 [
144160 [
145- [[[[1.0 ]]], [[[1.0 ]]], [[[0.0 ]]], [[[0.0 ]]]],
146- [[[[1.0 ]]], [[[1.0 ]]], [[[1.0 ]]], [[[0.0 ]]]],
161+ [[[[1.0 ]]], [[[1.0 ]]], [[[0.0 ]]], [[[0.0 ]]], [[[ 0.0 ]]] ],
162+ [[[[1.0 ]]], [[[1.0 ]]], [[[1.0 ]]], [[[0.0 ]]], [[[ 0.0 ]]] ],
147163 ],
148164 [
149- [[[[1.0 ]]], [[[1.0 ]]], [[[1.0 ]]], [[[1.0 ]]]],
150- [[[[0.0 ]]], [[[0.0 ]]], [[[0.0 ]]], [[[0.0 ]]]],
165+ [[[[1.0 ]]], [[[1.0 ]]], [[[1.0 ]]], [[[1.0 ]]], [[[ 0.0 ]]] ],
166+ [[[[0.0 ]]], [[[0.0 ]]], [[[0.0 ]]], [[[0.0 ]]], [[[ 0.0 ]]] ],
151167 ],
152168 ]
153169 ),
0 commit comments