20
20
from codecs import decode
21
21
import inspect
22
22
from io import BytesIO
23
+ import re
23
24
from struct import pack as struct_pack
24
25
from struct import unpack as struct_unpack
25
26
@@ -77,7 +78,8 @@ class StructTagV1:
77
78
78
79
79
80
class StructTagV2 (StructTagV1 ):
80
- pass
81
+ date_time = b"\x49 "
82
+ date_time_zone_id = b"\x69 "
81
83
82
84
83
85
class Structure :
@@ -124,9 +126,14 @@ def __repr__(self):
124
126
125
127
def __eq__ (self , other ):
126
128
try :
129
+ assert all (
130
+ StructTagV1 .path == value .path
131
+ for key , value in locals ().items ()
132
+ if re .match (r"^StructTagV[1-9]\d*$" , key )
133
+ )
127
134
if self .tag == StructTagV1 .path :
128
135
# path struct => order of nodes and rels is irrelevant
129
- return (other .tag == StructTagV1 . path
136
+ return (other .tag == self . tag
130
137
and len (other .fields ) == 3
131
138
and sorted (self .fields [0 ]) == sorted (other .fields [0 ])
132
139
and sorted (self .fields [1 ]) == sorted (other .fields [1 ])
@@ -169,7 +176,8 @@ def match_jolt_wildcard(self, wildcard: jolt_common_types.JoltWildcard):
169
176
if self .tag == struct_tags .local_time :
170
177
return True
171
178
elif issubclass (t , jolt_types_ .JoltDateTime ):
172
- if self .tag == struct_tags .date_time :
179
+ if self .tag in (struct_tags .date_time ,
180
+ struct_tags .date_time_zone_id ):
173
181
return True
174
182
elif issubclass (t , jolt_types_ .JoltLocalDateTime ):
175
183
if self .tag == struct_tags .local_date_time :
@@ -201,8 +209,13 @@ def _from_jolt_v1_type(cls, jolt: jolt_v1_types.JoltType):
201
209
return cls (StructTagV1 .local_time , jolt .nanoseconds ,
202
210
packstream_version = 1 )
203
211
if isinstance (jolt , jolt_v1_types .JoltDateTime ):
204
- return cls (StructTagV1 .date_time , * jolt .seconds_nanoseconds ,
205
- jolt .time .utc_offset , packstream_version = 1 )
212
+ if jolt .time .zone_id :
213
+ return cls (StructTagV1 .date_time_zone_id ,
214
+ * jolt .seconds_nanoseconds , jolt .time .zone_id ,
215
+ packstream_version = 1 )
216
+ else :
217
+ return cls (StructTagV1 .date_time , * jolt .seconds_nanoseconds ,
218
+ jolt .time .utc_offset , packstream_version = 1 )
206
219
if isinstance (jolt , jolt_v1_types .JoltLocalDateTime ):
207
220
return cls (StructTagV1 .local_date_time , * jolt .seconds_nanoseconds ,
208
221
packstream_version = 1 )
@@ -276,34 +289,39 @@ def _from_jolt_v1_type(cls, jolt: jolt_v1_types.JoltType):
276
289
@classmethod
277
290
def _from_jolt_v2_type (cls , jolt : jolt_v1_types .JoltType ):
278
291
if isinstance (jolt , jolt_v2_types .JoltDate ):
279
- return cls (StructTagV1 .date , jolt .days , packstream_version = 2 )
292
+ return cls (StructTagV2 .date , jolt .days , packstream_version = 2 )
280
293
if isinstance (jolt , jolt_v2_types .JoltTime ):
281
- return cls (StructTagV1 .time , jolt .nanoseconds , jolt .utc_offset ,
294
+ return cls (StructTagV2 .time , jolt .nanoseconds , jolt .utc_offset ,
282
295
packstream_version = 2 )
283
296
if isinstance (jolt , jolt_v2_types .JoltLocalTime ):
284
- return cls (StructTagV1 .local_time , jolt .nanoseconds ,
297
+ return cls (StructTagV2 .local_time , jolt .nanoseconds ,
285
298
packstream_version = 2 )
286
299
if isinstance (jolt , jolt_v2_types .JoltDateTime ):
287
- return cls (StructTagV1 .date_time , * jolt .seconds_nanoseconds ,
288
- jolt .time .utc_offset , packstream_version = 2 )
300
+ if jolt .time .zone_id :
301
+ return cls (StructTagV2 .date_time_zone_id ,
302
+ * jolt .seconds_nanoseconds , jolt .time .zone_id ,
303
+ packstream_version = 2 )
304
+ else :
305
+ return cls (StructTagV2 .date_time , * jolt .seconds_nanoseconds ,
306
+ jolt .time .utc_offset , packstream_version = 2 )
289
307
if isinstance (jolt , jolt_v2_types .JoltLocalDateTime ):
290
- return cls (StructTagV1 .local_date_time , * jolt .seconds_nanoseconds ,
308
+ return cls (StructTagV2 .local_date_time , * jolt .seconds_nanoseconds ,
291
309
packstream_version = 2 )
292
310
if isinstance (jolt , jolt_v2_types .JoltDuration ):
293
- return cls (StructTagV1 .duration , jolt .months , jolt .days ,
311
+ return cls (StructTagV2 .duration , jolt .months , jolt .days ,
294
312
jolt .seconds , jolt .nanoseconds , packstream_version = 2 )
295
313
if isinstance (jolt , jolt_v2_types .JoltPoint ):
296
314
if jolt .z is None : # 2D
297
- return cls (StructTagV1 .point_2d , jolt .srid , jolt .x , jolt .y ,
315
+ return cls (StructTagV2 .point_2d , jolt .srid , jolt .x , jolt .y ,
298
316
packstream_version = 2 )
299
317
else :
300
- return cls (StructTagV1 .point_3d , jolt .srid , jolt .x , jolt .y ,
318
+ return cls (StructTagV2 .point_3d , jolt .srid , jolt .x , jolt .y ,
301
319
jolt .z , packstream_version = 2 )
302
320
if isinstance (jolt , jolt_v2_types .JoltNode ):
303
- return cls (StructTagV1 .node , jolt .id , jolt .labels ,
321
+ return cls (StructTagV2 .node , jolt .id , jolt .labels ,
304
322
jolt .properties , jolt .element_id , packstream_version = 2 )
305
323
if isinstance (jolt , jolt_v2_types .JoltRelationship ):
306
- return cls (StructTagV1 .relationship , jolt .id , jolt .start_node_id ,
324
+ return cls (StructTagV2 .relationship , jolt .id , jolt .start_node_id ,
307
325
jolt .end_node_id , jolt .rel_type , jolt .properties ,
308
326
jolt .element_id , jolt .start_node_element_id ,
309
327
jolt .end_node_element_id , packstream_version = 2 )
@@ -331,7 +349,7 @@ def _from_jolt_v2_type(cls, jolt: jolt_v1_types.JoltType):
331
349
for rel in jolt .path [1 ::2 ]:
332
350
rels .append (rel )
333
351
334
- ub_rel = cls (StructTagV1 .unbound_relationship , rel .id ,
352
+ ub_rel = cls (StructTagV2 .unbound_relationship , rel .id ,
335
353
rel .rel_type , rel .properties , rel .element_id ,
336
354
packstream_version = 2 )
337
355
if ub_rel not in uniq_rels :
@@ -353,7 +371,7 @@ def _from_jolt_v2_type(cls, jolt: jolt_v1_types.JoltType):
353
371
else :
354
372
ids .append (- index )
355
373
356
- return cls (StructTagV1 .path , uniq_nodes , uniq_rels , ids ,
374
+ return cls (StructTagV2 .path , uniq_nodes , uniq_rels , ids ,
357
375
packstream_version = 2 )
358
376
raise TypeError ("Unsupported jolt type: {}" .format (type (jolt )))
359
377
@@ -372,7 +390,7 @@ def _to_jolt_v1_type(self):
372
390
return jolt_v1_types .JoltTime .new (* self .fields )
373
391
if self .tag == StructTagV1 .local_time :
374
392
return jolt_v1_types .JoltLocalTime .new (* self .fields )
375
- if self .tag == StructTagV1 .date_time :
393
+ if self .tag in ( StructTagV1 .date_time , StructTagV1 . date_time_zone_id ) :
376
394
return jolt_v1_types .JoltDateTime .new (* self .fields )
377
395
if self .tag == StructTagV1 .local_date_time :
378
396
return jolt_v1_types .JoltLocalDateTime .new (* self .fields )
@@ -421,7 +439,7 @@ def _to_jolt_v2_type(self):
421
439
return jolt_v2_types .JoltTime .new (* self .fields )
422
440
if self .tag == StructTagV2 .local_time :
423
441
return jolt_v2_types .JoltLocalTime .new (* self .fields )
424
- if self .tag == StructTagV2 .date_time :
442
+ if self .tag in ( StructTagV2 .date_time , StructTagV2 . date_time_zone_id ) :
425
443
return jolt_v2_types .JoltDateTime .new (* self .fields )
426
444
if self .tag == StructTagV2 .local_date_time :
427
445
return jolt_v2_types .JoltLocalDateTime .new (* self .fields )
@@ -663,12 +681,28 @@ def _verify_relationship(cls, structure, fields):
663
681
664
682
@classmethod
665
683
def verify_fields (cls , structure : Structure ):
666
- # assert tags didn't change
667
684
assert all (
668
685
hasattr (StructTagV1 , tag )
669
686
and getattr (StructTagV1 , tag ) == getattr (StructTagV2 , tag )
670
- for tag in dir (StructTagV2 ) if not tag .startswith ("_" )
687
+ for tag in dir (StructTagV2 ) if not (
688
+ tag .startswith ("_" )
689
+ or tag in ("date_time" , "date_time_zone_id" )
690
+ )
671
691
)
692
+
693
+ tag , fields = structure .tag , structure .fields
694
+
695
+ field_validator = {
696
+ StructTagV2 .date_time : cls ._build_generic_verifier (
697
+ (int , int , int ,), "DateTime"
698
+ ),
699
+ StructTagV2 .date_time_zone_id : cls ._build_generic_verifier (
700
+ (int , int , str ), "DateTimeZoneId"
701
+ ),
702
+ }
703
+
704
+ if tag in field_validator :
705
+ return field_validator [tag ](structure , fields )
672
706
return super ().verify_fields (structure )
673
707
674
708
0 commit comments