Skip to content

Commit 2cf1f20

Browse files
NivekTfacebook-github-bot
authored andcommitted
Adding usage examples to all IterDataPipes (#249)
Summary: Pull Request resolved: #249 Test Plan: Imported from OSS Reviewed By: ejguan Differential Revision: D34433825 Pulled By: NivekT fbshipit-source-id: c2bd60eb2ea957f064486064baf5978e7dcb3441
1 parent ebee4ca commit 2cf1f20

26 files changed

+375
-8
lines changed

docs/source/conf.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,8 @@
9595
"torch.utils.data.datapipes.map.grouping.T": "T",
9696
"torch.utils.data.datapipes.map.combining.T_co": "T_co",
9797
"torch.utils.data.datapipes.map.combinatorics.T_co": "T_co",
98+
"torchdata.datapipes.iter.util.cycler.T_co": "T_co",
99+
"torchdata.datapipes.iter.util.paragraphaggregator.T_co": "T_co",
98100
"typing.": "",
99101
}
100102

docs/source/torchdata.datapipes.iter.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ These DataPipes help opening and decompressing archive files of different format
4848
:toctree: generated/
4949
:template: datapipe.rst
5050

51-
Extractor
51+
Decompressor
5252
RarArchiveLoader
5353
TarArchiveLoader
5454
XzFileLoader

torchdata/datapipes/iter/load/fsspec.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ class FSSpecFileListerIterDataPipe(IterDataPipe[str]):
3636
Args:
3737
root: The root `fsspec` path directory to list files from
3838
masks: Unix style filter string or string list for filtering file name(s)
39+
40+
Example:
41+
>>> from torchdata.datapipes.iter import FSSpecFileLister
42+
>>> datapipe = FSSpecFileLister(root=dir_path)
3943
"""
4044

4145
def __init__(
@@ -82,6 +86,11 @@ class FSSpecFileOpenerIterDataPipe(IterDataPipe[Tuple[str, StreamWrapper]]):
8286
Args:
8387
source_datapipe: Iterable DataPipe that provides the pathnames or URLs
8488
mode: An optional string that specifies the mode in which the file is opened (``"r"`` by default)
89+
90+
Example:
91+
>>> from torchdata.datapipes.iter import FSSpecFileLister
92+
>>> datapipe = FSSpecFileLister(root=dir_path)
93+
>>> file_dp = datapipe.open_file_by_fsspec()
8594
"""
8695

8796
def __init__(self, source_datapipe: IterDataPipe[str], mode: str = "r") -> None:
@@ -111,6 +120,15 @@ class FSSpecSaverIterDataPipe(IterDataPipe[str]):
111120
source_datapipe: Iterable DataPipe with tuples of metadata and data
112121
mode: Mode in which the file will be opened for write the data (``"w"`` by default)
113122
filepath_fn: Function that takes in metadata and returns the target path of the new file
123+
124+
Example:
125+
>>> from torchdata.datapipes.iter import IterableWrapper
126+
>>> def filepath_fn(name: str) -> str:
127+
>>> return dir_path + name
128+
>>> name_to_data = {"1.txt": b"DATA1", "2.txt": b"DATA2", "3.txt": b"DATA3"}
129+
>>> source_dp = IterableWrapper(sorted(name_to_data.items()))
130+
>>> fsspec_saver_dp = source_dp.save_by_fsspec(filepath_fn=filepath_fn, mode="wb")
131+
>>> res_file_paths = list(fsspec_saver_dp)
114132
"""
115133

116134
def __init__(

torchdata/datapipes/iter/load/iopath.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@ class IoPathFileListerIterDataPipe(IterDataPipe[str]):
4747
Note:
4848
Default ``PathManager`` currently supports local file path, normal HTTP URL and OneDrive URL.
4949
S3 URL is supported only with ``iopath``>=0.1.9.
50+
51+
Example:
52+
>>> from torchdata.datapipes.iter import IoPathFileLister
53+
>>> datapipe = IoPathFileLister(root=S3URL)
5054
"""
5155

5256
def __init__(
@@ -93,6 +97,11 @@ class IoPathFileOpenerIterDataPipe(IterDataPipe[Tuple[str, StreamWrapper]]):
9397
Note:
9498
Default ``PathManager`` currently supports local file path, normal HTTP URL and OneDrive URL.
9599
S3 URL is supported only with `iopath`>=0.1.9.
100+
101+
Example:
102+
>>> from torchdata.datapipes.iter import IoPathFileLister
103+
>>> datapipe = IoPathFileLister(root=S3URL)
104+
>>> file_dp = datapipe.open_file_by_iopath()
96105
"""
97106

98107
def __init__(self, source_datapipe: IterDataPipe[str], mode: str = "r", pathmgr=None) -> None:
@@ -135,6 +144,15 @@ class IoPathSaverIterDataPipe(IterDataPipe[str]):
135144
Note:
136145
Default ``PathManager`` currently supports local file path, normal HTTP URL and OneDrive URL.
137146
S3 URL is supported only with `iopath`>=0.1.9.
147+
148+
Example:
149+
>>> from torchdata.datapipes.iter import IterableWrapper
150+
>>> def filepath_fn(name: str) -> str:
151+
>>> return S3URL + name
152+
>>> name_to_data = {"1.txt": b"DATA1", "2.txt": b"DATA2", "3.txt": b"DATA3"}
153+
>>> source_dp = IterableWrapper(sorted(name_to_data.items()))
154+
>>> iopath_saver_dp = source_dp.save_by_iopath(filepath_fn=filepath_fn, mode="wb")
155+
>>> res_file_paths = list(iopath_saver_dp)
138156
"""
139157

140158
def __init__(

torchdata/datapipes/iter/load/online.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,18 @@ class HTTPReaderIterDataPipe(IterDataPipe[Tuple[str, StreamWrapper]]):
3333
Args:
3434
source_datapipe: a DataPipe that contains URLs
3535
timeout: timeout in seconds for HTTP request
36+
37+
Example:
38+
>>> from torchdata.datapipes.iter import IterableWrapper, HttpReader
39+
>>> file_url = "https://raw.githubusercontent.com/pytorch/data/main/LICENSE"
40+
>>> http_reader_dp = HttpReader(IterableWrapper([file_url]))
41+
>>> reader_dp = http_reader_dp.readlines()
42+
>>> it = iter(reader_dp)
43+
>>> path, line = next(it)
44+
>>> path
45+
https://raw.githubusercontent.com/pytorch/data/main/LICENSE
46+
>>> line
47+
b'BSD 3-Clause License'
3648
"""
3749

3850
def __init__(self, source_datapipe: IterDataPipe[str], timeout: Optional[float] = None) -> None:
@@ -85,6 +97,18 @@ class GDriveReaderDataPipe(IterDataPipe[Tuple[str, StreamWrapper]]):
8597
Args:
8698
source_datapipe: a DataPipe that contains URLs to GDrive files
8799
timeout: timeout in seconds for HTTP request
100+
101+
Example:
102+
>>> from torchdata.datapipes.iter import IterableWrapper, GDriveReader
103+
>>> gdrive_file_url = "https://drive.google.com/uc?export=download&id=SomeIDToAGDriveFile"
104+
>>> gdrive_reader_dp = GDriveReader(IterableWrapper([gdrive_file_url]))
105+
>>> reader_dp = gdrive_reader_dp.readlines()
106+
>>> it = iter(reader_dp)
107+
>>> path, line = next(it)
108+
>>> path
109+
https://drive.google.com/uc?export=download&id=SomeIDToAGDriveFile
110+
>>> line
111+
<First line from the GDrive File>
88112
"""
89113
source_datapipe: IterDataPipe[str]
90114

@@ -108,6 +132,18 @@ class OnlineReaderIterDataPipe(IterDataPipe[Tuple[str, StreamWrapper]]):
108132
Args:
109133
source_datapipe: a DataPipe that contains URLs
110134
timeout: timeout in seconds for HTTP request
135+
136+
Example:
137+
>>> from torchdata.datapipes.iter import IterableWrapper, OnlineReader
138+
>>> file_url = "https://raw.githubusercontent.com/pytorch/data/main/LICENSE"
139+
>>> online_reader_dp = OnlineReader(IterableWrapper([file_url]))
140+
>>> reader_dp = online_reader_dp.readlines()
141+
>>> it = iter(reader_dp)
142+
>>> path, line = next(it)
143+
>>> path
144+
https://raw.githubusercontent.com/pytorch/data/main/LICENSE
145+
>>> line
146+
b'BSD 3-Clause License'
111147
"""
112148
source_datapipe: IterDataPipe[str]
113149

torchdata/datapipes/iter/transform/bucketbatcher.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,33 @@ class BucketBatcherIterDataPipe(IterDataPipe[DataChunk[T_co]]):
2323
dimension will be added as ``batch_size`` if ``drop_last`` is set to ``True``,
2424
or ``length % batch_size`` for the last batch if ``drop_last`` is set to ``False``.
2525
26+
The purpose of this DataPipe is to batch samples with some similarity according to the sorting function
27+
being passed. For an example in the text domain, it may be batching examples with similar number of tokens
28+
to minimize padding and to increase throughput.
29+
2630
Args:
2731
datapipe: Iterable DataPipe being batched
2832
batch_size: The size of each batch
2933
drop_last: Option to drop the last batch if it's not full
3034
batch_num: Number of batches within a bucket (i.e. `bucket_size = batch_size * batch_num`)
3135
bucket_num: Number of buckets to consist a pool for shuffling (i.e. `pool_size = bucket_size * bucket_num`)
32-
sort_key: Callable to specify the comparison key for sorting within bucket
33-
in_batch_shuffle: Option to do in-batch shuffle or buffer shuffle
36+
sort_key: Callable to sort a bucket (list)
37+
in_batch_shuffle: iF True, do in-batch shuffle; if False, buffer shuffle
38+
39+
Example:
40+
>>> from torchdata.datapipes.iter import IterableWrapper
41+
>>> source_dp = IterableWrapper(range(10))
42+
>>> batch_dp = source_dp.bucketbatch(batch_size=3, drop_last=True)
43+
>>> list(batch_dp)
44+
[[5, 6, 7], [9, 0, 1], [4, 3, 2]]
45+
>>> def sort_bucket(bucket):
46+
>>> return sorted(bucket)
47+
>>> batch_dp = source_dp.bucketbatch(
48+
>>> batch_size=3, drop_last=True, batch_num=100,
49+
>>> bucket_num=1, in_batch_shuffle=False, sort_key=sort_bucket
50+
>>> )
51+
>>> list(batch_dp)
52+
[[3, 4, 5], [6, 7, 8], [0, 1, 2]]
3453
"""
3554
datapipe: IterDataPipe[T_co]
3655
batch_size: int
@@ -71,7 +90,7 @@ def __new__(
7190
datapipe = datapipe.batch(batch_size, drop_last=drop_last)
7291
# Shuffle the batched data
7392
if sort_key is not None:
74-
# In-batch shuffle each bucket seems not that useful
93+
# In-batch shuffle each bucket seems not that useful, it seems misleading since .batch is called prior.
7594
if in_batch_shuffle:
7695
datapipe = datapipe.batch(batch_size=bucket_num, drop_last=False).map(fn=_in_batch_shuffle_fn).unbatch()
7796
else:

torchdata/datapipes/iter/transform/flatmap.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,15 @@ class FlatMapperIterDataPipe(IterDataPipe[T_co]):
2424
Args:
2525
datapipe: Source IterDataPipe
2626
fn: the function to be applied to each element in the DataPipe, the output must be a Sequence
27+
28+
Example:
29+
>>> from torchdata.datapipes.iter import IterableWrapper
30+
>>> def fn(e):
31+
>>> return [e, e * 10]
32+
>>> source_dp = IterableWrapper(list(range(5)))
33+
>>> flatmapped_dp = source_dp.flatmap(fn)
34+
>>> list(flatmapped_dp)
35+
[0, 0, 1, 10, 2, 20, 3, 30, 4, 40]
2736
"""
2837

2938
def __init__(self, datapipe: IterDataPipe, fn: Callable):

torchdata/datapipes/iter/util/cacheholder.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,13 @@ class InMemoryCacheHolderIterDataPipe(IterDataPipe[T_co]):
3232
Args:
3333
source_dp: source DataPipe from which elements are read and stored in memory
3434
size: The maximum size (in megabytes) that this DataPipe can hold in memory. This defaults to unlimited.
35+
36+
Example:
37+
>>> from torchdata.datapipes.iter import IterableWrapper
38+
>>> source_dp = IterableWrapper(range(10))
39+
>>> cache_dp = source_dp.in_memory_cache(size=5)
40+
>>> list(cache_dp)
41+
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
3542
"""
3643
size: Optional[int] = None
3744
idx: int
@@ -125,13 +132,14 @@ class OnDiskCacheHolderIterDataPipe(IterDataPipe):
125132
the given file path from ``filepath_fn``.
126133
127134
Example:
135+
>>> from torchdata.datapipes.iter import IterableWrapper, HttpReader
128136
>>> url = IterableWrapper(["https://path/to/filename", ])
129137
>>> def _filepath_fn(url):
130138
>>> temp_dir = tempfile.gettempdir()
131139
>>> return os.path.join(temp_dir, os.path.basename(url))
132140
>>> hash_dict = {"expected_filepath": expected_MD5_hash}
133141
>>> cache_dp = url.on_disk_cache(filepath_fn=_filepath_fn, hash_dict=_hash_dict, hash_type="md5")
134-
You must call ``.end_caching`` at a later point to stop tracing and save the results to local files.
142+
>>> # You must call ``.end_caching`` at a later point to stop tracing and save the results to local files.
135143
>>> cache_dp = HttpReader(cache_dp).end_caching(mode="wb". filepath_fn=_filepath_fn)
136144
"""
137145

@@ -235,6 +243,18 @@ class EndOnDiskCacheHolderIterDataPipe(IterDataPipe):
235243
same_filepath_fn: Set to ``True`` to use same ``filepath_fn`` from the ``OnDiskCacheHolder``.
236244
skip_read: Boolean value to skip reading the file handle from ``datapipe``.
237245
By default, reading is enabled and reading function is created based on the ``mode``.
246+
247+
Example:
248+
>>> from torchdata.datapipes.iter import IterableWrapper, HttpReader
249+
>>> url = IterableWrapper(["https://path/to/filename", ])
250+
>>> def _filepath_fn(url):
251+
>>> temp_dir = tempfile.gettempdir()
252+
>>> return os.path.join(temp_dir, os.path.basename(url))
253+
>>> hash_dict = {"expected_filepath": expected_MD5_hash}
254+
>>> # You must call ``.on_disk_cache`` at some point before ``.end_caching``
255+
>>> cache_dp = url.on_disk_cache(filepath_fn=_filepath_fn, hash_dict=_hash_dict, hash_type="md5")
256+
>>> # You must call ``.end_caching`` at a later point to stop tracing and save the results to local files.
257+
>>> cache_dp = HttpReader(cache_dp).end_caching(mode="wb". filepath_fn=_filepath_fn)
238258
"""
239259

240260
def __new__(cls, datapipe, mode="wb", filepath_fn=None, *, same_filepath_fn=False, skip_read=False):

torchdata/datapipes/iter/util/combining.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,18 @@ class IterKeyZipperIterDataPipe(IterDataPipe[T_co]):
3030
If it's specified as ``None``, the buffer size is set as infinite.
3131
merge_fn: Function that combines the item from ``source_datapipe`` and the item from ``ref_datapipe``,
3232
by default a tuple is created
33+
34+
Example:
35+
>>> from torchdata.datapipes.iter import IterableWrapper
36+
>>> from operator import itemgetter
37+
>>> def merge_fn(t1, t2):
38+
>>> return t1[1] + t2[1]
39+
>>> dp1 = IterableWrapper([('a', 100), ('b', 200), ('c', 300)])
40+
>>> dp2 = IterableWrapper([('a', 1), ('b', 2), ('c', 3), ('d', 4)])
41+
>>> res_dp = dp1.zip_with_iter(dp2, key_fn=itemgetter(0),
42+
>>> ref_key_fn=itemgetter(0), keep_key=True, merge_fn=merge_fn)
43+
>>> list(res_dp)
44+
[('a', 101), ('b', 202), ('c', 303)]
3345
"""
3446

3547
def __init__(
@@ -105,6 +117,18 @@ class MapKeyZipperIterDataPipe(IterDataPipe[T_co]):
105117
key_fn: Function that maps each item from ``source_iterdatapipe`` to a key that exists in ``map_datapipe``
106118
merge_fn: Function that combines the item from ``source_iterdatapipe`` and the matching item
107119
from ``map_datapipe``, by default a tuple is created
120+
121+
Example:
122+
>>> from torchdata.datapipes.iter import IterableWrapper
123+
>>> from torchdata.datapipes.map import SequenceWrapper
124+
>>> from operator import itemgetter
125+
>>> def merge_fn(tuple_from_iter, value_from_map):
126+
>>> return tuple_from_iter[0], tuple_from_iter[1] + value_from_map
127+
>>> dp1 = IterableWrapper([('a', 1), ('b', 2), ('c', 3)])
128+
>>> mapdp = SequenceWrapper({'a': 100, 'b': 200, 'c': 300, 'd': 400})
129+
>>> res_dp = dp1.zip_with_map(map_datapipe=mapdp, key_fn=itemgetter(0), merge_fn=merge_fn)
130+
>>> list(res_dp)
131+
[('a', 101), ('b', 202), ('c', 303)]
108132
"""
109133

110134
def __init__(

torchdata/datapipes/iter/util/cycler.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,13 @@ class CyclerIterDataPipe(IterDataPipe[T_co]):
1616
Args:
1717
source_datapipe: source DataPipe that will be cycled through
1818
count: the number of times to read through ``source_datapipe` (if ``None``, it will cycle in perpetuity)
19+
20+
Example:
21+
>>> from torchdata.datapipes.iter import IterableWrapper
22+
>>> dp = IterableWrapper(range(3))
23+
>>> dp = dp.cycle(2)
24+
>>> list(dp)
25+
[0, 1, 2, 0, 1, 2]
1926
"""
2027

2128
def __init__(self, source_datapipe: IterDataPipe[T_co], count: Optional[int] = None) -> None:

0 commit comments

Comments
 (0)