diff --git a/test/datasets/common.py b/test/datasets/common.py index 79cc9b89c0..10071bd73b 100644 --- a/test/datasets/common.py +++ b/test/datasets/common.py @@ -3,6 +3,7 @@ from parameterized import parameterized from torch.utils.data.graph import traverse from torch.utils.data.graph_settings import get_all_graph_pipes +from torchdata.dataloader2.linter import _check_shuffle_before_sharding from torchdata.datapipes.iter import Shuffler, ShardingFilter from torchtext.datasets import DATASETS @@ -34,6 +35,8 @@ def test_shuffle_shard_wrapper(self, dataset_fn): dp = [dp] for dp_split in dp: + _check_shuffle_before_sharding(dp_split) + dp_graph = get_all_graph_pipes(traverse(dp_split)) for annotation_dp_type in [Shuffler, ShardingFilter]: if not any(isinstance(dp, annotation_dp_type) for dp in dp_graph):