@@ -183,7 +183,7 @@ def list_files(root: str, suffix: str, prefix: bool = False) -> List[str]:
183
183
return files
184
184
185
185
186
- def _quota_exceeded (first_chunk : bytes ) -> bool : # type: ignore[name-defined]
186
+ def _quota_exceeded (first_chunk : bytes ) -> bool :
187
187
try :
188
188
return "Google Drive - Quota exceeded" in first_chunk .decode ()
189
189
except UnicodeDecodeError :
@@ -199,38 +199,28 @@ def download_file_from_google_drive(file_id: str, root: str, filename: Optional[
199
199
filename (str, optional): Name to save the file under. If None, use the id of the file.
200
200
md5 (str, optional): MD5 checksum of the download. If None, do not check
201
201
"""
202
- # Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url
203
- import requests
204
-
205
- url = "https://docs.google.com/uc?export=download"
202
+ url = f"https://drive.google.com/uc?export=download&id={ file_id } "
206
203
207
204
root = os .path .expanduser (root )
208
205
if not filename :
209
206
filename = file_id
210
207
fpath = os .path .join (root , filename )
211
208
212
- os .makedirs (root , exist_ok = True )
213
-
214
209
if os .path .isfile (fpath ) and check_integrity (fpath , md5 ):
215
210
print ("Using downloaded and verified file: " + fpath )
216
- else :
217
- session = requests .Session ()
218
-
219
- response = session .get (url , params = {"id" : file_id }, stream = True )
220
- token = _get_confirm_token (response )
211
+ return
221
212
222
- if token :
223
- params = {"id" : file_id , "confirm" : token }
224
- response = session .get (url , params = params , stream = True )
213
+ os .makedirs (root , exist_ok = True )
225
214
215
+ with urllib .request .urlopen (url ) as response :
226
216
# Ideally, one would use response.status_code to check for quota limits, but google drive is not consistent
227
217
# with their own API, refer https://github.com/pytorch/vision/issues/2992#issuecomment-730614517.
228
218
# Should this be fixed at some place in future, one could refactor the following to no longer rely on decoding
229
219
# the first_chunk of the payload
230
- response_content_generator = response .iter_content (32768 )
220
+ content = iter ( lambda : response .read (32768 ), b"" )
231
221
first_chunk = None
232
222
while not first_chunk : # filter out keep-alive new chunks
233
- first_chunk = next (response_content_generator )
223
+ first_chunk = next (content )
234
224
235
225
if _quota_exceeded (first_chunk ):
236
226
msg = (
@@ -240,8 +230,7 @@ def download_file_from_google_drive(file_id: str, root: str, filename: Optional[
240
230
)
241
231
raise RuntimeError (msg )
242
232
243
- _save_response_content (itertools .chain ((first_chunk ,), response_content_generator ), fpath )
244
- response .close ()
233
+ _save_response_content (itertools .chain ((first_chunk ,), content ), fpath )
245
234
246
235
247
236
def _get_confirm_token (response : "requests.models.Response" ) -> Optional [str ]: # type: ignore[name-defined]
0 commit comments