Skip to content

Commit e805fa7

Browse files
committed
use tfds to shuffle and split dataset
1 parent f76dc39 commit e805fa7

File tree

1 file changed

+10
-12
lines changed

1 file changed

+10
-12
lines changed

dataset/dataset_setup.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -778,20 +778,18 @@ def add_eos_batched(seqs):
778778
else:
779779
tokenized_dataset = hf_datasets.load_from_disk(os.path.join(data_dir, "fwedu_10B_tokenized"))
780780

781-
tokenized_dataset.to_tf_dataset()
782-
# Split in train and valid.
783-
print(type(tokenized_dataset))
784-
dataset_split_dict = tokenized_dataset.train_test_split(test_size=0.1, seed=42)
785-
train_dataset = dataset_split_dict['train']
786-
val_dataset = dataset_split_dict['test']
787-
print(type(train_dataset))
788-
789781
# Convert to tensorflow_datasets.Dataset objects
790-
train_dataset = train_dataset.to_tf_dataset()
791-
val_dataset = train_dataset.to_tf_dataset()
782+
tokenized_dataset = tokenized_dataset.to_tf_dataset()
792783

793-
# Save datasets
794-
train_dataset.Save(os.path.join(data_dir, "train"))
784+
# Shuffle dataset
785+
dataset_size = tokenized_dataset.cardinality().numpy()
786+
shuffled_dataset = tokenized_dataset.shuffle(dataset_size, seed=0)
787+
train_size = int(0.9 * dataset_size)
788+
train_dataset = shuffled_dataset.take(train_size)
789+
val_dataset = shuffled_dataset.skip(train_size)
790+
791+
# Split in train and valid.
792+
train_dataset.save(os.path.join(data_dir, "train"))
795793
val_dataset.save(os.path.join(data_dir, "val"))
796794

797795
return

0 commit comments

Comments
 (0)