15
15
from typing import Any , Callable , List , Iterable , Optional , TypeVar , Dict , IO , Tuple , Iterator
16
16
from urllib .parse import urlparse
17
17
18
+ import requests
18
19
import torch
19
20
from torch .utils .model_zoo import tqdm
20
21
@@ -199,28 +200,37 @@ def download_file_from_google_drive(file_id: str, root: str, filename: Optional[
199
200
filename (str, optional): Name to save the file under. If None, use the id of the file.
200
201
md5 (str, optional): MD5 checksum of the download. If None, do not check
201
202
"""
202
- url = f"https://drive.google.com/uc?export=download&id={ file_id } "
203
+ # Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url
204
+
205
+ url = "https://docs.google.com/uc?export=download"
203
206
204
207
root = os .path .expanduser (root )
205
208
if not filename :
206
209
filename = file_id
207
210
fpath = os .path .join (root , filename )
208
211
212
+ os .makedirs (root , exist_ok = True )
213
+
209
214
if os .path .isfile (fpath ) and check_integrity (fpath , md5 ):
210
215
print ("Using downloaded and verified file: " + fpath )
211
- return
216
+ else :
217
+ session = requests .Session ()
212
218
213
- os .makedirs (root , exist_ok = True )
219
+ response = session .get (url , params = {"id" : file_id }, stream = True )
220
+ token = _get_confirm_token (response )
221
+
222
+ if token :
223
+ params = {"id" : file_id , "confirm" : token }
224
+ response = session .get (url , params = params , stream = True )
214
225
215
- with urllib .request .urlopen (url ) as response :
216
226
# Ideally, one would use response.status_code to check for quota limits, but google drive is not consistent
217
227
# with their own API, refer https://github.com/pytorch/vision/issues/2992#issuecomment-730614517.
218
228
# Should this be fixed at some place in future, one could refactor the following to no longer rely on decoding
219
229
# the first_chunk of the payload
220
- content = iter ( lambda : response .read (32768 ), b"" )
230
+ response_content_generator = response .iter_content (32768 )
221
231
first_chunk = None
222
232
while not first_chunk : # filter out keep-alive new chunks
223
- first_chunk = next (content )
233
+ first_chunk = next (response_content_generator )
224
234
225
235
if _quota_exceeded (first_chunk ):
226
236
msg = (
@@ -230,12 +240,21 @@ def download_file_from_google_drive(file_id: str, root: str, filename: Optional[
230
240
)
231
241
raise RuntimeError (msg )
232
242
233
- _save_response_content (itertools .chain ((first_chunk ,), content ), fpath )
243
+ _save_response_content (itertools .chain ((first_chunk ,), response_content_generator ), fpath )
244
+ response .close ()
245
+
246
+
247
+ def _get_confirm_token (response : requests .models .Response ) -> Optional [str ]:
248
+ for key , value in response .cookies .items ():
249
+ if key .startswith ("download_warning" ):
250
+ return value
251
+
252
+ return None
234
253
235
254
236
255
def _save_response_content (
237
256
response_gen : Iterator [bytes ],
238
- destination : str , # type: ignore[name-defined]
257
+ destination : str ,
239
258
) -> None :
240
259
with open (destination , "wb" ) as f :
241
260
pbar = tqdm (total = None )
0 commit comments