Skip to content

Commit fcdc8b9

Browse files
Add tests for out of order with checkpointing (#1428)
* Add tests for out of order with checkpointing * add warning logs back * update test cases
1 parent cad6dbe commit fcdc8b9

File tree

1 file changed

+131
-0
lines changed

1 file changed

+131
-0
lines changed

test/stateful_dataloader/test_state_dict.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
import itertools
88
import json
9+
import math
10+
import time
911
import unittest
1012
from copy import deepcopy
1113

@@ -1632,5 +1634,134 @@ def test_mp(self):
16321634
self._run_test(2, CountIterCallsIter(100))
16331635

16341636

1637+
class _TestSlowIndexDataset(torch.utils.data.Dataset):
1638+
def __init__(self, end: int, slow_index: int):
1639+
self.end = end
1640+
self.slow_index = slow_index
1641+
self._worker_id = None
1642+
1643+
def __getitem__(self, idx):
1644+
if idx == self.slow_index:
1645+
time.sleep(1.0)
1646+
return idx
1647+
1648+
def __len__(self):
1649+
return self.end
1650+
1651+
1652+
class _TestSlowIterableDataset(torch.utils.data.IterableDataset):
1653+
def __init__(self, start: int, end: int):
1654+
self.start = start
1655+
self.end = end
1656+
self.mid = math.ceil((self.end - self.start) / 2)
1657+
1658+
def give_data(self, iter_start, iter_end):
1659+
for i in range(iter_start, iter_end):
1660+
if i == self.mid:
1661+
time.sleep(1.0)
1662+
yield i
1663+
1664+
def __iter__(self):
1665+
worker_info = torch.utils.data.get_worker_info()
1666+
per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
1667+
worker_id = worker_info.id
1668+
iter_start = self.start + worker_id * per_worker
1669+
iter_end = min(iter_start + per_worker, self.end)
1670+
return self.give_data(iter_start, iter_end)
1671+
1672+
1673+
class TestOutOfOrderWithCheckpointing(TestCase):
1674+
def test_out_of_order_index_ds(self):
1675+
dataset = _TestSlowIndexDataset(end=10, slow_index=0)
1676+
dataloader = StatefulDataLoader(
1677+
dataset,
1678+
num_workers=2,
1679+
in_order=False,
1680+
)
1681+
1682+
# worker_id = 0 gets 'stuck' on 0 and also has 2 in it's queue
1683+
# due to prefetch_factor being 2
1684+
output = []
1685+
for i, data in enumerate(dataloader):
1686+
output.append(data)
1687+
if i == 3:
1688+
state_dict = dataloader.state_dict()
1689+
break
1690+
1691+
# 0 is the slow index, assert it isn't in the output before the pause
1692+
self.assertNotIn(0, output)
1693+
1694+
new_dataloader = StatefulDataLoader(dataset, num_workers=2, in_order=False)
1695+
new_dataloader.load_state_dict(state_dict)
1696+
for i, data in enumerate(new_dataloader):
1697+
output.append(data)
1698+
1699+
self.assertEqual(len(output), 10)
1700+
self.assertNotEqual(output, list(range(10)))
1701+
self.assertEqual(sorted(output), list(range(10)))
1702+
1703+
def test_out_of_order_iterable_ds_one_completed_worker(self):
1704+
dataset = _TestSlowIterableDataset(start=0, end=10)
1705+
dataloader = StatefulDataLoader(
1706+
dataset,
1707+
num_workers=2,
1708+
prefetch_factor=2,
1709+
in_order=False,
1710+
)
1711+
1712+
# break later on, as one of the workers will be finished
1713+
output = []
1714+
for i, data in enumerate(dataloader):
1715+
output.append(data)
1716+
if i == 7:
1717+
state_dict = dataloader.state_dict()
1718+
break
1719+
1720+
worker_0_ended = state_dict["_snapshot"]["_worker_snapshots"]["worker_0"]["fetcher_state"]["fetcher_ended"]
1721+
worker_1_ended = state_dict["_snapshot"]["_worker_snapshots"]["worker_1"]["fetcher_state"]["fetcher_ended"]
1722+
self.assertTrue(worker_0_ended)
1723+
self.assertFalse(worker_1_ended)
1724+
1725+
new_dataloader = StatefulDataLoader(dataset, batch_size=1, num_workers=2, in_order=False)
1726+
new_dataloader.load_state_dict(state_dict)
1727+
for i, data in enumerate(new_dataloader):
1728+
output.append(data)
1729+
1730+
self.assertEqual(len(output), 10)
1731+
self.assertEqual(output, list(range(10)))
1732+
self.assertNotEqual(output, [0, 5, 1, 6, 2, 7, 3, 8, 4, 9])
1733+
1734+
def test_out_of_order_iterable_ds_no_completed_workers(self):
1735+
dataset = _TestSlowIterableDataset(start=0, end=10)
1736+
dataloader = StatefulDataLoader(
1737+
dataset,
1738+
num_workers=2,
1739+
prefetch_factor=2,
1740+
in_order=False,
1741+
)
1742+
1743+
# break early - both workers will resume
1744+
output = []
1745+
for i, data in enumerate(dataloader):
1746+
output.append(data)
1747+
if i == 3:
1748+
state_dict = dataloader.state_dict()
1749+
break
1750+
1751+
worker_0_ended = state_dict["_snapshot"]["_worker_snapshots"]["worker_0"]["fetcher_state"]["fetcher_ended"]
1752+
worker_1_ended = state_dict["_snapshot"]["_worker_snapshots"]["worker_1"]["fetcher_state"]["fetcher_ended"]
1753+
self.assertFalse(worker_0_ended)
1754+
self.assertFalse(worker_1_ended)
1755+
1756+
new_dataloader = StatefulDataLoader(dataset, batch_size=1, num_workers=2, in_order=False)
1757+
new_dataloader.load_state_dict(state_dict)
1758+
for i, data in enumerate(new_dataloader):
1759+
output.append(data)
1760+
1761+
self.assertEqual(len(output), 10)
1762+
self.assertEqual(output, list(range(10)))
1763+
self.assertNotEqual(output, [0, 5, 1, 6, 2, 7, 3, 8, 4, 9])
1764+
1765+
16351766
if __name__ == "__main__":
16361767
unittest.main()

0 commit comments

Comments
 (0)