Skip to content

Commit 7abb164

Browse files
adding aistore datapipe (#545)
Summary: Fixes #{[517](#517)} ### Changes - Added `aisio.py` (Iterable Datapipe for AIStore backends) - Added unit tests in ~~`test/test_local_io.py`~~ `test/test_aistore.py` - Added GitHub action for running AIStore in ~~`CI.yml`~~ `.github/aistore_ci.yml` workflow ### Questions to maintainers - We are unsure about the documentation generated on PyTorch to refer it in `README.md`, so I have tentatively added a URL similar to s3io functions. (see ```torchdata/datapipes/iter/load/README.md```) Signed-off-by: Abhishek Gaikwad <[email protected]> Pull Request resolved: #545 Reviewed By: VitalyFedyunin Differential Revision: D37620194 Pulled By: msaroufim fbshipit-source-id: 9df099586dd39d47f8fdf2b760b17503f8a9822d
1 parent f629899 commit 7abb164

File tree

9 files changed

+336
-6
lines changed

9 files changed

+336
-6
lines changed

.github/workflows/aistore_ci.yml

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
name: Run AIStore Datapipe Test
2+
on:
3+
push:
4+
branches:
5+
- main
6+
- release/*
7+
tags:
8+
pull_request:
9+
types: [opened, synchronize, reopened, labeled]
10+
branches:
11+
- main
12+
# For PR created by ghstack
13+
- gh/*/*/base
14+
- release/*
15+
16+
jobs:
17+
test:
18+
if: ${{ github.repository_owner == 'pytorch' }}
19+
runs-on: ${{ matrix.os }}
20+
strategy:
21+
fail-fast: false
22+
matrix:
23+
os:
24+
- macos-latest
25+
- ubuntu-latest
26+
python-version:
27+
- 3.7
28+
- 3.8
29+
- 3.9
30+
steps:
31+
- name: Get PyTorch Channel
32+
shell: bash
33+
run: |
34+
if [[ "${{ github.base_ref }}" == release/* ]] || [[ "${{ github.ref }}" == refs/heads/release/* ]] || [[ "${{ github.ref }}" == refs/tags/v* ]]; then
35+
PT_CHANNEL="https://download.pytorch.org/whl/test/cpu/torch_test.html"
36+
else
37+
PT_CHANNEL="https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html"
38+
fi
39+
echo "::set-output name=value::$PT_CHANNEL"
40+
id: pytorch_channel
41+
- name: Setup Python ${{ matrix.python-version }}
42+
uses: actions/setup-python@v2
43+
with:
44+
python-version: ${{ matrix.python-version }}
45+
- name: Check out source repository
46+
uses: actions/checkout@v2
47+
- name: Install dependencies
48+
run: |
49+
pip3 install -r requirements.txt
50+
pip3 install --pre torch -f "${{ steps.pytorch_channel.outputs.value }}"
51+
- name: Run AIStore local deployment
52+
uses: NVIDIA/aistore@master
53+
- name: Build TorchData
54+
run: |
55+
python setup.py install
56+
- name: Install test requirements
57+
run: pip3 install -r test/requirements_aistore.txt
58+
- name: Run AIStore DataPipe tests with pytest
59+
run: pytest --no-header -v test/test_aistore.py

.github/workflows/ci.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,18 +80,18 @@ jobs:
8080
if: ${{ ! contains(github.event.pull_request.labels.*.name, 'ciflow/slow') }}
8181
run:
8282
pytest --no-header -v test --ignore=test/test_period.py --ignore=test/test_text_examples.py
83-
--ignore=test/test_audio_examples.py
83+
--ignore=test/test_audio_examples.py --ignore=test/test_aistore.py
8484
- name: Run DataPipes tests with pytest (including slow tests)
8585
if: ${{ contains(github.event.pull_request.labels.*.name, 'ciflow/slow') }}
8686
run:
8787
pytest --no-header -v test --ignore=test/test_period.py --ignore=test/test_text_examples.py
88-
--ignore=test/test_audio_examples.py
88+
--ignore=test/test_audio_examples.py --ignore=test/test_aistore.py
8989
env:
9090
PYTORCH_TEST_WITH_SLOW: 1
9191
- name: Run DataPipes period tests with pytest
9292
if: ${{ contains(github.event.pull_request.labels.*.name, 'ciflow/period') }}
9393
run:
9494
pytest --no-header -v test/test_period.py --ignore=test/test_text_examples.py
95-
--ignore=test/test_audio_examples.py
95+
--ignore=test/test_audio_examples.py --ignore=test/test_aistore.py
9696
env:
9797
PYTORCH_TEST_WITH_SLOW: 1

docs/source/torchdata.datapipes.iter.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,8 @@ saving files, and listing the files in directories).
130130
:toctree: generated/
131131
:template: datapipe.rst
132132

133+
AISFileLister
134+
AISFileLoader
133135
FSSpecFileLister
134136
FSSpecFileOpener
135137
FSSpecSaver

mypy.ini

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,6 @@ ignore_missing_imports = True
4747

4848
[mypy-graphviz.*]
4949
ignore_missing_imports = True
50+
51+
[mypy-aistore.*]
52+
ignore_missing_imports = True

test/requirements_aistore.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
pytest
2+
aistore > 0.9.1

test/test_aistore.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import random
8+
import string
9+
import tempfile
10+
import unittest
11+
12+
from torchdata.datapipes.iter import AISFileLister, AISFileLoader
13+
14+
try:
15+
from aistore.client.api import Client
16+
from aistore.client.errors import AISError, ErrBckNotFound
17+
18+
AIS_CLUSTER_ENDPT = "http://localhost:8080"
19+
20+
HAS_AIS = Client(AIS_CLUSTER_ENDPT).is_aistore_running()
21+
except (ImportError, ConnectionError):
22+
HAS_AIS = False
23+
skipIfNoAIS = unittest.skipIf(not HAS_AIS, "AIS not running or library not installed")
24+
25+
26+
@skipIfNoAIS
27+
class TestDataPipeLocalIO(unittest.TestCase):
28+
def setUp(self):
29+
# initialize client and create new bucket
30+
self.client = Client(AIS_CLUSTER_ENDPT)
31+
letters = string.ascii_lowercase
32+
self.bck_name = "".join(random.choice(letters) for _ in range(10))
33+
self.client.create_bucket(self.bck_name)
34+
# create temp files
35+
num_objs = 10
36+
37+
# create 10 objects in the `/temp` dir
38+
for i in range(num_objs):
39+
object_body = "test string" * random.randrange(1, 10)
40+
content = object_body.encode("utf-8")
41+
obj_name = f"temp/obj{ i }"
42+
with tempfile.NamedTemporaryFile() as file:
43+
file.write(content)
44+
file.flush()
45+
self.client.put_object(self.bck_name, obj_name, file.name)
46+
47+
# create 10 objects in the `/`dir
48+
for i in range(num_objs):
49+
object_body = "test string" * random.randrange(1, 10)
50+
content = object_body.encode("utf-8")
51+
obj_name = f"obj{ i }"
52+
with tempfile.NamedTemporaryFile() as file:
53+
file.write(content)
54+
file.flush()
55+
self.client.put_object(self.bck_name, obj_name, file.name)
56+
57+
def tearDown(self):
58+
# Try to destroy bucket and its items
59+
try:
60+
self.client.destroy_bucket(self.bck_name)
61+
except ErrBckNotFound:
62+
pass
63+
64+
def test_ais_io_iterdatapipe(self):
65+
66+
prefixes = [
67+
["ais://" + self.bck_name],
68+
["ais://" + self.bck_name + "/"],
69+
["ais://" + self.bck_name + "/temp/", "ais://" + self.bck_name + "/obj"],
70+
]
71+
72+
# check if the created files exist
73+
for prefix in prefixes:
74+
urls = AISFileLister(url=AIS_CLUSTER_ENDPT, source_datapipe=prefix)
75+
ais_loader = AISFileLoader(url=AIS_CLUSTER_ENDPT, source_datapipe=urls)
76+
with self.assertRaises(TypeError):
77+
len(urls)
78+
self.assertEqual(len(list(urls)), 20)
79+
self.assertEqual(sum(1 for _ in ais_loader), 20)
80+
81+
# check for incorrect prefixes
82+
prefixes = ["ais://asdasd"]
83+
84+
# AISFileLister: Bucket not found
85+
try:
86+
list(AISFileLister(url=AIS_CLUSTER_ENDPT, source_datapipe=prefixes))
87+
except ErrBckNotFound as err:
88+
self.assertEqual(err.status_code, 404)
89+
90+
# AISFileLoader: incorrect inputs
91+
url_list = [[""], ["ais:"], ["ais://"], ["s3:///unkown-bucket"]]
92+
93+
for url in url_list:
94+
with self.assertRaises(AISError):
95+
file_loader = AISFileLoader(url=AIS_CLUSTER_ENDPT, source_datapipe=url)
96+
for _ in file_loader:
97+
pass
98+
99+
100+
if __name__ == "__main__":
101+
unittest.main()

torchdata/datapipes/iter/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@
2929
UnBatcher,
3030
Zipper,
3131
)
32+
from torchdata.datapipes.iter.load.aisio import (
33+
AISFileListerIterDataPipe as AISFileLister,
34+
AISFileLoaderIterDataPipe as AISFileLoader,
35+
)
3236

3337
###############################################################################
3438
# TorchData
@@ -125,6 +129,8 @@
125129
from torchdata.datapipes.map.util.converter import MapToIterConverterIterDataPipe as MapToIterConverter
126130

127131
__all__ = [
132+
"AISFileLister",
133+
"AISFileLoader",
128134
"BatchMapper",
129135
"Batcher",
130136
"BucketBatcher",

torchdata/datapipes/iter/load/README.md

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1-
# S3 IO Datapipe Documentation
1+
# Iterable Datapipes
22

3-
## Build from Source
3+
## S3 IO Datapipe Documentation
4+
5+
### Build from Source
46

57
`ninja` is required to link PyThon implementation to C++ source code.
68

@@ -44,8 +46,37 @@ It's recommended to set up a detailed configuration file with the `AWS_CONFIG_FI
4446
environment variables are also parsed: `HOME`, `S3_USE_HTTPS`, `S3_VERIFY_SSL`, `S3_ENDPOINT_URL`, `AWS_REGION` (would
4547
be overwritten by the `region` variable).
4648

47-
## Troubleshooting
49+
### Troubleshooting
4850

4951
If you get `Access Denied`, it's very possibly a
5052
[wrong region configuration](https://github.com/aws/aws-sdk-cpp/issues/1211) or an
5153
[accessing issue with `aws-sdk-cpp`](https://aws.amazon.com/premiumsupport/knowledge-center/s3-access-denied-aws-sdk/).
54+
55+
## AIStore IO Datapipe
56+
57+
[AIStore](https://github.com/NVIDIA/aistore) (AIS for short) is a highly available lightweight object storage system
58+
that specifically focuses on petascale deep learning. As a reliable redundant storage, AIS supports n-way mirroring and
59+
erasure coding. But it is not purely – or not only – a storage system: it’ll shuffle user datasets and run custom
60+
extract-transform-load workloads.
61+
62+
AIS is an elastic cluster that can grow and shrink at runtime and can be ad-hoc deployed, with or without Kubernetes,
63+
anywhere from a single Linux machine to a bare-metal cluster of any size.
64+
65+
AIS fully supports Amazon S3, Google Cloud, and Microsoft Azure backends, providing a unified namespace across multiple
66+
connected backends and/or other AIS clusters, and [more](https://github.com/NVIDIA/aistore#features). Getting started
67+
with AIS will take only a few minutes (prerequisites boil down to having a Linux with a disk) and can be done either by
68+
running a prebuilt all-in-one docker image or directly from the open-source.
69+
70+
### Dependency
71+
72+
The `AISFileLister` and `AISFileLoader` under [`aisio.py`](/torchdata/datapipes/iter/load/aisio.py) internally use the
73+
[Python SDK](https://github.com/NVIDIA/aistore/tree/master/sdk/python) for AIStore.
74+
75+
Run `pip install aistore` or `conda install aistore` to install the [python package](https://pypi.org/project/aistore/).
76+
77+
### Example
78+
79+
Please refer to the documentation:
80+
81+
- [`AISFileLister`](https://pytorch.org/data/main/generated/torchdata.datapipes.iter.AISFileLister.html#aisfilelister)
82+
- [`AISFileLoader`](https://pytorch.org/data/main/generated/torchdata.datapipes.iter.AISFileLoader.html#aisfileloader)
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from io import BytesIO
8+
from typing import Iterator, Tuple
9+
10+
from torchdata.datapipes import functional_datapipe
11+
12+
from torchdata.datapipes.iter import IterDataPipe
13+
from torchdata.datapipes.utils import StreamWrapper
14+
15+
try:
16+
from aistore.client import Client
17+
from aistore.pytorch.utils import parse_url, unparse_url
18+
19+
HAS_AIS = True
20+
except ImportError:
21+
HAS_AIS = False
22+
23+
24+
def _assert_aistore() -> None:
25+
if not HAS_AIS:
26+
raise ModuleNotFoundError(
27+
"Package `aistore` is required to be installed to use this datapipe."
28+
"Please run `pip install aistore` or `conda install aistore` to install the package"
29+
"For more info visit: https://github.com/NVIDIA/aistore/blob/master/sdk/python/"
30+
)
31+
32+
33+
@functional_datapipe("list_files_by_ais")
34+
class AISFileListerIterDataPipe(IterDataPipe[str]):
35+
"""
36+
Iterable Datapipe that lists files from the AIStore backends with the given URL prefixes. (functional name: ``list_files_by_ais``).
37+
Acceptable prefixes include but not limited to - `ais://bucket-name`, `ais://bucket-name/`
38+
39+
Note:
40+
- This function also supports files from multiple backends (`aws://..`, `gcp://..`, `hdfs://..`, etc)
41+
- Input must be a list and direct URLs are not supported.
42+
- length is -1 by default, all calls to len() are invalid as
43+
not all items are iterated at the start.
44+
- This internally uses AIStore Python SDK.
45+
46+
Args:
47+
source_datapipe(IterDataPipe[str]): a DataPipe that contains URLs/URL
48+
prefixes to objects on AIS
49+
url(str): AIStore endpoint
50+
length(int): length of the datapipe
51+
52+
Example:
53+
>>> from torchdata.datapipes.iter import IterableWrapper, AISFileLister
54+
>>> ais_prefixes = IterableWrapper(['ais://bucket-name/folder/', 'aws:bucket-name/folder/', ...])
55+
>>> dp_ais_urls = AISFileLister(url='localhost:8080', source_datapipe=prefix)
56+
>>> for d in dp_ais_urls:
57+
... pass
58+
>>> # Functional API
59+
>>> dp_ais_urls = dp_ais_urls.list_files_by_ais(url='localhost:8080')
60+
>>> for d in dp_ais_urls:
61+
... pass
62+
"""
63+
64+
def __init__(self, source_datapipe: IterDataPipe[str], url: str, length: int = -1) -> None:
65+
_assert_aistore()
66+
self.source_datapipe: IterDataPipe[str] = source_datapipe
67+
self.length: int = length
68+
self.client = Client(url)
69+
70+
def __iter__(self) -> Iterator[str]:
71+
for prefix in self.source_datapipe:
72+
provider, bck_name, prefix = parse_url(prefix)
73+
obj_iter = self.client.list_objects_iter(bck_name=bck_name, provider=provider, prefix=prefix)
74+
for entry in obj_iter:
75+
yield unparse_url(provider=provider, bck_name=bck_name, obj_name=entry.name)
76+
77+
def __len__(self) -> int:
78+
if self.length == -1:
79+
raise TypeError(f"{type(self).__name__} instance doesn't have valid length")
80+
return self.length
81+
82+
83+
@functional_datapipe("load_files_by_ais")
84+
class AISFileLoaderIterDataPipe(IterDataPipe[Tuple[str, StreamWrapper]]):
85+
"""
86+
Iterable DataPipe that loads files from AIStore with the given URLs (functional name: ``load_files_by_ais``).
87+
Iterates all files in BytesIO format and returns a tuple (url, BytesIO).
88+
89+
Note:
90+
- This function also supports files from multiple backends (`aws://..`, `gcp://..`, etc)
91+
- Input must be a list and direct URLs are not supported.
92+
- This internally uses AIStore Python SDK.
93+
94+
Args:
95+
source_datapipe(IterDataPipe[str]): a DataPipe that contains URLs/URL prefixes to objects
96+
url(str): AIStore endpoint
97+
length(int): length of the datapipe
98+
99+
Example:
100+
>>> from torchdata.datapipes.iter import IterableWrapper, AISFileLister,AISFileLoader
101+
>>> ais_prefixes = IterableWrapper(['ais://bucket-name/folder/', 'aws:bucket-name/folder/', ...])
102+
>>> dp_ais_urls = AISFileLister(url='localhost:8080', source_datapipe=prefix)
103+
>>> dp_s3_files = AISFileLoader(url='localhost:8080', source_datapipe=dp_ais_urls)
104+
>>> for url, file in dp_ais_urls:
105+
... pass
106+
>>> # Functional API
107+
>>> dp_ais_urls = dp_ais_urls.load_files_by_ais(url='localhost:8080')
108+
>>> for url, file in dp_ais_urls:
109+
... pass
110+
"""
111+
112+
def __init__(self, source_datapipe: IterDataPipe[str], url: str, length: int = -1) -> None:
113+
_assert_aistore()
114+
self.source_datapipe: IterDataPipe[str] = source_datapipe
115+
self.length = length
116+
self.client = Client(url)
117+
118+
def __iter__(self) -> Iterator[Tuple[str, StreamWrapper]]:
119+
for url in self.source_datapipe:
120+
provider, bck_name, obj_name = parse_url(url)
121+
yield url, StreamWrapper(
122+
BytesIO(self.client.get_object(bck_name=bck_name, provider=provider, obj_name=obj_name).read_all())
123+
)
124+
125+
def __len__(self) -> int:
126+
return len(self.source_datapipe)

0 commit comments

Comments
 (0)