Skip to content

Add tests for Bolt fix for broken DateTime encoding #470

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Jun 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 56 additions & 22 deletions boltstub/packstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from codecs import decode
import inspect
from io import BytesIO
import re
from struct import pack as struct_pack
from struct import unpack as struct_unpack

Expand Down Expand Up @@ -77,7 +78,8 @@ class StructTagV1:


class StructTagV2(StructTagV1):
pass
date_time = b"\x49"
date_time_zone_id = b"\x69"


class Structure:
Expand Down Expand Up @@ -124,9 +126,14 @@ def __repr__(self):

def __eq__(self, other):
try:
assert all(
StructTagV1.path == value.path
for key, value in locals().items()
if re.match(r"^StructTagV[1-9]\d*$", key)
)
if self.tag == StructTagV1.path:
# path struct => order of nodes and rels is irrelevant
return (other.tag == StructTagV1.path
return (other.tag == self.tag
and len(other.fields) == 3
and sorted(self.fields[0]) == sorted(other.fields[0])
and sorted(self.fields[1]) == sorted(other.fields[1])
Expand Down Expand Up @@ -169,7 +176,8 @@ def match_jolt_wildcard(self, wildcard: jolt_common_types.JoltWildcard):
if self.tag == struct_tags.local_time:
return True
elif issubclass(t, jolt_types_.JoltDateTime):
if self.tag == struct_tags.date_time:
if self.tag in (struct_tags.date_time,
struct_tags.date_time_zone_id):
return True
elif issubclass(t, jolt_types_.JoltLocalDateTime):
if self.tag == struct_tags.local_date_time:
Expand Down Expand Up @@ -201,8 +209,13 @@ def _from_jolt_v1_type(cls, jolt: jolt_v1_types.JoltType):
return cls(StructTagV1.local_time, jolt.nanoseconds,
packstream_version=1)
if isinstance(jolt, jolt_v1_types.JoltDateTime):
return cls(StructTagV1.date_time, *jolt.seconds_nanoseconds,
jolt.time.utc_offset, packstream_version=1)
if jolt.time.zone_id:
return cls(StructTagV1.date_time_zone_id,
*jolt.seconds_nanoseconds, jolt.time.zone_id,
packstream_version=1)
else:
return cls(StructTagV1.date_time, *jolt.seconds_nanoseconds,
jolt.time.utc_offset, packstream_version=1)
if isinstance(jolt, jolt_v1_types.JoltLocalDateTime):
return cls(StructTagV1.local_date_time, *jolt.seconds_nanoseconds,
packstream_version=1)
Expand Down Expand Up @@ -276,34 +289,39 @@ def _from_jolt_v1_type(cls, jolt: jolt_v1_types.JoltType):
@classmethod
def _from_jolt_v2_type(cls, jolt: jolt_v1_types.JoltType):
if isinstance(jolt, jolt_v2_types.JoltDate):
return cls(StructTagV1.date, jolt.days, packstream_version=2)
return cls(StructTagV2.date, jolt.days, packstream_version=2)
if isinstance(jolt, jolt_v2_types.JoltTime):
return cls(StructTagV1.time, jolt.nanoseconds, jolt.utc_offset,
return cls(StructTagV2.time, jolt.nanoseconds, jolt.utc_offset,
packstream_version=2)
if isinstance(jolt, jolt_v2_types.JoltLocalTime):
return cls(StructTagV1.local_time, jolt.nanoseconds,
return cls(StructTagV2.local_time, jolt.nanoseconds,
packstream_version=2)
if isinstance(jolt, jolt_v2_types.JoltDateTime):
return cls(StructTagV1.date_time, *jolt.seconds_nanoseconds,
jolt.time.utc_offset, packstream_version=2)
if jolt.time.zone_id:
return cls(StructTagV2.date_time_zone_id,
*jolt.seconds_nanoseconds, jolt.time.zone_id,
packstream_version=2)
else:
return cls(StructTagV2.date_time, *jolt.seconds_nanoseconds,
jolt.time.utc_offset, packstream_version=2)
if isinstance(jolt, jolt_v2_types.JoltLocalDateTime):
return cls(StructTagV1.local_date_time, *jolt.seconds_nanoseconds,
return cls(StructTagV2.local_date_time, *jolt.seconds_nanoseconds,
packstream_version=2)
if isinstance(jolt, jolt_v2_types.JoltDuration):
return cls(StructTagV1.duration, jolt.months, jolt.days,
return cls(StructTagV2.duration, jolt.months, jolt.days,
jolt.seconds, jolt.nanoseconds, packstream_version=2)
if isinstance(jolt, jolt_v2_types.JoltPoint):
if jolt.z is None: # 2D
return cls(StructTagV1.point_2d, jolt.srid, jolt.x, jolt.y,
return cls(StructTagV2.point_2d, jolt.srid, jolt.x, jolt.y,
packstream_version=2)
else:
return cls(StructTagV1.point_3d, jolt.srid, jolt.x, jolt.y,
return cls(StructTagV2.point_3d, jolt.srid, jolt.x, jolt.y,
jolt.z, packstream_version=2)
if isinstance(jolt, jolt_v2_types.JoltNode):
return cls(StructTagV1.node, jolt.id, jolt.labels,
return cls(StructTagV2.node, jolt.id, jolt.labels,
jolt.properties, jolt.element_id, packstream_version=2)
if isinstance(jolt, jolt_v2_types.JoltRelationship):
return cls(StructTagV1.relationship, jolt.id, jolt.start_node_id,
return cls(StructTagV2.relationship, jolt.id, jolt.start_node_id,
jolt.end_node_id, jolt.rel_type, jolt.properties,
jolt.element_id, jolt.start_node_element_id,
jolt.end_node_element_id, packstream_version=2)
Expand Down Expand Up @@ -331,7 +349,7 @@ def _from_jolt_v2_type(cls, jolt: jolt_v1_types.JoltType):
for rel in jolt.path[1::2]:
rels.append(rel)

ub_rel = cls(StructTagV1.unbound_relationship, rel.id,
ub_rel = cls(StructTagV2.unbound_relationship, rel.id,
rel.rel_type, rel.properties, rel.element_id,
packstream_version=2)
if ub_rel not in uniq_rels:
Expand All @@ -353,7 +371,7 @@ def _from_jolt_v2_type(cls, jolt: jolt_v1_types.JoltType):
else:
ids.append(-index)

return cls(StructTagV1.path, uniq_nodes, uniq_rels, ids,
return cls(StructTagV2.path, uniq_nodes, uniq_rels, ids,
packstream_version=2)
raise TypeError("Unsupported jolt type: {}".format(type(jolt)))

Expand All @@ -372,7 +390,7 @@ def _to_jolt_v1_type(self):
return jolt_v1_types.JoltTime.new(*self.fields)
if self.tag == StructTagV1.local_time:
return jolt_v1_types.JoltLocalTime.new(*self.fields)
if self.tag == StructTagV1.date_time:
if self.tag in (StructTagV1.date_time, StructTagV1.date_time_zone_id):
return jolt_v1_types.JoltDateTime.new(*self.fields)
if self.tag == StructTagV1.local_date_time:
return jolt_v1_types.JoltLocalDateTime.new(*self.fields)
Expand Down Expand Up @@ -421,7 +439,7 @@ def _to_jolt_v2_type(self):
return jolt_v2_types.JoltTime.new(*self.fields)
if self.tag == StructTagV2.local_time:
return jolt_v2_types.JoltLocalTime.new(*self.fields)
if self.tag == StructTagV2.date_time:
if self.tag in (StructTagV2.date_time, StructTagV2.date_time_zone_id):
return jolt_v2_types.JoltDateTime.new(*self.fields)
if self.tag == StructTagV2.local_date_time:
return jolt_v2_types.JoltLocalDateTime.new(*self.fields)
Expand Down Expand Up @@ -663,12 +681,28 @@ def _verify_relationship(cls, structure, fields):

@classmethod
def verify_fields(cls, structure: Structure):
# assert tags didn't change
assert all(
hasattr(StructTagV1, tag)
and getattr(StructTagV1, tag) == getattr(StructTagV2, tag)
for tag in dir(StructTagV2) if not tag.startswith("_")
for tag in dir(StructTagV2) if not (
tag.startswith("_")
or tag in ("date_time", "date_time_zone_id")
)
)

tag, fields = structure.tag, structure.fields

field_validator = {
StructTagV2.date_time: cls._build_generic_verifier(
(int, int, int,), "DateTime"
),
StructTagV2.date_time_zone_id: cls._build_generic_verifier(
(int, int, str), "DateTimeZoneId"
),
}

if tag in field_validator:
return field_validator[tag](structure, fields)
return super().verify_fields(structure)


Expand Down
18 changes: 10 additions & 8 deletions boltstub/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@
ServerExit,
)
from .packstream import Structure
from .simple_jolt.common.types import JoltWildcard
from .simple_jolt.common.types import (
JoltType,
JoltWildcard,
)


def load_parser():
Expand Down Expand Up @@ -229,12 +232,12 @@ def parse_jolt(self, jolt_package):
self,
"message fields failed JOLT parser"
) from e
decoded = self._jolt_to_struct(decoded, jolt_package)
decoded = self._jolt_to_struct(decoded)
jolt_fields.append(decoded)
self.jolt_parsed = self.parsed[0], jolt_fields
return self.jolt_parsed

def _jolt_to_struct(self, decoded, jolt_package):
def _jolt_to_struct(self, decoded):
if isinstance(decoded, JoltWildcard):
if not self.allow_jolt_wildcard:
raise LineError(
Expand All @@ -243,14 +246,12 @@ def _jolt_to_struct(self, decoded, jolt_package):
)
else:
return decoded
if isinstance(decoded, jolt_package.types.JoltType):
if isinstance(decoded, JoltType):
return Structure.from_jolt_type(decoded)
if isinstance(decoded, (list, tuple)):
return type(decoded)(self._jolt_to_struct(d, jolt_package)
for d in decoded)
return type(decoded)(self._jolt_to_struct(d) for d in decoded)
if isinstance(decoded, dict):
return {k: self._jolt_to_struct(v, jolt_package)
for k, v in decoded.items()}
return {k: self._jolt_to_struct(v) for k, v in decoded.items()}
return decoded


Expand Down Expand Up @@ -1062,6 +1063,7 @@ def __str__(self):
res += ":\n"
res += "\n".join(map(str, self.expected_lines))
res += "\n\nReceived:\n" + str(self.received)
res += "\n => " + repr(self.received)
return res


Expand Down
9 changes: 8 additions & 1 deletion boltstub/simple_jolt/v1/codec.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import abc
import importlib
import inspect
import re
import sys
Expand Down Expand Up @@ -561,7 +562,13 @@ class Codec:
def decode(cls, value):
def transform(value_):
if isinstance(value_, dict) and len(value_) == 1:
sigil = next(iter(value_))
sigil, content = next(iter(value_.items()))
match = re.match(r"(.+)(v\d+)", sigil)
if match:
sigil, version = match.groups()
other_codec = importlib.import_module(f"..{version}.codec",
package=__package__)
return other_codec.Codec.decode({sigil: content})
transformer = cls.sigil_to_type.get(sigil)
if transformer:
return transformer.decode_full(value_[sigil], transform)
Expand Down
Loading