|
6 | 6 |
|
7 | 7 | import itertools
|
8 | 8 | import json
|
| 9 | +import math |
| 10 | +import time |
9 | 11 | import unittest
|
10 | 12 | from copy import deepcopy
|
11 | 13 |
|
@@ -1632,5 +1634,134 @@ def test_mp(self):
|
1632 | 1634 | self._run_test(2, CountIterCallsIter(100))
|
1633 | 1635 |
|
1634 | 1636 |
|
| 1637 | +class _TestSlowIndexDataset(torch.utils.data.Dataset): |
| 1638 | + def __init__(self, end: int, slow_index: int): |
| 1639 | + self.end = end |
| 1640 | + self.slow_index = slow_index |
| 1641 | + self._worker_id = None |
| 1642 | + |
| 1643 | + def __getitem__(self, idx): |
| 1644 | + if idx == self.slow_index: |
| 1645 | + time.sleep(1.0) |
| 1646 | + return idx |
| 1647 | + |
| 1648 | + def __len__(self): |
| 1649 | + return self.end |
| 1650 | + |
| 1651 | + |
| 1652 | +class _TestSlowIterableDataset(torch.utils.data.IterableDataset): |
| 1653 | + def __init__(self, start: int, end: int): |
| 1654 | + self.start = start |
| 1655 | + self.end = end |
| 1656 | + self.mid = math.ceil((self.end - self.start) / 2) |
| 1657 | + |
| 1658 | + def give_data(self, iter_start, iter_end): |
| 1659 | + for i in range(iter_start, iter_end): |
| 1660 | + if i == self.mid: |
| 1661 | + time.sleep(1.0) |
| 1662 | + yield i |
| 1663 | + |
| 1664 | + def __iter__(self): |
| 1665 | + worker_info = torch.utils.data.get_worker_info() |
| 1666 | + per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers))) |
| 1667 | + worker_id = worker_info.id |
| 1668 | + iter_start = self.start + worker_id * per_worker |
| 1669 | + iter_end = min(iter_start + per_worker, self.end) |
| 1670 | + return self.give_data(iter_start, iter_end) |
| 1671 | + |
| 1672 | + |
| 1673 | +class TestOutOfOrderWithCheckpointing(TestCase): |
| 1674 | + def test_out_of_order_index_ds(self): |
| 1675 | + dataset = _TestSlowIndexDataset(end=10, slow_index=0) |
| 1676 | + dataloader = StatefulDataLoader( |
| 1677 | + dataset, |
| 1678 | + num_workers=2, |
| 1679 | + in_order=False, |
| 1680 | + ) |
| 1681 | + |
| 1682 | + # worker_id = 0 gets 'stuck' on 0 and also has 2 in it's queue |
| 1683 | + # due to prefetch_factor being 2 |
| 1684 | + output = [] |
| 1685 | + for i, data in enumerate(dataloader): |
| 1686 | + output.append(data) |
| 1687 | + if i == 3: |
| 1688 | + state_dict = dataloader.state_dict() |
| 1689 | + break |
| 1690 | + |
| 1691 | + # 0 is the slow index, assert it isn't in the output before the pause |
| 1692 | + self.assertNotIn(0, output) |
| 1693 | + |
| 1694 | + new_dataloader = StatefulDataLoader(dataset, num_workers=2, in_order=False) |
| 1695 | + new_dataloader.load_state_dict(state_dict) |
| 1696 | + for i, data in enumerate(new_dataloader): |
| 1697 | + output.append(data) |
| 1698 | + |
| 1699 | + self.assertEqual(len(output), 10) |
| 1700 | + self.assertNotEqual(output, list(range(10))) |
| 1701 | + self.assertEqual(sorted(output), list(range(10))) |
| 1702 | + |
| 1703 | + def test_out_of_order_iterable_ds_one_completed_worker(self): |
| 1704 | + dataset = _TestSlowIterableDataset(start=0, end=10) |
| 1705 | + dataloader = StatefulDataLoader( |
| 1706 | + dataset, |
| 1707 | + num_workers=2, |
| 1708 | + prefetch_factor=2, |
| 1709 | + in_order=False, |
| 1710 | + ) |
| 1711 | + |
| 1712 | + # break later on, as one of the workers will be finished |
| 1713 | + output = [] |
| 1714 | + for i, data in enumerate(dataloader): |
| 1715 | + output.append(data) |
| 1716 | + if i == 7: |
| 1717 | + state_dict = dataloader.state_dict() |
| 1718 | + break |
| 1719 | + |
| 1720 | + worker_0_ended = state_dict["_snapshot"]["_worker_snapshots"]["worker_0"]["fetcher_state"]["fetcher_ended"] |
| 1721 | + worker_1_ended = state_dict["_snapshot"]["_worker_snapshots"]["worker_1"]["fetcher_state"]["fetcher_ended"] |
| 1722 | + self.assertTrue(worker_0_ended) |
| 1723 | + self.assertFalse(worker_1_ended) |
| 1724 | + |
| 1725 | + new_dataloader = StatefulDataLoader(dataset, batch_size=1, num_workers=2, in_order=False) |
| 1726 | + new_dataloader.load_state_dict(state_dict) |
| 1727 | + for i, data in enumerate(new_dataloader): |
| 1728 | + output.append(data) |
| 1729 | + |
| 1730 | + self.assertEqual(len(output), 10) |
| 1731 | + self.assertEqual(output, list(range(10))) |
| 1732 | + self.assertNotEqual(output, [0, 5, 1, 6, 2, 7, 3, 8, 4, 9]) |
| 1733 | + |
| 1734 | + def test_out_of_order_iterable_ds_no_completed_workers(self): |
| 1735 | + dataset = _TestSlowIterableDataset(start=0, end=10) |
| 1736 | + dataloader = StatefulDataLoader( |
| 1737 | + dataset, |
| 1738 | + num_workers=2, |
| 1739 | + prefetch_factor=2, |
| 1740 | + in_order=False, |
| 1741 | + ) |
| 1742 | + |
| 1743 | + # break early - both workers will resume |
| 1744 | + output = [] |
| 1745 | + for i, data in enumerate(dataloader): |
| 1746 | + output.append(data) |
| 1747 | + if i == 3: |
| 1748 | + state_dict = dataloader.state_dict() |
| 1749 | + break |
| 1750 | + |
| 1751 | + worker_0_ended = state_dict["_snapshot"]["_worker_snapshots"]["worker_0"]["fetcher_state"]["fetcher_ended"] |
| 1752 | + worker_1_ended = state_dict["_snapshot"]["_worker_snapshots"]["worker_1"]["fetcher_state"]["fetcher_ended"] |
| 1753 | + self.assertFalse(worker_0_ended) |
| 1754 | + self.assertFalse(worker_1_ended) |
| 1755 | + |
| 1756 | + new_dataloader = StatefulDataLoader(dataset, batch_size=1, num_workers=2, in_order=False) |
| 1757 | + new_dataloader.load_state_dict(state_dict) |
| 1758 | + for i, data in enumerate(new_dataloader): |
| 1759 | + output.append(data) |
| 1760 | + |
| 1761 | + self.assertEqual(len(output), 10) |
| 1762 | + self.assertEqual(output, list(range(10))) |
| 1763 | + self.assertNotEqual(output, [0, 5, 1, 6, 2, 7, 3, 8, 4, 9]) |
| 1764 | + |
| 1765 | + |
1635 | 1766 | if __name__ == "__main__":
|
1636 | 1767 | unittest.main()
|
0 commit comments