Skip to content

Commit 4282c9f

Browse files
authored
Revert removing requests and make it a hard dependency instead (#5047)
* Revert "[FBcode->GH] remove unused requests functionality (#5014)" This reverts commit 33123be. * Revert "replace requests with urllib (#4973)" This reverts commit 8d25de7. * add requests as hard dependency * install library stubs in CI * fix syntax * add requests to conda dependencies * fix mypy CI
1 parent e250db3 commit 4282c9f

File tree

5 files changed

+31
-10
lines changed

5 files changed

+31
-10
lines changed

.circleci/config.yml

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.circleci/config.yml.in

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ jobs:
311311
descr: Install Python type check utilities
312312
- run:
313313
name: Check Python types statically
314-
command: mypy --config-file mypy.ini
314+
command: mypy --install-types --non-interactive --config-file mypy.ini
315315

316316
unittest_torchhub:
317317
docker:

packaging/torchvision/meta.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ requirements:
2424
run:
2525
- python
2626
- defaults::numpy >=1.11
27+
- requests
2728
- libpng
2829
- ffmpeg >=4.2 # [not win]
2930
- jpeg

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def write_version_file():
5959

6060
requirements = [
6161
"numpy",
62+
"requests",
6263
pytorch_dep,
6364
]
6465

torchvision/datasets/utils.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from typing import Any, Callable, List, Iterable, Optional, TypeVar, Dict, IO, Tuple, Iterator
1616
from urllib.parse import urlparse
1717

18+
import requests
1819
import torch
1920
from torch.utils.model_zoo import tqdm
2021

@@ -199,28 +200,37 @@ def download_file_from_google_drive(file_id: str, root: str, filename: Optional[
199200
filename (str, optional): Name to save the file under. If None, use the id of the file.
200201
md5 (str, optional): MD5 checksum of the download. If None, do not check
201202
"""
202-
url = f"https://drive.google.com/uc?export=download&id={file_id}"
203+
# Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url
204+
205+
url = "https://docs.google.com/uc?export=download"
203206

204207
root = os.path.expanduser(root)
205208
if not filename:
206209
filename = file_id
207210
fpath = os.path.join(root, filename)
208211

212+
os.makedirs(root, exist_ok=True)
213+
209214
if os.path.isfile(fpath) and check_integrity(fpath, md5):
210215
print("Using downloaded and verified file: " + fpath)
211-
return
216+
else:
217+
session = requests.Session()
212218

213-
os.makedirs(root, exist_ok=True)
219+
response = session.get(url, params={"id": file_id}, stream=True)
220+
token = _get_confirm_token(response)
221+
222+
if token:
223+
params = {"id": file_id, "confirm": token}
224+
response = session.get(url, params=params, stream=True)
214225

215-
with urllib.request.urlopen(url) as response:
216226
# Ideally, one would use response.status_code to check for quota limits, but google drive is not consistent
217227
# with their own API, refer https://github.com/pytorch/vision/issues/2992#issuecomment-730614517.
218228
# Should this be fixed at some place in future, one could refactor the following to no longer rely on decoding
219229
# the first_chunk of the payload
220-
content = iter(lambda: response.read(32768), b"")
230+
response_content_generator = response.iter_content(32768)
221231
first_chunk = None
222232
while not first_chunk: # filter out keep-alive new chunks
223-
first_chunk = next(content)
233+
first_chunk = next(response_content_generator)
224234

225235
if _quota_exceeded(first_chunk):
226236
msg = (
@@ -230,12 +240,21 @@ def download_file_from_google_drive(file_id: str, root: str, filename: Optional[
230240
)
231241
raise RuntimeError(msg)
232242

233-
_save_response_content(itertools.chain((first_chunk,), content), fpath)
243+
_save_response_content(itertools.chain((first_chunk,), response_content_generator), fpath)
244+
response.close()
245+
246+
247+
def _get_confirm_token(response: requests.models.Response) -> Optional[str]:
248+
for key, value in response.cookies.items():
249+
if key.startswith("download_warning"):
250+
return value
251+
252+
return None
234253

235254

236255
def _save_response_content(
237256
response_gen: Iterator[bytes],
238-
destination: str, # type: ignore[name-defined]
257+
destination: str,
239258
) -> None:
240259
with open(destination, "wb") as f:
241260
pbar = tqdm(total=None)

0 commit comments

Comments
 (0)