Skip to content

Commit 91a43ed

Browse files
Merge pull request #1027 from dimitri-yatsenko/master
minor optimization of progress bar in autopopulate
2 parents 0ff34f2 + f99fe44 commit 91a43ed

File tree

6 files changed

+29
-30
lines changed

6 files changed

+29
-30
lines changed

.github/workflows/development.yaml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,6 @@ jobs:
3636
include:
3737
- py_ver: "3.7"
3838
mysql_ver: "5.7"
39-
- py_ver: "3.6"
40-
mysql_ver: "5.7"
4139
steps:
4240
- uses: actions/checkout@v2
4341
- name: Set up Python ${{matrix.py_ver}}
@@ -106,4 +104,4 @@ jobs:
106104
with:
107105
branch: gh-pages
108106
directory: gh-pages
109-
github_token: ${{secrets.GITHUB_TOKEN}}
107+
github_token: ${{secrets.GITHUB_TOKEN}}

datajoint/autopopulate.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from .errors import DataJointError, LostConnectionError
1010
import signal
1111
import multiprocessing as mp
12+
import contextlib
1213

1314
# noinspection PyExceptionInherit,PyCallingNonCallable
1415

@@ -213,7 +214,7 @@ def handler(signum, frame):
213214
if not nkeys:
214215
return
215216

216-
processes = min(*(_ for _ in (processes, nkeys, mp.cpu_count()) if _))
217+
processes = min(_ for _ in (processes, nkeys, mp.cpu_count()) if _)
217218

218219
error_list = []
219220
populate_kwargs = dict(
@@ -235,17 +236,16 @@ def handler(signum, frame):
235236
del self.connection._conn.ctx # SSLContext is not pickleable
236237
with mp.Pool(
237238
processes, _initialize_populate, (self, jobs, populate_kwargs)
238-
) as pool:
239-
if display_progress:
240-
with tqdm(desc="Processes: ", total=nkeys) as pbar:
241-
for error in pool.imap(_call_populate1, keys, chunksize=1):
242-
if error is not None:
243-
error_list.append(error)
244-
pbar.update()
245-
else:
246-
for error in pool.imap(_call_populate1, keys):
247-
if error is not None:
248-
error_list.append(error)
239+
) as pool, (
240+
tqdm(desc="Processes: ", total=nkeys)
241+
if display_progress
242+
else contextlib.nullcontext()
243+
) as progress_bar:
244+
for error in pool.imap(_call_populate1, keys, chunksize=1):
245+
if error is not None:
246+
error_list.append(error)
247+
if display_progress:
248+
progress_bar.update()
249249
self.connection.connect() # reconnect parent process to MySQL server
250250

251251
# restore original signal handler:

datajoint/connection.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
query_log_max_length = 300
2222

2323

24+
cache_key = "query_cache" # the key to lookup the query_cache folder in dj.config
25+
26+
2427
def get_host_hook(host_input):
2528
if "://" in host_input:
2629
plugin_name = host_input.split("://")[0]
@@ -220,7 +223,7 @@ def connect(self):
220223
k: v
221224
for k, v in self.conn_info.items()
222225
if k not in ["ssl_input", "host_input"]
223-
}
226+
},
224227
)
225228
except client.err.InternalError:
226229
self._conn = client.connect(
@@ -236,7 +239,7 @@ def connect(self):
236239
or k == "ssl"
237240
and self.conn_info["ssl_input"] is None
238241
)
239-
}
242+
},
240243
)
241244
self._conn.autocommit(True)
242245

@@ -254,13 +257,12 @@ def set_query_cache(self, query_cache=None):
254257
def purge_query_cache(self):
255258
"""Purges all query cache."""
256259
if (
257-
"query_cache" in config
258-
and isinstance(config["query_cache"], str)
259-
and pathlib.Path(config["query_cache"]).is_dir()
260+
isinstance(config.get(cache_key), str)
261+
and pathlib.Path(config[cache_key]).is_dir()
260262
):
261-
path_iter = pathlib.Path(config["query_cache"]).glob("**/*")
262-
for path in path_iter:
263-
path.unlink()
263+
for path in pathlib.Path(config[cache_key]).iterdir():
264+
if not path.is_dir():
265+
path.unlink()
264266

265267
def close(self):
266268
self._conn.close()
@@ -313,15 +315,15 @@ def query(
313315
"Only SELECT queries are allowed when query caching is on."
314316
)
315317
if use_query_cache:
316-
if not config["query_cache"]:
318+
if not config[cache_key]:
317319
raise errors.DataJointError(
318-
"Provide filepath dj.config['query_cache'] when using query caching."
320+
f"Provide filepath dj.config['{cache_key}'] when using query caching."
319321
)
320322
hash_ = uuid_from_buffer(
321323
(str(self._query_cache) + re.sub(r"`\$\w+`", "", query)).encode()
322324
+ pack(args)
323325
)
324-
cache_path = pathlib.Path(config["query_cache"]) / str(hash_)
326+
cache_path = pathlib.Path(config[cache_key]) / str(hash_)
325327
try:
326328
buffer = cache_path.read_bytes()
327329
except FileNotFoundError:

datajoint/table.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ def update1(self, row):
311311
raise DataJointError("Update cannot be applied to a restricted table.")
312312
key = {k: row[k] for k in self.primary_key}
313313
if len(self & key) != 1:
314-
raise DataJointError("Update entry must exist.")
314+
raise DataJointError("Update can only be applied to one existing entry.")
315315
# UPDATE query
316316
row = [
317317
self.__make_placeholder(k, v)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from os import path
44
import sys
55

6-
min_py_version = (3, 6)
6+
min_py_version = (3, 7)
77

88
if sys.version_info < min_py_version:
99
sys.exit('DataJoint is only supported for Python {}.{} or higher'.format(*min_py_version))

tests/schema_simple.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,9 +200,8 @@ class Website(dj.Part):
200200
"""
201201

202202
def populate_random(self, n=10):
203-
faker.Faker.seed(0)
204203
fake = faker.Faker()
205-
faker.Faker.seed(0) # make tests deterministic
204+
faker.Faker.seed(0) # make test deterministic
206205
for _ in range(n):
207206
profile = fake.profile()
208207
with self.connection.transaction:

0 commit comments

Comments
 (0)