Skip to content

Fixing HeaderIterDP's __len__ function #166

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions test/test_datapipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,9 +307,24 @@ def test_header_iterdatapipe(self) -> None:
# __len__ Test: returns the limit when it is less than the length of source
self.assertEqual(5, len(header_dp))

# TODO(123): __len__ Test: returns the length of source when it is less than the limit
# header_dp = source_dp.header(30)
# self.assertEqual(20, len(header_dp))
# __len__ Test: returns the length of source when it is less than the limit
Copy link
Contributor Author

@NivekT NivekT Jan 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These four tests here show what behaviors are expected

header_dp = source_dp.header(30)
self.assertEqual(20, len(header_dp))

# __len__ Test: returns limit if source doesn't have length
source_dp_NoLen = IDP_NoLen(list(range(20)))
header_dp = source_dp_NoLen.header(30)
with warnings.catch_warnings(record=True) as wa:
self.assertEqual(30, len(header_dp))
self.assertEqual(len(wa), 1)
self.assertRegex(
str(wa[0].message), r"length of this HeaderIterDataPipe is inferred to be equal to its limit"
)

# __len__ Test: returns limit if source doesn't have length, but it has been iterated through once
for _ in header_dp:
pass
self.assertEqual(20, len(header_dp))

def test_enumerator_iterdatapipe(self) -> None:
letters = "abcde"
Expand Down
23 changes: 19 additions & 4 deletions torchdata/datapipes/iter/util/header.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) Facebook, Inc. and its affiliates.
from typing import Iterator, TypeVar
from warnings import warn

from torchdata.datapipes import functional_datapipe
from torchdata.datapipes.iter import IterDataPipe
Expand All @@ -20,14 +21,28 @@ class HeaderIterDataPipe(IterDataPipe[T_co]):
def __init__(self, source_datapipe: IterDataPipe[T_co], limit: int = 10) -> None:
self.source_datapipe: IterDataPipe[T_co] = source_datapipe
self.limit: int = limit
self.length: int = -1

def __iter__(self) -> Iterator[T_co]:
for i, value in enumerate(self.source_datapipe):
if i < self.limit:
i: int = 0
for value in self.source_datapipe:
i += 1
if i <= self.limit:
yield value
else:
break
self.length = min(i, self.limit) # We know length with certainty when we reach here

# TODO(134): Fix the case that the length of source_datapipe is shorter than limit
def __len__(self) -> int:
return self.limit
if self.length != -1:
return self.length
try:
source_len = len(self.source_datapipe)
self.length = min(source_len, self.limit)
return self.length
except TypeError:
warn(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is raising a warning and returning the best guess (i.e. self.limit) the right decision here? It can be triggered by things like list(header_dp) if header_dp hasn't been fully traversed once yet, which might confuse the users. The alternative is to raise an exception and not return anything.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me.

"The length of this HeaderIterDataPipe is inferred to be equal to its limit."
"The actual value may be smaller if the actual length of source_datapipe is smaller than the limit."
)
return self.limit