Skip to content

Commit 8d25de7

Browse files
authored
replace requests with urllib (#4973)
1 parent 999ef25 commit 8d25de7

File tree

1 file changed

+8
-19
lines changed

1 file changed

+8
-19
lines changed

torchvision/datasets/utils.py

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def list_files(root: str, suffix: str, prefix: bool = False) -> List[str]:
183183
return files
184184

185185

186-
def _quota_exceeded(first_chunk: bytes) -> bool: # type: ignore[name-defined]
186+
def _quota_exceeded(first_chunk: bytes) -> bool:
187187
try:
188188
return "Google Drive - Quota exceeded" in first_chunk.decode()
189189
except UnicodeDecodeError:
@@ -199,38 +199,28 @@ def download_file_from_google_drive(file_id: str, root: str, filename: Optional[
199199
filename (str, optional): Name to save the file under. If None, use the id of the file.
200200
md5 (str, optional): MD5 checksum of the download. If None, do not check
201201
"""
202-
# Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url
203-
import requests
204-
205-
url = "https://docs.google.com/uc?export=download"
202+
url = f"https://drive.google.com/uc?export=download&id={file_id}"
206203

207204
root = os.path.expanduser(root)
208205
if not filename:
209206
filename = file_id
210207
fpath = os.path.join(root, filename)
211208

212-
os.makedirs(root, exist_ok=True)
213-
214209
if os.path.isfile(fpath) and check_integrity(fpath, md5):
215210
print("Using downloaded and verified file: " + fpath)
216-
else:
217-
session = requests.Session()
218-
219-
response = session.get(url, params={"id": file_id}, stream=True)
220-
token = _get_confirm_token(response)
211+
return
221212

222-
if token:
223-
params = {"id": file_id, "confirm": token}
224-
response = session.get(url, params=params, stream=True)
213+
os.makedirs(root, exist_ok=True)
225214

215+
with urllib.request.urlopen(url) as response:
226216
# Ideally, one would use response.status_code to check for quota limits, but google drive is not consistent
227217
# with their own API, refer https://github.com/pytorch/vision/issues/2992#issuecomment-730614517.
228218
# Should this be fixed at some place in future, one could refactor the following to no longer rely on decoding
229219
# the first_chunk of the payload
230-
response_content_generator = response.iter_content(32768)
220+
content = iter(lambda: response.read(32768), b"")
231221
first_chunk = None
232222
while not first_chunk: # filter out keep-alive new chunks
233-
first_chunk = next(response_content_generator)
223+
first_chunk = next(content)
234224

235225
if _quota_exceeded(first_chunk):
236226
msg = (
@@ -240,8 +230,7 @@ def download_file_from_google_drive(file_id: str, root: str, filename: Optional[
240230
)
241231
raise RuntimeError(msg)
242232

243-
_save_response_content(itertools.chain((first_chunk,), response_content_generator), fpath)
244-
response.close()
233+
_save_response_content(itertools.chain((first_chunk,), content), fpath)
245234

246235

247236
def _get_confirm_token(response: "requests.models.Response") -> Optional[str]: # type: ignore[name-defined]

0 commit comments

Comments
 (0)