@@ -80,7 +80,7 @@ def test_invalid_index_type(self, datasets):
8080 with pytest .raises (TypeError ):
8181 multi_dataset ["invalid_type" ] # Non-integer index
8282
83- def test_packed_dataset (self , torch_datasets ):
83+ def test_single_packed_dataset (self , torch_datasets ):
8484 torch_datasets [0 ] = PackedDataset (
8585 torch_datasets [0 ],
8686 max_seq_len = 25 ,
@@ -90,3 +90,33 @@ def test_packed_dataset(self, torch_datasets):
9090
9191 with pytest .raises (ValueError ):
9292 concated_dataset = ConcatDataset (torch_datasets )
93+
94+ def test_all_packed_datasets (self , torch_datasets ):
95+ for i in range (len (torch_datasets )):
96+ torch_datasets [i ] = PackedDataset (
97+ torch_datasets [i ],
98+ max_seq_len = 2000 ,
99+ max_packs = 16 ,
100+ split_across_pack = True ,
101+ )
102+ concated_dataset = ConcatDataset (torch_datasets )
103+ assert concated_dataset .packed
104+
105+ # 2k tokens per pack
106+ # 1st ds has 4k tokens, 2nd ds has 8k tokens, 3rd ds has 15k tokens
107+ # 4th ds has 16k tokens, 5th ds has 23k tokens, 6th ds has 42k tokens
108+
109+ assert concated_dataset [0 ]["seq_lens" ][0 ] == 4
110+ # 2nd packed ds starts at idx 2
111+ assert concated_dataset [2 ]["seq_lens" ][0 ] == 8
112+ # 3rd packed ds starts at idx 6
113+ assert concated_dataset [6 ]["seq_lens" ][0 ] == 15
114+ # 4th packed ds starts at idx 14
115+ assert concated_dataset [14 ]["seq_lens" ][0 ] == 16
116+ # 5th packed ds starts at idx 22
117+ assert concated_dataset [22 ]["seq_lens" ][0 ] == 23
118+ # 6th packed ds starts at idx 34
119+ assert concated_dataset [34 ]["seq_lens" ][0 ] == 42
120+
121+ # Total length is 2 + 4 + 8 + 8 + 12 + 16 (because of max_packs) = 50
122+ assert len (concated_dataset ) == 50
0 commit comments