Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions .github/workflows/development.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@ jobs:
include:
- py_ver: "3.7"
mysql_ver: "5.7"
- py_ver: "3.6"
mysql_ver: "5.7"
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{matrix.py_ver}}
Expand Down Expand Up @@ -106,4 +104,4 @@ jobs:
with:
branch: gh-pages
directory: gh-pages
github_token: ${{secrets.GITHUB_TOKEN}}
github_token: ${{secrets.GITHUB_TOKEN}}
24 changes: 12 additions & 12 deletions datajoint/autopopulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .errors import DataJointError, LostConnectionError
import signal
import multiprocessing as mp
import contextlib

# noinspection PyExceptionInherit,PyCallingNonCallable

Expand Down Expand Up @@ -213,7 +214,7 @@ def handler(signum, frame):
if not nkeys:
return

processes = min(*(_ for _ in (processes, nkeys, mp.cpu_count()) if _))
processes = min(_ for _ in (processes, nkeys, mp.cpu_count()) if _)

error_list = []
populate_kwargs = dict(
Expand All @@ -235,17 +236,16 @@ def handler(signum, frame):
del self.connection._conn.ctx # SSLContext is not pickleable
with mp.Pool(
processes, _initialize_populate, (self, jobs, populate_kwargs)
) as pool:
if display_progress:
with tqdm(desc="Processes: ", total=nkeys) as pbar:
for error in pool.imap(_call_populate1, keys, chunksize=1):
if error is not None:
error_list.append(error)
pbar.update()
else:
for error in pool.imap(_call_populate1, keys):
if error is not None:
error_list.append(error)
) as pool, (
tqdm(desc="Processes: ", total=nkeys)
if display_progress
else contextlib.nullcontext()
) as progress_bar:
for error in pool.imap(_call_populate1, keys, chunksize=1):
if error is not None:
error_list.append(error)
if display_progress:
progress_bar.update()
self.connection.connect() # reconnect parent process to MySQL server

# restore original signal handler:
Expand Down
24 changes: 13 additions & 11 deletions datajoint/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
query_log_max_length = 300


cache_key = "query_cache" # the key to lookup the query_cache folder in dj.config


def get_host_hook(host_input):
if "://" in host_input:
plugin_name = host_input.split("://")[0]
Expand Down Expand Up @@ -220,7 +223,7 @@ def connect(self):
k: v
for k, v in self.conn_info.items()
if k not in ["ssl_input", "host_input"]
}
},
)
except client.err.InternalError:
self._conn = client.connect(
Expand All @@ -236,7 +239,7 @@ def connect(self):
or k == "ssl"
and self.conn_info["ssl_input"] is None
)
}
},
)
self._conn.autocommit(True)

Expand All @@ -254,13 +257,12 @@ def set_query_cache(self, query_cache=None):
def purge_query_cache(self):
"""Purges all query cache."""
if (
"query_cache" in config
and isinstance(config["query_cache"], str)
and pathlib.Path(config["query_cache"]).is_dir()
isinstance(config.get(cache_key), str)
and pathlib.Path(config[cache_key]).is_dir()
):
path_iter = pathlib.Path(config["query_cache"]).glob("**/*")
for path in path_iter:
path.unlink()
for path in pathlib.Path(config[cache_key]).iterdir():
if not path.is_dir():
path.unlink()

def close(self):
self._conn.close()
Expand Down Expand Up @@ -313,15 +315,15 @@ def query(
"Only SELECT queries are allowed when query caching is on."
)
if use_query_cache:
if not config["query_cache"]:
if not config[cache_key]:
raise errors.DataJointError(
"Provide filepath dj.config['query_cache'] when using query caching."
f"Provide filepath dj.config['{cache_key}'] when using query caching."
)
hash_ = uuid_from_buffer(
(str(self._query_cache) + re.sub(r"`\$\w+`", "", query)).encode()
+ pack(args)
)
cache_path = pathlib.Path(config["query_cache"]) / str(hash_)
cache_path = pathlib.Path(config[cache_key]) / str(hash_)
try:
buffer = cache_path.read_bytes()
except FileNotFoundError:
Expand Down
2 changes: 1 addition & 1 deletion datajoint/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ def update1(self, row):
raise DataJointError("Update cannot be applied to a restricted table.")
key = {k: row[k] for k in self.primary_key}
if len(self & key) != 1:
raise DataJointError("Update entry must exist.")
raise DataJointError("Update can only be applied to one existing entry.")
# UPDATE query
row = [
self.__make_placeholder(k, v)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from os import path
import sys

min_py_version = (3, 6)
min_py_version = (3, 7)

if sys.version_info < min_py_version:
sys.exit('DataJoint is only supported for Python {}.{} or higher'.format(*min_py_version))
Expand Down
3 changes: 1 addition & 2 deletions tests/schema_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,9 +200,8 @@ class Website(dj.Part):
"""

def populate_random(self, n=10):
faker.Faker.seed(0)
fake = faker.Faker()
faker.Faker.seed(0) # make tests deterministic
faker.Faker.seed(0) # make test deterministic
for _ in range(n):
profile = fake.profile()
with self.connection.transaction:
Expand Down