3
3
import warnings
4
4
5
5
from redis import StrictRedis
6
+ from redis .client import Pipeline as RedisPipeline
6
7
import numpy as np
7
8
8
9
from . import command_builder as builder
12
13
processor = Processor ()
13
14
14
15
15
- def enable_debug (f ):
16
- @wraps (f )
17
- def wrapper (* args ):
18
- print (* args )
19
- return f (* args )
20
- return wrapper
21
-
22
-
23
- class Dag :
24
- def __init__ (self , load , persist , executor , readonly = False ):
25
- self .result_processors = []
26
- if readonly :
27
- if persist :
28
- raise RuntimeError ("READONLY requests cannot write (duh!) and should not "
29
- "have PERSISTing values" )
30
- self .commands = ['AI.DAGRUN_RO' ]
31
- else :
32
- self .commands = ['AI.DAGRUN' ]
33
- if load :
34
- if not isinstance (load , (list , tuple )):
35
- self .commands += ["LOAD" , 1 , load ]
36
- else :
37
- self .commands += ["LOAD" , len (load ), * load ]
38
- if persist :
39
- if not isinstance (persist , (list , tuple )):
40
- self .commands += ["PERSIST" , 1 , persist , '|>' ]
41
- else :
42
- self .commands += ["PERSIST" , len (persist ), * persist , '|>' ]
43
- elif load :
44
- self .commands .append ('|>' )
45
- self .executor = executor
46
-
47
- def tensorset (self ,
48
- key : AnyStr ,
49
- tensor : Union [np .ndarray , list , tuple ],
50
- shape : Sequence [int ] = None ,
51
- dtype : str = None ) -> Any :
52
- args = builder .tensorset (key , tensor , shape , dtype )
53
- self .commands .extend (args )
54
- self .commands .append ("|>" )
55
- self .result_processors .append (bytes .decode )
56
- return self
57
-
58
- def tensorget (self ,
59
- key : AnyStr , as_numpy : bool = True ,
60
- meta_only : bool = False ) -> Any :
61
- args = builder .tensorget (key , as_numpy , meta_only )
62
- self .commands .extend (args )
63
- self .commands .append ("|>" )
64
- self .result_processors .append (partial (processor .tensorget ,
65
- as_numpy = as_numpy ,
66
- meta_only = meta_only ))
67
- return self
68
-
69
- def modelrun (self ,
70
- key : AnyStr ,
71
- inputs : Union [AnyStr , List [AnyStr ]],
72
- outputs : Union [AnyStr , List [AnyStr ]]) -> Any :
73
- args = builder .modelrun (key , inputs , outputs )
74
- self .commands .extend (args )
75
- self .commands .append ("|>" )
76
- self .result_processors .append (bytes .decode )
77
- return self
78
-
79
- def run (self ):
80
- results = self .executor (* self .commands )
81
- out = []
82
- for res , fn in zip (results , self .result_processors ):
83
- out .append (fn (res ))
84
- return out
85
-
86
-
87
16
class Client (StrictRedis ):
88
17
"""
89
18
Redis client build specifically for the RedisAI module. It takes all the necessary
@@ -96,20 +25,47 @@ class Client(StrictRedis):
96
25
debug : bool
97
26
If debug mode is ON, then each command that is sent to the server is
98
27
printed to the terminal
28
+ enable_postprocess : bool
29
+ Flag to enable post processing. If enabled, all the bytestring-ed returns
30
+ are converted to python strings recursively and key value pairs will be converted
31
+ to dictionaries. Also note that, this flag doesn't work with pipeline() function
32
+ since pipeline function could have native redis commands (along with RedisAI
33
+ commands)
99
34
100
35
Example
101
36
-------
102
37
>>> from redisai import Client
103
38
>>> con = Client(host='localhost', port=6379)
104
39
"""
105
- def __init__ (self , debug = False , * args , ** kwargs ):
40
+ def __init__ (self , debug = False , enable_postprocess = True , * args , ** kwargs ):
106
41
super ().__init__ (* args , ** kwargs )
107
42
if debug :
108
43
self .execute_command = enable_debug (super ().execute_command )
44
+ self .enable_postprocess = enable_postprocess
45
+
46
+ def pipeline (self , transaction : bool = True , shard_hint : bool = None ) -> 'Pipeline' :
47
+ """
48
+ It follows the same pipeline implementation of native redis client but enables it
49
+ to access redisai operation as well. This function is experimental in the
50
+ current release.
51
+
52
+ Example
53
+ -------
54
+ >>> pipe = con.pipeline(transaction=False)
55
+ >>> pipe = pipe.set('nativeKey', 1)
56
+ >>> pipe = pipe.tensorset('redisaiKey', np.array([1, 2]))
57
+ >>> pipe.execute()
58
+ [True, b'OK']
59
+ """
60
+ return Pipeline (self .enable_postprocess ,
61
+ self .connection_pool ,
62
+ self .response_callbacks ,
63
+ transaction = True , shard_hint = None )
109
64
110
65
def dag (self , load : Sequence = None , persist : Sequence = None ,
111
- readonly : bool = False ) -> Dag :
112
- """ It returns a DAG object on which other DAG-allowed operations can be called. For
66
+ readonly : bool = False ) -> 'Dag' :
67
+ """
68
+ It returns a DAG object on which other DAG-allowed operations can be called. For
113
69
more details about DAG in RedisAI, refer to the RedisAI documentation.
114
70
115
71
Parameters
@@ -141,7 +97,7 @@ def dag(self, load: Sequence = None, persist: Sequence = None,
141
97
>>> # You can even chain the operations
142
98
>>> result = dag.tensorset(**akwargs).modelrun(**bkwargs).tensorget(**ckwargs).run()
143
99
"""
144
- return Dag (load , persist , self .execute_command , readonly )
100
+ return Dag (load , persist , self .execute_command , readonly , self . enable_postprocess )
145
101
146
102
def loadbackend (self , identifier : AnyStr , path : AnyStr ) -> str :
147
103
"""
@@ -168,7 +124,7 @@ def loadbackend(self, identifier: AnyStr, path: AnyStr) -> str:
168
124
"""
169
125
args = builder .loadbackend (identifier , path )
170
126
res = self .execute_command (* args )
171
- return processor .loadbackend (res )
127
+ return res if not self . enable_postprocess else processor .loadbackend (res )
172
128
173
129
def modelset (self ,
174
130
key : AnyStr ,
@@ -227,7 +183,7 @@ def modelset(self,
227
183
args = builder .modelset (key , backend , device , data ,
228
184
batch , minbatch , tag , inputs , outputs )
229
185
res = self .execute_command (* args )
230
- return processor .modelset (res )
186
+ return res if not self . enable_postprocess else processor .modelset (res )
231
187
232
188
def modelget (self , key : AnyStr , meta_only = False ) -> dict :
233
189
"""
@@ -253,7 +209,7 @@ def modelget(self, key: AnyStr, meta_only=False) -> dict:
253
209
"""
254
210
args = builder .modelget (key , meta_only )
255
211
res = self .execute_command (* args )
256
- return processor .modelget (res )
212
+ return res if not self . enable_postprocess else processor .modelget (res )
257
213
258
214
def modeldel (self , key : AnyStr ) -> str :
259
215
"""
@@ -276,7 +232,7 @@ def modeldel(self, key: AnyStr) -> str:
276
232
"""
277
233
args = builder .modeldel (key )
278
234
res = self .execute_command (* args )
279
- return processor .modeldel (res )
235
+ return res if not self . enable_postprocess else processor .modeldel (res )
280
236
281
237
def modelrun (self ,
282
238
key : AnyStr ,
@@ -318,7 +274,7 @@ def modelrun(self,
318
274
"""
319
275
args = builder .modelrun (key , inputs , outputs )
320
276
res = self .execute_command (* args )
321
- return processor .modelrun (res )
277
+ return res if not self . enable_postprocess else processor .modelrun (res )
322
278
323
279
def modelscan (self ) -> List [List [AnyStr ]]:
324
280
"""
@@ -340,7 +296,7 @@ def modelscan(self) -> List[List[AnyStr]]:
340
296
"in the future without any notice" , UserWarning )
341
297
args = builder .modelscan ()
342
298
res = self .execute_command (* args )
343
- return processor .modelscan (res )
299
+ return res if not self . enable_postprocess else processor .modelscan (res )
344
300
345
301
def tensorset (self ,
346
302
key : AnyStr ,
@@ -376,7 +332,7 @@ def tensorset(self,
376
332
"""
377
333
args = builder .tensorset (key , tensor , shape , dtype )
378
334
res = self .execute_command (* args )
379
- return processor .tensorset (res )
335
+ return res if not self . enable_postprocess else processor .tensorset (res )
380
336
381
337
def tensorget (self ,
382
338
key : AnyStr , as_numpy : bool = True ,
@@ -412,7 +368,8 @@ def tensorget(self,
412
368
"""
413
369
args = builder .tensorget (key , as_numpy , meta_only )
414
370
res = self .execute_command (* args )
415
- return processor .tensorget (res , as_numpy , meta_only )
371
+ return res if not self .enable_postprocess else processor .tensorget (res ,
372
+ as_numpy , meta_only )
416
373
417
374
def scriptset (self , key : AnyStr , device : str , script : str , tag : AnyStr = None ) -> str :
418
375
"""
@@ -456,7 +413,7 @@ def scriptset(self, key: AnyStr, device: str, script: str, tag: AnyStr = None) -
456
413
"""
457
414
args = builder .scriptset (key , device , script , tag )
458
415
res = self .execute_command (* args )
459
- return processor .scriptset (res )
416
+ return res if not self . enable_postprocess else processor .scriptset (res )
460
417
461
418
def scriptget (self , key : AnyStr , meta_only = False ) -> dict :
462
419
"""
@@ -481,7 +438,7 @@ def scriptget(self, key: AnyStr, meta_only=False) -> dict:
481
438
"""
482
439
args = builder .scriptget (key , meta_only )
483
440
res = self .execute_command (* args )
484
- return processor .scriptget (res )
441
+ return res if not self . enable_postprocess else processor .scriptget (res )
485
442
486
443
def scriptdel (self , key : AnyStr ) -> str :
487
444
"""
@@ -504,7 +461,7 @@ def scriptdel(self, key: AnyStr) -> str:
504
461
"""
505
462
args = builder .scriptdel (key )
506
463
res = self .execute_command (* args )
507
- return processor .scriptdel (res )
464
+ return res if not self . enable_postprocess else processor .scriptdel (res )
508
465
509
466
def scriptrun (self ,
510
467
key : AnyStr ,
@@ -540,7 +497,7 @@ def scriptrun(self,
540
497
"""
541
498
args = builder .scriptrun (key , function , inputs , outputs )
542
499
res = self .execute_command (* args )
543
- return processor .scriptrun (res )
500
+ return res if not self . enable_postprocess else processor .scriptrun (res )
544
501
545
502
def scriptscan (self ) -> List [List [AnyStr ]]:
546
503
"""
@@ -561,7 +518,7 @@ def scriptscan(self) -> List[List[AnyStr]]:
561
518
"in the future without any notice" , UserWarning )
562
519
args = builder .scriptscan ()
563
520
res = self .execute_command (* args )
564
- return processor .scriptscan (res )
521
+ return res if not self . enable_postprocess else processor .scriptscan (res )
565
522
566
523
def infoget (self , key : AnyStr ) -> dict :
567
524
"""
@@ -590,7 +547,7 @@ def infoget(self, key: AnyStr) -> dict:
590
547
"""
591
548
args = builder .infoget (key )
592
549
res = self .execute_command (* args )
593
- return processor .infoget (res )
550
+ return res if not self . enable_postprocess else processor .infoget (res )
594
551
595
552
def inforeset (self , key : AnyStr ) -> str :
596
553
"""
@@ -613,4 +570,117 @@ def inforeset(self, key: AnyStr) -> str:
613
570
"""
614
571
args = builder .inforeset (key )
615
572
res = self .execute_command (* args )
616
- return processor .inforeset (res )
573
+ return res if not self .enable_postprocess else processor .inforeset (res )
574
+
575
+
576
+ class Pipeline (RedisPipeline , Client ):
577
+ def __init__ (self , enable_postprocess , * args , ** kwargs ):
578
+ warnings .warn ("Pipeling AI commands through this client is experimental." ,
579
+ UserWarning )
580
+ self .enable_postprocess = False
581
+ if enable_postprocess :
582
+ warnings .warn ("Postprocessing is enabled but not allowed in pipelines."
583
+ "Disable postprocessing to remove this warning." , UserWarning )
584
+ self .tensorget_processors = []
585
+ super ().__init__ (* args , ** kwargs )
586
+
587
+ def dag (self , * args , ** kwargs ):
588
+ raise RuntimeError ("Pipeline object doesn't allow DAG creation currently" )
589
+
590
+ def tensorget (self , key , as_numpy = True , meta_only = False ):
591
+ self .tensorget_processors .append (partial (processor .tensorget ,
592
+ as_numpy = as_numpy ,
593
+ meta_only = meta_only ))
594
+ return super ().tensorget (key , as_numpy , meta_only )
595
+
596
+ def _execute_transaction (self , * args , ** kwargs ):
597
+ res = super ()._execute_transaction (* args , ** kwargs )
598
+ for i in range (len (res )):
599
+ # tensorget will have minimum 4 values if meta_only = True
600
+ if isinstance (res [i ], list ) and len (res [i ]) >= 4 :
601
+ res [i ] = self .tensorget_processors .pop (0 )(res [i ])
602
+ return res
603
+
604
+ def _execute_pipeline (self , * args , ** kwargs ):
605
+ res = super ()._execute_pipeline (* args , ** kwargs )
606
+ for i in range (len (res )):
607
+ # tensorget will have minimum 4 values if meta_only = True
608
+ if isinstance (res [i ], list ) and len (res [i ]) >= 4 :
609
+ res [i ] = self .tensorget_processors .pop (0 )(res [i ])
610
+ return res
611
+
612
+
613
+ class Dag :
614
+ def __init__ (self , load , persist , executor , readonly = False , postprocess = True ):
615
+ self .result_processors = []
616
+ self .enable_postprocess = True
617
+ if readonly :
618
+ if persist :
619
+ raise RuntimeError ("READONLY requests cannot write (duh!) and should not "
620
+ "have PERSISTing values" )
621
+ self .commands = ['AI.DAGRUN_RO' ]
622
+ else :
623
+ self .commands = ['AI.DAGRUN' ]
624
+ if load :
625
+ if not isinstance (load , (list , tuple )):
626
+ self .commands += ["LOAD" , 1 , load ]
627
+ else :
628
+ self .commands += ["LOAD" , len (load ), * load ]
629
+ if persist :
630
+ if not isinstance (persist , (list , tuple )):
631
+ self .commands += ["PERSIST" , 1 , persist , '|>' ]
632
+ else :
633
+ self .commands += ["PERSIST" , len (persist ), * persist , '|>' ]
634
+ elif load :
635
+ self .commands .append ('|>' )
636
+ self .executor = executor
637
+
638
+ def tensorset (self ,
639
+ key : AnyStr ,
640
+ tensor : Union [np .ndarray , list , tuple ],
641
+ shape : Sequence [int ] = None ,
642
+ dtype : str = None ) -> Any :
643
+ args = builder .tensorset (key , tensor , shape , dtype )
644
+ self .commands .extend (args )
645
+ self .commands .append ("|>" )
646
+ self .result_processors .append (bytes .decode )
647
+ return self
648
+
649
+ def tensorget (self ,
650
+ key : AnyStr , as_numpy : bool = True ,
651
+ meta_only : bool = False ) -> Any :
652
+ args = builder .tensorget (key , as_numpy , meta_only )
653
+ self .commands .extend (args )
654
+ self .commands .append ("|>" )
655
+ self .result_processors .append (partial (processor .tensorget ,
656
+ as_numpy = as_numpy ,
657
+ meta_only = meta_only ))
658
+ return self
659
+
660
+ def modelrun (self ,
661
+ key : AnyStr ,
662
+ inputs : Union [AnyStr , List [AnyStr ]],
663
+ outputs : Union [AnyStr , List [AnyStr ]]) -> Any :
664
+ args = builder .modelrun (key , inputs , outputs )
665
+ self .commands .extend (args )
666
+ self .commands .append ("|>" )
667
+ self .result_processors .append (bytes .decode )
668
+ return self
669
+
670
+ def run (self ):
671
+ results = self .executor (* self .commands )
672
+ if self .enable_postprocess :
673
+ out = []
674
+ for res , fn in zip (results , self .result_processors ):
675
+ out .append (fn (res ))
676
+ else :
677
+ out = results
678
+ return out
679
+
680
+
681
+ def enable_debug (f ):
682
+ @wraps (f )
683
+ def wrapper (* args ):
684
+ print (* args )
685
+ return f (* args )
686
+ return wrapper
0 commit comments