1
1
import contextlib
2
2
import itertools
3
+ import time
3
4
import unittest
4
5
import unittest .mock
6
+ from datetime import datetime
5
7
from os import path
6
- from time import sleep
8
+ from urllib . parse import urlparse
7
9
from urllib .request import urlopen , Request
8
10
9
11
from torchvision import datasets
13
15
from fakedata_generation import places365_root
14
16
15
17
18
+ def limit_requests_per_time (min_secs_between_requests = 2.0 ):
19
+ last_requests = {}
20
+
21
+ def outer_wrapper (fn ):
22
+ def inner_wrapper (request , * args , ** kwargs ):
23
+ url = request .full_url if isinstance (request , Request ) else request
24
+
25
+ netloc = urlparse (url ).netloc
26
+ last_request = last_requests .get (netloc )
27
+ if last_request is not None :
28
+ elapsed_secs = (datetime .now () - last_request ).total_seconds ()
29
+ delta = min_secs_between_requests - elapsed_secs
30
+ if delta > 0 :
31
+ time .sleep (delta )
32
+
33
+ response = fn (request , * args , ** kwargs )
34
+ last_requests [netloc ] = datetime .now ()
35
+
36
+ return response
37
+
38
+ return inner_wrapper
39
+
40
+ return outer_wrapper
41
+
42
+
43
+ urlopen = limit_requests_per_time ()(urlopen )
44
+
45
+
16
46
class DownloadTester (unittest .TestCase ):
17
47
@staticmethod
18
48
@contextlib .contextmanager
@@ -37,7 +67,7 @@ def retry(fn, times=1, wait=5.0):
37
67
return fn ()
38
68
except AssertionError as error :
39
69
msgs .append (str (error ))
40
- sleep (wait )
70
+ time . sleep (wait )
41
71
else :
42
72
raise AssertionError (
43
73
"\n " .join (
@@ -80,7 +110,6 @@ def test_download(self):
80
110
for url , md5 in self .collect_urls_and_md5s ():
81
111
with self .subTest (url = url , md5 = md5 ):
82
112
self .retry (lambda : assert_fn (url , md5 ))
83
- sleep (2.0 )
84
113
85
114
def collect_urls_and_md5s (self ):
86
115
raise NotImplementedError
0 commit comments