Skip to content

Commit 168c4a5

Browse files
author
hhsecond
committed
pipeline
1 parent 9ffe297 commit 168c4a5

File tree

3 files changed

+199
-94
lines changed

3 files changed

+199
-94
lines changed

redisai/client.py

Lines changed: 161 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import warnings
44

55
from redis import StrictRedis
6+
from redis.client import Pipeline as RedisPipeline
67
import numpy as np
78

89
from . import command_builder as builder
@@ -12,78 +13,6 @@
1213
processor = Processor()
1314

1415

15-
def enable_debug(f):
16-
@wraps(f)
17-
def wrapper(*args):
18-
print(*args)
19-
return f(*args)
20-
return wrapper
21-
22-
23-
class Dag:
24-
def __init__(self, load, persist, executor, readonly=False):
25-
self.result_processors = []
26-
if readonly:
27-
if persist:
28-
raise RuntimeError("READONLY requests cannot write (duh!) and should not "
29-
"have PERSISTing values")
30-
self.commands = ['AI.DAGRUN_RO']
31-
else:
32-
self.commands = ['AI.DAGRUN']
33-
if load:
34-
if not isinstance(load, (list, tuple)):
35-
self.commands += ["LOAD", 1, load]
36-
else:
37-
self.commands += ["LOAD", len(load), *load]
38-
if persist:
39-
if not isinstance(persist, (list, tuple)):
40-
self.commands += ["PERSIST", 1, persist, '|>']
41-
else:
42-
self.commands += ["PERSIST", len(persist), *persist, '|>']
43-
elif load:
44-
self.commands.append('|>')
45-
self.executor = executor
46-
47-
def tensorset(self,
48-
key: AnyStr,
49-
tensor: Union[np.ndarray, list, tuple],
50-
shape: Sequence[int] = None,
51-
dtype: str = None) -> Any:
52-
args = builder.tensorset(key, tensor, shape, dtype)
53-
self.commands.extend(args)
54-
self.commands.append("|>")
55-
self.result_processors.append(bytes.decode)
56-
return self
57-
58-
def tensorget(self,
59-
key: AnyStr, as_numpy: bool = True,
60-
meta_only: bool = False) -> Any:
61-
args = builder.tensorget(key, as_numpy, meta_only)
62-
self.commands.extend(args)
63-
self.commands.append("|>")
64-
self.result_processors.append(partial(processor.tensorget,
65-
as_numpy=as_numpy,
66-
meta_only=meta_only))
67-
return self
68-
69-
def modelrun(self,
70-
key: AnyStr,
71-
inputs: Union[AnyStr, List[AnyStr]],
72-
outputs: Union[AnyStr, List[AnyStr]]) -> Any:
73-
args = builder.modelrun(key, inputs, outputs)
74-
self.commands.extend(args)
75-
self.commands.append("|>")
76-
self.result_processors.append(bytes.decode)
77-
return self
78-
79-
def run(self):
80-
results = self.executor(*self.commands)
81-
out = []
82-
for res, fn in zip(results, self.result_processors):
83-
out.append(fn(res))
84-
return out
85-
86-
8716
class Client(StrictRedis):
8817
"""
8918
Redis client build specifically for the RedisAI module. It takes all the necessary
@@ -96,20 +25,47 @@ class Client(StrictRedis):
9625
debug : bool
9726
If debug mode is ON, then each command that is sent to the server is
9827
printed to the terminal
28+
enable_postprocess : bool
29+
Flag to enable post processing. If enabled, all the bytestring-ed returns
30+
are converted to python strings recursively and key value pairs will be converted
31+
to dictionaries. Also note that, this flag doesn't work with pipeline() function
32+
since pipeline function could have native redis commands (along with RedisAI
33+
commands)
9934
10035
Example
10136
-------
10237
>>> from redisai import Client
10338
>>> con = Client(host='localhost', port=6379)
10439
"""
105-
def __init__(self, debug=False, *args, **kwargs):
40+
def __init__(self, debug=False, enable_postprocess=True, *args, **kwargs):
10641
super().__init__(*args, **kwargs)
10742
if debug:
10843
self.execute_command = enable_debug(super().execute_command)
44+
self.enable_postprocess = enable_postprocess
45+
46+
def pipeline(self, transaction: bool = True, shard_hint: bool = None) -> 'Pipeline':
47+
"""
48+
It follows the same pipeline implementation of native redis client but enables it
49+
to access redisai operation as well. This function is experimental in the
50+
current release.
51+
52+
Example
53+
-------
54+
>>> pipe = con.pipeline(transaction=False)
55+
>>> pipe = pipe.set('nativeKey', 1)
56+
>>> pipe = pipe.tensorset('redisaiKey', np.array([1, 2]))
57+
>>> pipe.execute()
58+
[True, b'OK']
59+
"""
60+
return Pipeline(self.enable_postprocess,
61+
self.connection_pool,
62+
self.response_callbacks,
63+
transaction=True, shard_hint=None)
10964

11065
def dag(self, load: Sequence = None, persist: Sequence = None,
111-
readonly: bool = False) -> Dag:
112-
""" It returns a DAG object on which other DAG-allowed operations can be called. For
66+
readonly: bool = False) -> 'Dag':
67+
"""
68+
It returns a DAG object on which other DAG-allowed operations can be called. For
11369
more details about DAG in RedisAI, refer to the RedisAI documentation.
11470
11571
Parameters
@@ -141,7 +97,7 @@ def dag(self, load: Sequence = None, persist: Sequence = None,
14197
>>> # You can even chain the operations
14298
>>> result = dag.tensorset(**akwargs).modelrun(**bkwargs).tensorget(**ckwargs).run()
14399
"""
144-
return Dag(load, persist, self.execute_command, readonly)
100+
return Dag(load, persist, self.execute_command, readonly, self.enable_postprocess)
145101

146102
def loadbackend(self, identifier: AnyStr, path: AnyStr) -> str:
147103
"""
@@ -168,7 +124,7 @@ def loadbackend(self, identifier: AnyStr, path: AnyStr) -> str:
168124
"""
169125
args = builder.loadbackend(identifier, path)
170126
res = self.execute_command(*args)
171-
return processor.loadbackend(res)
127+
return res if not self.enable_postprocess else processor.loadbackend(res)
172128

173129
def modelset(self,
174130
key: AnyStr,
@@ -227,7 +183,7 @@ def modelset(self,
227183
args = builder.modelset(key, backend, device, data,
228184
batch, minbatch, tag, inputs, outputs)
229185
res = self.execute_command(*args)
230-
return processor.modelset(res)
186+
return res if not self.enable_postprocess else processor.modelset(res)
231187

232188
def modelget(self, key: AnyStr, meta_only=False) -> dict:
233189
"""
@@ -253,7 +209,7 @@ def modelget(self, key: AnyStr, meta_only=False) -> dict:
253209
"""
254210
args = builder.modelget(key, meta_only)
255211
res = self.execute_command(*args)
256-
return processor.modelget(res)
212+
return res if not self.enable_postprocess else processor.modelget(res)
257213

258214
def modeldel(self, key: AnyStr) -> str:
259215
"""
@@ -276,7 +232,7 @@ def modeldel(self, key: AnyStr) -> str:
276232
"""
277233
args = builder.modeldel(key)
278234
res = self.execute_command(*args)
279-
return processor.modeldel(res)
235+
return res if not self.enable_postprocess else processor.modeldel(res)
280236

281237
def modelrun(self,
282238
key: AnyStr,
@@ -318,7 +274,7 @@ def modelrun(self,
318274
"""
319275
args = builder.modelrun(key, inputs, outputs)
320276
res = self.execute_command(*args)
321-
return processor.modelrun(res)
277+
return res if not self.enable_postprocess else processor.modelrun(res)
322278

323279
def modelscan(self) -> List[List[AnyStr]]:
324280
"""
@@ -340,7 +296,7 @@ def modelscan(self) -> List[List[AnyStr]]:
340296
"in the future without any notice", UserWarning)
341297
args = builder.modelscan()
342298
res = self.execute_command(*args)
343-
return processor.modelscan(res)
299+
return res if not self.enable_postprocess else processor.modelscan(res)
344300

345301
def tensorset(self,
346302
key: AnyStr,
@@ -376,7 +332,7 @@ def tensorset(self,
376332
"""
377333
args = builder.tensorset(key, tensor, shape, dtype)
378334
res = self.execute_command(*args)
379-
return processor.tensorset(res)
335+
return res if not self.enable_postprocess else processor.tensorset(res)
380336

381337
def tensorget(self,
382338
key: AnyStr, as_numpy: bool = True,
@@ -412,7 +368,8 @@ def tensorget(self,
412368
"""
413369
args = builder.tensorget(key, as_numpy, meta_only)
414370
res = self.execute_command(*args)
415-
return processor.tensorget(res, as_numpy, meta_only)
371+
return res if not self.enable_postprocess else processor.tensorget(res,
372+
as_numpy, meta_only)
416373

417374
def scriptset(self, key: AnyStr, device: str, script: str, tag: AnyStr = None) -> str:
418375
"""
@@ -456,7 +413,7 @@ def scriptset(self, key: AnyStr, device: str, script: str, tag: AnyStr = None) -
456413
"""
457414
args = builder.scriptset(key, device, script, tag)
458415
res = self.execute_command(*args)
459-
return processor.scriptset(res)
416+
return res if not self.enable_postprocess else processor.scriptset(res)
460417

461418
def scriptget(self, key: AnyStr, meta_only=False) -> dict:
462419
"""
@@ -481,7 +438,7 @@ def scriptget(self, key: AnyStr, meta_only=False) -> dict:
481438
"""
482439
args = builder.scriptget(key, meta_only)
483440
res = self.execute_command(*args)
484-
return processor.scriptget(res)
441+
return res if not self.enable_postprocess else processor.scriptget(res)
485442

486443
def scriptdel(self, key: AnyStr) -> str:
487444
"""
@@ -504,7 +461,7 @@ def scriptdel(self, key: AnyStr) -> str:
504461
"""
505462
args = builder.scriptdel(key)
506463
res = self.execute_command(*args)
507-
return processor.scriptdel(res)
464+
return res if not self.enable_postprocess else processor.scriptdel(res)
508465

509466
def scriptrun(self,
510467
key: AnyStr,
@@ -540,7 +497,7 @@ def scriptrun(self,
540497
"""
541498
args = builder.scriptrun(key, function, inputs, outputs)
542499
res = self.execute_command(*args)
543-
return processor.scriptrun(res)
500+
return res if not self.enable_postprocess else processor.scriptrun(res)
544501

545502
def scriptscan(self) -> List[List[AnyStr]]:
546503
"""
@@ -561,7 +518,7 @@ def scriptscan(self) -> List[List[AnyStr]]:
561518
"in the future without any notice", UserWarning)
562519
args = builder.scriptscan()
563520
res = self.execute_command(*args)
564-
return processor.scriptscan(res)
521+
return res if not self.enable_postprocess else processor.scriptscan(res)
565522

566523
def infoget(self, key: AnyStr) -> dict:
567524
"""
@@ -590,7 +547,7 @@ def infoget(self, key: AnyStr) -> dict:
590547
"""
591548
args = builder.infoget(key)
592549
res = self.execute_command(*args)
593-
return processor.infoget(res)
550+
return res if not self.enable_postprocess else processor.infoget(res)
594551

595552
def inforeset(self, key: AnyStr) -> str:
596553
"""
@@ -613,4 +570,117 @@ def inforeset(self, key: AnyStr) -> str:
613570
"""
614571
args = builder.inforeset(key)
615572
res = self.execute_command(*args)
616-
return processor.inforeset(res)
573+
return res if not self.enable_postprocess else processor.inforeset(res)
574+
575+
576+
class Pipeline(RedisPipeline, Client):
577+
def __init__(self, enable_postprocess, *args, **kwargs):
578+
warnings.warn("Pipeling AI commands through this client is experimental.",
579+
UserWarning)
580+
self.enable_postprocess = False
581+
if enable_postprocess:
582+
warnings.warn("Postprocessing is enabled but not allowed in pipelines."
583+
"Disable postprocessing to remove this warning.", UserWarning)
584+
self.tensorget_processors = []
585+
super().__init__(*args, **kwargs)
586+
587+
def dag(self, *args, **kwargs):
588+
raise RuntimeError("Pipeline object doesn't allow DAG creation currently")
589+
590+
def tensorget(self, key, as_numpy=True, meta_only=False):
591+
self.tensorget_processors.append(partial(processor.tensorget,
592+
as_numpy=as_numpy,
593+
meta_only=meta_only))
594+
return super().tensorget(key, as_numpy, meta_only)
595+
596+
def _execute_transaction(self, *args, **kwargs):
597+
res = super()._execute_transaction(*args, **kwargs)
598+
for i in range(len(res)):
599+
# tensorget will have minimum 4 values if meta_only = True
600+
if isinstance(res[i], list) and len(res[i]) >= 4:
601+
res[i] = self.tensorget_processors.pop(0)(res[i])
602+
return res
603+
604+
def _execute_pipeline(self, *args, **kwargs):
605+
res = super()._execute_pipeline(*args, **kwargs)
606+
for i in range(len(res)):
607+
# tensorget will have minimum 4 values if meta_only = True
608+
if isinstance(res[i], list) and len(res[i]) >= 4:
609+
res[i] = self.tensorget_processors.pop(0)(res[i])
610+
return res
611+
612+
613+
class Dag:
614+
def __init__(self, load, persist, executor, readonly=False, postprocess=True):
615+
self.result_processors = []
616+
self.enable_postprocess = True
617+
if readonly:
618+
if persist:
619+
raise RuntimeError("READONLY requests cannot write (duh!) and should not "
620+
"have PERSISTing values")
621+
self.commands = ['AI.DAGRUN_RO']
622+
else:
623+
self.commands = ['AI.DAGRUN']
624+
if load:
625+
if not isinstance(load, (list, tuple)):
626+
self.commands += ["LOAD", 1, load]
627+
else:
628+
self.commands += ["LOAD", len(load), *load]
629+
if persist:
630+
if not isinstance(persist, (list, tuple)):
631+
self.commands += ["PERSIST", 1, persist, '|>']
632+
else:
633+
self.commands += ["PERSIST", len(persist), *persist, '|>']
634+
elif load:
635+
self.commands.append('|>')
636+
self.executor = executor
637+
638+
def tensorset(self,
639+
key: AnyStr,
640+
tensor: Union[np.ndarray, list, tuple],
641+
shape: Sequence[int] = None,
642+
dtype: str = None) -> Any:
643+
args = builder.tensorset(key, tensor, shape, dtype)
644+
self.commands.extend(args)
645+
self.commands.append("|>")
646+
self.result_processors.append(bytes.decode)
647+
return self
648+
649+
def tensorget(self,
650+
key: AnyStr, as_numpy: bool = True,
651+
meta_only: bool = False) -> Any:
652+
args = builder.tensorget(key, as_numpy, meta_only)
653+
self.commands.extend(args)
654+
self.commands.append("|>")
655+
self.result_processors.append(partial(processor.tensorget,
656+
as_numpy=as_numpy,
657+
meta_only=meta_only))
658+
return self
659+
660+
def modelrun(self,
661+
key: AnyStr,
662+
inputs: Union[AnyStr, List[AnyStr]],
663+
outputs: Union[AnyStr, List[AnyStr]]) -> Any:
664+
args = builder.modelrun(key, inputs, outputs)
665+
self.commands.extend(args)
666+
self.commands.append("|>")
667+
self.result_processors.append(bytes.decode)
668+
return self
669+
670+
def run(self):
671+
results = self.executor(*self.commands)
672+
if self.enable_postprocess:
673+
out = []
674+
for res, fn in zip(results, self.result_processors):
675+
out.append(fn(res))
676+
else:
677+
out = results
678+
return out
679+
680+
681+
def enable_debug(f):
682+
@wraps(f)
683+
def wrapper(*args):
684+
print(*args)
685+
return f(*args)
686+
return wrapper

0 commit comments

Comments
 (0)