diff --git a/test/test_datapipe.py b/test/test_datapipe.py index 32512fedb..580845245 100644 --- a/test/test_datapipe.py +++ b/test/test_datapipe.py @@ -307,9 +307,24 @@ def test_header_iterdatapipe(self) -> None: # __len__ Test: returns the limit when it is less than the length of source self.assertEqual(5, len(header_dp)) - # TODO(123): __len__ Test: returns the length of source when it is less than the limit - # header_dp = source_dp.header(30) - # self.assertEqual(20, len(header_dp)) + # __len__ Test: returns the length of source when it is less than the limit + header_dp = source_dp.header(30) + self.assertEqual(20, len(header_dp)) + + # __len__ Test: returns limit if source doesn't have length + source_dp_NoLen = IDP_NoLen(list(range(20))) + header_dp = source_dp_NoLen.header(30) + with warnings.catch_warnings(record=True) as wa: + self.assertEqual(30, len(header_dp)) + self.assertEqual(len(wa), 1) + self.assertRegex( + str(wa[0].message), r"length of this HeaderIterDataPipe is inferred to be equal to its limit" + ) + + # __len__ Test: returns limit if source doesn't have length, but it has been iterated through once + for _ in header_dp: + pass + self.assertEqual(20, len(header_dp)) def test_enumerator_iterdatapipe(self) -> None: letters = "abcde" diff --git a/torchdata/datapipes/iter/util/header.py b/torchdata/datapipes/iter/util/header.py index affc0f6a5..0262fd06c 100644 --- a/torchdata/datapipes/iter/util/header.py +++ b/torchdata/datapipes/iter/util/header.py @@ -1,5 +1,6 @@ # Copyright (c) Facebook, Inc. and its affiliates. from typing import Iterator, TypeVar +from warnings import warn from torchdata.datapipes import functional_datapipe from torchdata.datapipes.iter import IterDataPipe @@ -20,14 +21,28 @@ class HeaderIterDataPipe(IterDataPipe[T_co]): def __init__(self, source_datapipe: IterDataPipe[T_co], limit: int = 10) -> None: self.source_datapipe: IterDataPipe[T_co] = source_datapipe self.limit: int = limit + self.length: int = -1 def __iter__(self) -> Iterator[T_co]: - for i, value in enumerate(self.source_datapipe): - if i < self.limit: + i: int = 0 + for value in self.source_datapipe: + i += 1 + if i <= self.limit: yield value else: break + self.length = min(i, self.limit) # We know length with certainty when we reach here - # TODO(134): Fix the case that the length of source_datapipe is shorter than limit def __len__(self) -> int: - return self.limit + if self.length != -1: + return self.length + try: + source_len = len(self.source_datapipe) + self.length = min(source_len, self.limit) + return self.length + except TypeError: + warn( + "The length of this HeaderIterDataPipe is inferred to be equal to its limit." + "The actual value may be smaller if the actual length of source_datapipe is smaller than the limit." + ) + return self.limit