-
Notifications
You must be signed in to change notification settings - Fork 2.9k
Open
Description
Describe the bug
resume from ckpt skips samples if .map is applied
Maybe related: #7538
Steps to reproduce the bug
from datasets import Dataset
from datasets.distributed import split_dataset_by_node
# Create dataset with map transformation
def create_dataset():
ds = Dataset.from_dict({"id": list(range(100))})
ds = ds.to_iterable_dataset(num_shards=4)
ds = ds.map(lambda x: x) #comment it out to get desired behavior
ds = split_dataset_by_node(ds, rank=0, world_size=2)
return ds
ds = create_dataset()
# Iterate and save checkpoint after 10 samples
it = iter(ds)
for idx, sample in enumerate(it):
if idx == 9: # Checkpoint after 10 samples
checkpoint = ds.state_dict()
print(f"Checkpoint saved at sample: {sample['id']}")
break
# Continue with original iterator
original_next_samples = []
for idx, sample in enumerate(it):
original_next_samples.append(sample["id"])
if idx >= 4:
break
# Resume from checkpoint
ds_new = create_dataset()
ds_new.load_state_dict(checkpoint)
# Get samples from resumed iterator
it_new = iter(ds_new)
resumed_next_samples = []
for idx, sample in enumerate(it_new):
resumed_next_samples.append(sample["id"])
if idx >= 4:
break
print(f"\nExpected next samples: {original_next_samples}")
print(f"Actual next samples: {resumed_next_samples}")
print(
f"\n❌ BUG: {resumed_next_samples[0] - original_next_samples[0]} samples were skipped!"
)
With map
Checkpoint saved at sample: 9
Expected next samples: [10, 11, 12, 13, 14]
Actual next samples: [50, 51, 52, 53, 54]
❌ BUG: 40 samples were skipped!
Expected behavior
without map
Expected next samples: [10, 11, 12, 13, 14]
Actual next samples: [10, 11, 12, 13, 14]
❌ BUG: 0 samples were skipped!
Environment info
datasets == 3.6.0
Metadata
Metadata
Assignees
Labels
No labels