Skip to content

Commit 52e2cfc

Browse files
authored
Merge pull request #1204 from dimitri-yatsenko/master
fix #1170: Support long make calls
2 parents ddcdc72 + a284616 commit 52e2cfc

File tree

10 files changed

+64
-25
lines changed

10 files changed

+64
-25
lines changed

datajoint/autopopulate.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import signal
1313
import multiprocessing as mp
1414
import contextlib
15+
import deepdiff
1516

1617
# noinspection PyExceptionInherit,PyCallingNonCallable
1718

@@ -309,17 +310,46 @@ def _populate1(
309310
):
310311
return False
311312

312-
self.connection.start_transaction()
313+
# if make is a generator, it transaction can be delayed until the final stage
314+
is_generator = inspect.isgeneratorfunction(make)
315+
if not is_generator:
316+
self.connection.start_transaction()
317+
313318
if key in self.target: # already populated
314-
self.connection.cancel_transaction()
319+
if not is_generator:
320+
self.connection.cancel_transaction()
315321
if jobs is not None:
316322
jobs.complete(self.target.table_name, self._job_key(key))
317323
return False
318324

319325
logger.debug(f"Making {key} -> {self.target.full_table_name}")
320326
self.__class__._allow_insert = True
327+
321328
try:
322-
make(dict(key), **(make_kwargs or {}))
329+
if not is_generator:
330+
make(dict(key), **(make_kwargs or {}))
331+
else:
332+
# tripartite make - transaction is delayed until the final stage
333+
gen = make(dict(key), **(make_kwargs or {}))
334+
fetched_data = next(gen)
335+
fetch_hash = deepdiff.DeepHash(
336+
fetched_data, ignore_iterable_order=False
337+
)[fetched_data]
338+
computed_result = next(gen) # perform the computation
339+
# fetch and insert inside a transaction
340+
self.connection.start_transaction()
341+
gen = make(dict(key), **(make_kwargs or {})) # restart make
342+
fetched_data = next(gen)
343+
if (
344+
fetch_hash
345+
!= deepdiff.DeepHash(fetched_data, ignore_iterable_order=False)[
346+
fetched_data
347+
]
348+
): # rollback due to referential integrity fail
349+
self.connection.cancel_transaction()
350+
return False
351+
gen.send(computed_result) # insert
352+
323353
except (KeyboardInterrupt, SystemExit, Exception) as error:
324354
try:
325355
self.connection.cancel_transaction()

datajoint/blob.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ def pack_blob(self, obj):
204204
return self.pack_dict(obj)
205205
if isinstance(obj, str):
206206
return self.pack_string(obj)
207-
if isinstance(obj, collections.abc.ByteString):
207+
if isinstance(obj, (bytes, bytearray)):
208208
return self.pack_bytes(obj)
209209
if isinstance(obj, collections.abc.MutableSequence):
210210
return self.pack_list(obj)

datajoint/connection.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import pathlib
1313

1414
from .settings import config
15-
from . import errors
15+
from . import errors, __version__
1616
from .dependencies import Dependencies
1717
from .blob import pack, unpack
1818
from .hash import uuid_from_buffer
@@ -190,15 +190,20 @@ def __init__(self, host, user, password, port=None, init_fun=None, use_tls=None)
190190
self.conn_info["ssl_input"] = use_tls
191191
self.conn_info["host_input"] = host_input
192192
self.init_fun = init_fun
193-
logger.info("Connecting {user}@{host}:{port}".format(**self.conn_info))
194193
self._conn = None
195194
self._query_cache = None
196195
connect_host_hook(self)
197196
if self.is_connected:
198-
logger.info("Connected {user}@{host}:{port}".format(**self.conn_info))
197+
logger.info(
198+
"DataJoint {version} connected to {user}@{host}:{port}".format(
199+
version=__version__, **self.conn_info
200+
)
201+
)
199202
self.connection_id = self.query("SELECT connection_id()").fetchone()[0]
200203
else:
201-
raise errors.LostConnectionError("Connection failed.")
204+
raise errors.LostConnectionError(
205+
"Connection failed {user}@{host}:{port}".format(**self.conn_info)
206+
)
202207
self._in_transaction = False
203208
self.schemas = dict()
204209
self.dependencies = Dependencies(self)
@@ -344,7 +349,7 @@ def query(
344349
except errors.LostConnectionError:
345350
if not reconnect:
346351
raise
347-
logger.warning("MySQL server has gone away. Reconnecting to the server.")
352+
logger.warning("Reconnecting to MySQL server.")
348353
connect_host_hook(self)
349354
if self._in_transaction:
350355
self.cancel_transaction()

datajoint/external.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
def subfold(name, folds):
2424
"""
25-
subfolding for external storage: e.g. subfold('aBCdefg', (2, 3)) --> ['ab','cde']
25+
subfolding for external storage: e.g. subfold('aBCdefg', (2, 3)) --> ['ab','cde']
2626
"""
2727
return (
2828
(name[: folds[0]].lower(),) + subfold(name[folds[0] :], folds[1:])
@@ -278,7 +278,7 @@ def upload_filepath(self, local_filepath):
278278

279279
# check if the remote file already exists and verify that it matches
280280
check_hash = (self & {"hash": uuid}).fetch("contents_hash")
281-
if check_hash:
281+
if check_hash.size:
282282
# the tracking entry exists, check that it's the same file as before
283283
if contents_hash != check_hash[0]:
284284
raise DataJointError(

datajoint/schemas.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -482,8 +482,8 @@ def list_tables(self):
482482
return [
483483
t
484484
for d, t in (
485-
full_t.replace("`", "").split(".")
486-
for full_t in self.connection.dependencies.topo_sort()
485+
table_name.replace("`", "").split(".")
486+
for table_name in self.connection.dependencies.topo_sort()
487487
)
488488
if d == self.database
489489
]

datajoint/settings.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
Settings for DataJoint.
2+
Settings for DataJoint
33
"""
44

55
from contextlib import contextmanager
@@ -48,7 +48,8 @@
4848
"database.use_tls": None,
4949
"enable_python_native_blobs": True, # python-native/dj0 encoding support
5050
"add_hidden_timestamp": False,
51-
"filepath_checksum_size_limit": None, # file size limit for when to disable checksums
51+
# file size limit for when to disable checksums
52+
"filepath_checksum_size_limit": None,
5253
}
5354
)
5455

@@ -117,6 +118,7 @@ def load(self, filename):
117118
if filename is None:
118119
filename = LOCALCONFIG
119120
with open(filename, "r") as fid:
121+
logger.info(f"DataJoint is configured from {os.path.abspath(filename)}")
120122
self._conf.update(json.load(fid))
121123

122124
def save_local(self, verbose=False):
@@ -236,7 +238,8 @@ class __Config:
236238

237239
def __init__(self, *args, **kwargs):
238240
self._conf = dict(default)
239-
self._conf.update(dict(*args, **kwargs)) # use the free update to set keys
241+
# use the free update to set keys
242+
self._conf.update(dict(*args, **kwargs))
240243

241244
def __getitem__(self, key):
242245
return self._conf[key]
@@ -250,7 +253,9 @@ def __setitem__(self, key, value):
250253
valid_logging_levels = {"DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"}
251254
if key == "loglevel":
252255
if value not in valid_logging_levels:
253-
raise ValueError(f"{'value'} is not a valid logging value")
256+
raise ValueError(
257+
f"'{value}' is not a valid logging value {tuple(valid_logging_levels)}"
258+
)
254259
logger.setLevel(value)
255260

256261

@@ -260,11 +265,9 @@ def __setitem__(self, key, value):
260265
os.path.expanduser(n) for n in (LOCALCONFIG, os.path.join("~", GLOBALCONFIG))
261266
)
262267
try:
263-
config_file = next(n for n in config_files if os.path.exists(n))
268+
config.load(next(n for n in config_files if os.path.exists(n)))
264269
except StopIteration:
265-
pass
266-
else:
267-
config.load(config_file)
270+
logger.info("No config file was found.")
268271

269272
# override login credentials with environment variables
270273
mapping = {
@@ -292,6 +295,8 @@ def __setitem__(self, key, value):
292295
)
293296
if v is not None
294297
}
295-
config.update(mapping)
298+
if mapping:
299+
logger.info(f"Overloaded settings {tuple(mapping)} from environment variables.")
300+
config.update(mapping)
296301

297302
logger.setLevel(log_levels[config["loglevel"]])

datajoint/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
__version__ = "0.14.3"
1+
__version__ = "0.14.4"
22

33
assert len(__version__) <= 10 # The log table limits version to the 10 characters

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ version = "0.14.3"
44
dependencies = [
55
"numpy",
66
"pymysql>=0.7.2",
7+
"deepdiff",
78
"pyparsing",
89
"ipython",
910
"pandas",

tests/test_declare.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,6 @@ class WithSuchALongPartNameThatItCrashesMySQL(dj.Part):
359359

360360

361361
def test_regex_mismatch(schema_any):
362-
363362
class IndexAttribute(dj.Manual):
364363
definition = """
365364
index: int

tests/test_relational_operand.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -574,7 +574,6 @@ def test_union_multiple(schema_simp_pop):
574574

575575

576576
class TestDjTop:
577-
578577
def test_restrictions_by_top(self, schema_simp_pop):
579578
a = L() & dj.Top()
580579
b = L() & dj.Top(order_by=["cond_in_l", "KEY"])

0 commit comments

Comments
 (0)