diff --git a/torchvision/datasets/utils.py b/torchvision/datasets/utils.py index 2fd89e4c5a2..e3222eca41d 100644 --- a/torchvision/datasets/utils.py +++ b/torchvision/datasets/utils.py @@ -183,7 +183,7 @@ def list_files(root: str, suffix: str, prefix: bool = False) -> List[str]: return files -def _quota_exceeded(first_chunk: bytes) -> bool: # type: ignore[name-defined] +def _quota_exceeded(first_chunk: bytes) -> bool: try: return "Google Drive - Quota exceeded" in first_chunk.decode() except UnicodeDecodeError: @@ -199,38 +199,28 @@ def download_file_from_google_drive(file_id: str, root: str, filename: Optional[ filename (str, optional): Name to save the file under. If None, use the id of the file. md5 (str, optional): MD5 checksum of the download. If None, do not check """ - # Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url - import requests - - url = "https://docs.google.com/uc?export=download" + url = f"https://drive.google.com/uc?export=download&id={file_id}" root = os.path.expanduser(root) if not filename: filename = file_id fpath = os.path.join(root, filename) - os.makedirs(root, exist_ok=True) - if os.path.isfile(fpath) and check_integrity(fpath, md5): print("Using downloaded and verified file: " + fpath) - else: - session = requests.Session() - - response = session.get(url, params={"id": file_id}, stream=True) - token = _get_confirm_token(response) + return - if token: - params = {"id": file_id, "confirm": token} - response = session.get(url, params=params, stream=True) + os.makedirs(root, exist_ok=True) + with urllib.request.urlopen(url) as response: # Ideally, one would use response.status_code to check for quota limits, but google drive is not consistent # with their own API, refer https://github.com/pytorch/vision/issues/2992#issuecomment-730614517. # Should this be fixed at some place in future, one could refactor the following to no longer rely on decoding # the first_chunk of the payload - response_content_generator = response.iter_content(32768) + content = iter(lambda: response.read(32768), b"") first_chunk = None while not first_chunk: # filter out keep-alive new chunks - first_chunk = next(response_content_generator) + first_chunk = next(content) if _quota_exceeded(first_chunk): msg = ( @@ -240,8 +230,7 @@ def download_file_from_google_drive(file_id: str, root: str, filename: Optional[ ) raise RuntimeError(msg) - _save_response_content(itertools.chain((first_chunk,), response_content_generator), fpath) - response.close() + _save_response_content(itertools.chain((first_chunk,), content), fpath) def _get_confirm_token(response: "requests.models.Response") -> Optional[str]: # type: ignore[name-defined]