Skip to content

Commit 37bd84f

Browse files
ejguanfacebook-github-bot
authored andcommitted
Fix OnDiskCacheHolder to list all files for decompressing operations (#203)
Summary: Pull Request resolved: #203 Add `FileLister` to make sure `OnDiskCacheHolder` can list all of files after any 1-to-N operations like decompression. Test Plan: Imported from OSS Reviewed By: VitalyFedyunin, NivekT Differential Revision: D34085743 Pulled By: ejguan fbshipit-source-id: 3f2461b0e77eb015ec4e8b5b5a936505380f5a76
1 parent 8e4dd06 commit 37bd84f

File tree

3 files changed

+33
-7
lines changed

3 files changed

+33
-7
lines changed

test/test_remote_io.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,30 @@ def _read_and_decode(x):
137137
self.assertTrue(os.path.exists(expected_csv_path))
138138
self.assertEqual(expected_csv_path, csv_path)
139139

140+
# Cache decompressed archive but only check root directory
141+
root_dir = "temp"
142+
143+
file_cache_dp = OnDiskCacheHolder(
144+
tar_cache_dp, filepath_fn=lambda tar_path: os.path.join(os.path.dirname(tar_path), root_dir)
145+
)
146+
file_cache_dp = FileOpener(file_cache_dp, mode="rb").read_from_tar()
147+
file_cache_dp = file_cache_dp.end_caching(
148+
mode="wb",
149+
filepath_fn=lambda file_path: os.path.join(self.temp_dir.name, root_dir, os.path.basename(file_path)),
150+
)
151+
152+
cached_it = iter(file_cache_dp)
153+
for i in range(3):
154+
expected_csv_path = os.path.join(self.temp_dir.name, root_dir, f"{i}.csv")
155+
# File doesn't exist on disk
156+
self.assertFalse(os.path.exists(expected_csv_path))
157+
158+
csv_path = next(cached_it)
159+
160+
# File is cached to disk
161+
self.assertTrue(os.path.exists(expected_csv_path))
162+
self.assertEqual(expected_csv_path, csv_path)
163+
140164

141165
if __name__ == "__main__":
142166
unittest.main()

torchdata/datapipes/iter/util/cacheholder.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
from torch.utils.data.graph import traverse
1414
from torchdata.datapipes import functional_datapipe
15-
from torchdata.datapipes.iter import IterDataPipe
15+
from torchdata.datapipes.iter import FileLister, IterDataPipe
1616

1717
if DILL_AVAILABLE:
1818
import dill
@@ -135,7 +135,7 @@ def _filepath_fn(url):
135135
hash_dict = {"expected_filepaht": expected_MD5_hash}
136136
137137
cache_dp = url.on_disk_cache(filepath_fn=_filepath_fn, hash_dict=_hash_dict, hash_type="md5")
138-
cache_dp = HttpReader(cache_dp).end_caching(mode="wb". filepath_fn=_filepath_fn)
138+
cache_dp = HttpReader(cache_dp).end_caching(filepath_fn=_filepath_fn)
139139
"""
140140

141141
_temp_dict: Dict = {}
@@ -234,14 +234,15 @@ class EndOnDiskCacheHolderIterDataPipe(IterDataPipe):
234234
datapipe: IterDataPipe with at least one `OnDiskCacheHolder` in the graph.
235235
mode: Mode in which cached files are opened for write the data. This is needed
236236
to be aligned with the type of data or file handle from `datapipe`.
237+
``"wb"`` is used by default.
237238
filepath_fn: Optional function to extract filepath from the metadata from `datapipe`.
238239
As default, it would directly use the metadata as file path.
239240
same_filepath_fn: Set to `True` to use same `filepath_fn` from the `OnDiskCacheHolder`.
240241
skip_read: Boolean value to skip reading the file handle from `datapipe`.
241242
As default, reading is enabled and reading function is created based on the `mode`.
242243
"""
243244

244-
def __new__(cls, datapipe, mode="w", filepath_fn=None, *, same_filepath_fn=False, skip_read=False):
245+
def __new__(cls, datapipe, mode="wb", filepath_fn=None, *, same_filepath_fn=False, skip_read=False):
245246
if filepath_fn is not None and same_filepath_fn:
246247
raise ValueError("`filepath_fn` is mutually exclusive with `same_filepath_fn`")
247248

@@ -255,16 +256,17 @@ def __new__(cls, datapipe, mode="w", filepath_fn=None, *, same_filepath_fn=False
255256

256257
_filepath_fn, _hash_dict, _hash_type, _ = OnDiskCacheHolderIterDataPipe._temp_dict[cache_holder]
257258
cached_dp = cache_holder._end_caching()
259+
cached_dp = FileLister(cached_dp, recursive=True)
258260

259261
if same_filepath_fn:
260262
filepath_fn = _filepath_fn
261263

262264
todo_dp = datapipe
263265
if not skip_read:
264-
if "b" in mode:
265-
todo_dp = todo_dp.map(fn=_read_bytes, input_col=1)
266-
else:
266+
if "t" in mode:
267267
todo_dp = todo_dp.map(fn=_read_str, input_col=1)
268+
else:
269+
todo_dp = todo_dp.map(fn=_read_bytes, input_col=1)
268270

269271
if filepath_fn is not None:
270272
todo_dp = todo_dp.map(fn=filepath_fn, input_col=0)

torchdata/datapipes/iter/util/saver.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def __init__(
2828
filepath_fn: Optional[Callable] = None,
2929
):
3030
self.source_datapipe: IterDataPipe[Tuple[Any, U]] = source_datapipe
31-
self.mode: str = mode
31+
self.mode: str = mode if "w" in mode else "w" + mode
3232
self.fn: Optional[Callable] = filepath_fn
3333

3434
def __iter__(self) -> Iterator[str]:

0 commit comments

Comments
 (0)