Skip to content

Commit 6a43a1f

Browse files
authored
limit requests per time in download tests (#2699)
1 parent 1b41525 commit 6a43a1f

File tree

1 file changed

+32
-3
lines changed

1 file changed

+32
-3
lines changed

test/test_datasets_download.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import contextlib
22
import itertools
3+
import time
34
import unittest
45
import unittest.mock
6+
from datetime import datetime
57
from os import path
6-
from time import sleep
8+
from urllib.parse import urlparse
79
from urllib.request import urlopen, Request
810

911
from torchvision import datasets
@@ -13,6 +15,34 @@
1315
from fakedata_generation import places365_root
1416

1517

18+
def limit_requests_per_time(min_secs_between_requests=2.0):
19+
last_requests = {}
20+
21+
def outer_wrapper(fn):
22+
def inner_wrapper(request, *args, **kwargs):
23+
url = request.full_url if isinstance(request, Request) else request
24+
25+
netloc = urlparse(url).netloc
26+
last_request = last_requests.get(netloc)
27+
if last_request is not None:
28+
elapsed_secs = (datetime.now() - last_request).total_seconds()
29+
delta = min_secs_between_requests - elapsed_secs
30+
if delta > 0:
31+
time.sleep(delta)
32+
33+
response = fn(request, *args, **kwargs)
34+
last_requests[netloc] = datetime.now()
35+
36+
return response
37+
38+
return inner_wrapper
39+
40+
return outer_wrapper
41+
42+
43+
urlopen = limit_requests_per_time()(urlopen)
44+
45+
1646
class DownloadTester(unittest.TestCase):
1747
@staticmethod
1848
@contextlib.contextmanager
@@ -37,7 +67,7 @@ def retry(fn, times=1, wait=5.0):
3767
return fn()
3868
except AssertionError as error:
3969
msgs.append(str(error))
40-
sleep(wait)
70+
time.sleep(wait)
4171
else:
4272
raise AssertionError(
4373
"\n".join(
@@ -80,7 +110,6 @@ def test_download(self):
80110
for url, md5 in self.collect_urls_and_md5s():
81111
with self.subTest(url=url, md5=md5):
82112
self.retry(lambda: assert_fn(url, md5))
83-
sleep(2.0)
84113

85114
def collect_urls_and_md5s(self):
86115
raise NotImplementedError

0 commit comments

Comments
 (0)