@@ -778,20 +778,18 @@ def add_eos_batched(seqs):
778
778
else :
779
779
tokenized_dataset = hf_datasets .load_from_disk (os .path .join (data_dir , "fwedu_10B_tokenized" ))
780
780
781
- tokenized_dataset .to_tf_dataset ()
782
- # Split in train and valid.
783
- print (type (tokenized_dataset ))
784
- dataset_split_dict = tokenized_dataset .train_test_split (test_size = 0.1 , seed = 42 )
785
- train_dataset = dataset_split_dict ['train' ]
786
- val_dataset = dataset_split_dict ['test' ]
787
- print (type (train_dataset ))
788
-
789
781
# Convert to tensorflow_datasets.Dataset objects
790
- train_dataset = train_dataset .to_tf_dataset ()
791
- val_dataset = train_dataset .to_tf_dataset ()
782
+ tokenized_dataset = tokenized_dataset .to_tf_dataset ()
792
783
793
- # Save datasets
794
- train_dataset .Save (os .path .join (data_dir , "train" ))
784
+ # Shuffle dataset
785
+ dataset_size = tokenized_dataset .cardinality ().numpy ()
786
+ shuffled_dataset = tokenized_dataset .shuffle (dataset_size , seed = 0 )
787
+ train_size = int (0.9 * dataset_size )
788
+ train_dataset = shuffled_dataset .take (train_size )
789
+ val_dataset = shuffled_dataset .skip (train_size )
790
+
791
+ # Split in train and valid.
792
+ train_dataset .save (os .path .join (data_dir , "train" ))
795
793
val_dataset .save (os .path .join (data_dir , "val" ))
796
794
797
795
return
0 commit comments