@@ -77,9 +77,29 @@ def _get_expected_mask_and_input_pos(
7777
7878 return mask [:max_seq_len , :max_seq_len ], torch .tensor (input_pos [:max_seq_len ])
7979
80+ def _calculate_num_packs (
81+ self , dataset_size , max_seq_len , sample_size , split_across_pack , max_packs
82+ ):
83+ # First see how many samples we can fit in a single pack
84+ num_samples_per_pack , remainder = divmod (max_seq_len , sample_size )
85+
86+ # If we split across pack (and the samples don't fit perfectly in max_seq_len), we can fit more
87+ if split_across_pack and remainder > 0 :
88+ # Now we need the fractional to see how many we can partially fit in each pack
89+ num_samples_per_pack = max_seq_len / sample_size
90+
91+ # If we don't split across pack, we will need more packs
92+ num_packs , remainder = divmod (dataset_size , num_samples_per_pack )
93+
94+ # If there's leftover, we need to add one more pack
95+ if remainder > 0 :
96+ num_packs += 1
97+
98+ return num_packs if num_packs < max_packs else max_packs
99+
80100 @pytest .mark .parametrize ("max_seq_len" , [25 ])
81101 @pytest .mark .parametrize ("sample_size" , [2 , 5 ])
82- @pytest .mark .parametrize ("max_packs" , [5 ])
102+ @pytest .mark .parametrize ("max_packs" , [5 , 200 ])
83103 @pytest .mark .parametrize ("split_across_pack" , [True , False ])
84104 def test_packed_dataset (
85105 self , max_seq_len , sample_size , max_packs , split_across_pack
@@ -91,8 +111,13 @@ def test_packed_dataset(
91111 max_packs = max_packs ,
92112 split_across_pack = split_across_pack ,
93113 )
114+
94115 # Check we get right number of packs
95- assert len (packed ) == max_packs
116+ correct_num_packs = self ._calculate_num_packs (
117+ len (dataset ), max_seq_len , sample_size , split_across_pack , max_packs
118+ )
119+ assert len (packed ) == correct_num_packs
120+
96121 # Check all fields are same length
97122 assert (
98123 len (packed [0 ]["tokens" ])
@@ -105,15 +130,15 @@ def test_packed_dataset(
105130 if split_across_pack :
106131 # If we split samples, we'll know how many samples by taking the
107132 # full length and dividing by sample size
108- last_index , remainder = divmod (max_packs * max_seq_len , sample_size )
133+ last_index , remainder = divmod (len ( packed ) * max_seq_len , sample_size )
109134 # Account for remaining sample that didn't fit in window
110135 last_index = last_index if remainder > 0 else last_index - 1
111136 else :
112137 # If we don't split samples, we know how many samples by taking
113138 # how much fits in a single window and multiplying by max rows.
114139 # If there is a remainder, this will end up being a pad token.
115140 last_index = (
116- (max_seq_len // sample_size ) * max_packs - 1
141+ (max_seq_len // sample_size ) * len ( packed ) - 1
117142 if max_seq_len % sample_size == 0
118143 else 0
119144 )
@@ -207,11 +232,11 @@ def test_packed_dataset_real_data(self):
207232
208233 def test_pad_pack (self ):
209234 padding_idx = - 8
210- ignore_idx = - 9
235+ ignore_idx = - 100 # Same as CROSS_ENTROPY_IGNORE_IDX
211236 pack = {
212237 "tokens" : [2 , 5 ],
213238 "labels" : [3 , 7 ],
214- "mask " : torch . tensor ([[ True , False ], [ True , True ]]) ,
239+ "seq_lens " : [ 1 , 1 ] ,
215240 # Let the first token be the end of the previous sample (pos 8),
216241 # and the second token the start of the next sample (pos 0). Collate
217242 # should continue from 0 -> 1, 2, ...
@@ -224,11 +249,11 @@ def test_pad_pack(self):
224249 max_seq_len = 4 ,
225250 )
226251
227- padded = packed ._pad_pack (pack , padding_idx = padding_idx , ignore_idx = ignore_idx )
252+ pack = packed ._convert_to_tensors (pack )
253+ padded = packed ._pad_pack (pack , padding_idx = padding_idx )
228254
229255 padded_input = padded ["tokens" ]
230256 padded_label = padded ["labels" ]
231- padded_mask = padded ["mask" ]
232257 padded_input_pos = padded ["input_pos" ]
233258
234259 torch .testing .assert_close (
@@ -237,15 +262,12 @@ def test_pad_pack(self):
237262 torch .testing .assert_close (
238263 padded_label , torch .tensor ([3 , 7 , ignore_idx , ignore_idx ])
239264 )
240- assert torch .equal (
241- padded_mask ,
242- torch .tensor (
243- [
244- [True , False , False , False ],
245- [True , True , False , False ],
246- [False , False , True , False ],
247- [False , False , False , True ],
248- ]
249- ),
250- )
251265 torch .testing .assert_close (padded_input_pos , torch .tensor ([8 , 0 , 1 , 2 ]))
266+
267+ def test_pack_errors_if_sample_too_long (self ):
268+ dataset = DummyDataset (8 )
269+ with pytest .raises (ValueError , match = "Dataset sample is too long" ):
270+ PackedDataset (
271+ dataset ,
272+ max_seq_len = 4 ,
273+ )
0 commit comments