From b47cae04988226f3a1f2d0136ee52f78ac8c3299 Mon Sep 17 00:00:00 2001 From: Kevin Tse Date: Fri, 14 Jan 2022 15:26:03 -0500 Subject: [PATCH 1/3] Fixing HeaderIterDP's __len__ function [ghstack-poisoned] --- test/test_datapipe.py | 21 ++++++++++++++++++--- torchdata/datapipes/iter/util/header.py | 17 +++++++++++++++-- 2 files changed, 33 insertions(+), 5 deletions(-) diff --git a/test/test_datapipe.py b/test/test_datapipe.py index 32512fedb..580845245 100644 --- a/test/test_datapipe.py +++ b/test/test_datapipe.py @@ -307,9 +307,24 @@ def test_header_iterdatapipe(self) -> None: # __len__ Test: returns the limit when it is less than the length of source self.assertEqual(5, len(header_dp)) - # TODO(123): __len__ Test: returns the length of source when it is less than the limit - # header_dp = source_dp.header(30) - # self.assertEqual(20, len(header_dp)) + # __len__ Test: returns the length of source when it is less than the limit + header_dp = source_dp.header(30) + self.assertEqual(20, len(header_dp)) + + # __len__ Test: returns limit if source doesn't have length + source_dp_NoLen = IDP_NoLen(list(range(20))) + header_dp = source_dp_NoLen.header(30) + with warnings.catch_warnings(record=True) as wa: + self.assertEqual(30, len(header_dp)) + self.assertEqual(len(wa), 1) + self.assertRegex( + str(wa[0].message), r"length of this HeaderIterDataPipe is inferred to be equal to its limit" + ) + + # __len__ Test: returns limit if source doesn't have length, but it has been iterated through once + for _ in header_dp: + pass + self.assertEqual(20, len(header_dp)) def test_enumerator_iterdatapipe(self) -> None: letters = "abcde" diff --git a/torchdata/datapipes/iter/util/header.py b/torchdata/datapipes/iter/util/header.py index affc0f6a5..6c0cabc35 100644 --- a/torchdata/datapipes/iter/util/header.py +++ b/torchdata/datapipes/iter/util/header.py @@ -1,5 +1,6 @@ # Copyright (c) Facebook, Inc. and its affiliates. from typing import Iterator, TypeVar +from warnings import warn from torchdata.datapipes import functional_datapipe from torchdata.datapipes.iter import IterDataPipe @@ -20,6 +21,7 @@ class HeaderIterDataPipe(IterDataPipe[T_co]): def __init__(self, source_datapipe: IterDataPipe[T_co], limit: int = 10) -> None: self.source_datapipe: IterDataPipe[T_co] = source_datapipe self.limit: int = limit + self.length = -1 def __iter__(self) -> Iterator[T_co]: for i, value in enumerate(self.source_datapipe): @@ -27,7 +29,18 @@ def __iter__(self) -> Iterator[T_co]: yield value else: break + self.length = min(i + 1, self.limit) # We know length with certainty when we reach here - # TODO(134): Fix the case that the length of source_datapipe is shorter than limit def __len__(self) -> int: - return self.limit + if self.length != -1: + return self.length + try: + source_len = len(self.source_datapipe) + self.length = min(source_len, self.limit) + return self.length + except TypeError: + warn( + "The length of this HeaderIterDataPipe is inferred to be equal to its limit." + "The actual value may be smaller if the actual length of source_datapipe is smaller than the limit." + ) + return self.limit From db58c2150ba529c9a6b144dd0439ca06d1e6cb7e Mon Sep 17 00:00:00 2001 From: Kevin Tse Date: Fri, 14 Jan 2022 15:43:14 -0500 Subject: [PATCH 2/3] Update on "Fixing HeaderIterDP's __len__ function" The previous implementation simply return the limit as the length for `HeaderIterDataPipe`. This updated implementation takes it a step further to account for the different possible scenarios. Fixes #123 Fixes #134 [ghstack-poisoned] --- torchdata/datapipes/iter/util/header.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchdata/datapipes/iter/util/header.py b/torchdata/datapipes/iter/util/header.py index 6c0cabc35..dd9e718d8 100644 --- a/torchdata/datapipes/iter/util/header.py +++ b/torchdata/datapipes/iter/util/header.py @@ -24,6 +24,7 @@ def __init__(self, source_datapipe: IterDataPipe[T_co], limit: int = 10) -> None self.length = -1 def __iter__(self) -> Iterator[T_co]: + i = -1 for i, value in enumerate(self.source_datapipe): if i < self.limit: yield value From bcd7c11be128ead3bd2b0aac0a8171697c84d870 Mon Sep 17 00:00:00 2001 From: Kevin Tse Date: Fri, 14 Jan 2022 16:08:35 -0500 Subject: [PATCH 3/3] Update on "Fixing HeaderIterDP's __len__ function" The previous implementation simply return the limit as the length for `HeaderIterDataPipe`. This updated implementation takes it a step further to account for the different possible scenarios. Fixes #123 Fixes #134 Differential Revision: [D33589168](https://our.internmc.facebook.com/intern/diff/D33589168) [ghstack-poisoned] --- torchdata/datapipes/iter/util/header.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/torchdata/datapipes/iter/util/header.py b/torchdata/datapipes/iter/util/header.py index dd9e718d8..0262fd06c 100644 --- a/torchdata/datapipes/iter/util/header.py +++ b/torchdata/datapipes/iter/util/header.py @@ -21,16 +21,17 @@ class HeaderIterDataPipe(IterDataPipe[T_co]): def __init__(self, source_datapipe: IterDataPipe[T_co], limit: int = 10) -> None: self.source_datapipe: IterDataPipe[T_co] = source_datapipe self.limit: int = limit - self.length = -1 + self.length: int = -1 def __iter__(self) -> Iterator[T_co]: - i = -1 - for i, value in enumerate(self.source_datapipe): - if i < self.limit: + i: int = 0 + for value in self.source_datapipe: + i += 1 + if i <= self.limit: yield value else: break - self.length = min(i + 1, self.limit) # We know length with certainty when we reach here + self.length = min(i, self.limit) # We know length with certainty when we reach here def __len__(self) -> int: if self.length != -1: