Skip to content

Commit 50b2f91

Browse files
kohr-hfmassa
authored andcommitted
Fix broken progress bar (#524)
- Fix broken update calculation - Make progress bar use the neat `unit_scale` feature of tqdm
1 parent 3f6c23c commit 50b2f91

File tree

1 file changed

+13
-5
lines changed

1 file changed

+13
-5
lines changed

torchvision/datasets/utils.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@
55
from tqdm import tqdm
66

77

8-
def gen_bar_updator(pbar):
8+
def gen_bar_updater(pbar):
99
def bar_update(count, block_size, total_size):
10-
pbar.total = total_size / block_size
11-
pbar.update(count)
10+
if pbar.total is None and total_size:
11+
pbar.total = total_size
12+
progress_bytes = count * block_size
13+
pbar.update(progress_bytes - pbar.n)
1214

1315
return bar_update
1416

@@ -47,13 +49,19 @@ def download_url(url, root, filename, md5):
4749
else:
4850
try:
4951
print('Downloading ' + url + ' to ' + fpath)
50-
urllib.request.urlretrieve(url, fpath, reporthook=gen_bar_updator(tqdm()))
52+
urllib.request.urlretrieve(
53+
url, fpath,
54+
reporthook=gen_bar_updater(tqdm(unit='B', unit_scale=True))
55+
)
5156
except:
5257
if url[:5] == 'https':
5358
url = url.replace('https:', 'http:')
5459
print('Failed download. Trying https -> http instead.'
5560
' Downloading ' + url + ' to ' + fpath)
56-
urllib.request.urlretrieve(url, fpath)
61+
urllib.request.urlretrieve(
62+
url, fpath,
63+
reporthook=gen_bar_updater(tqdm(unit='B', unit_scale=True))
64+
)
5765

5866

5967
def list_dir(root, prefix=False):

0 commit comments

Comments
 (0)