18
18
import typing
19
19
import traceback
20
20
21
- from restate .context import DurablePromise , ObjectContext , Request
21
+ from restate .context import DurablePromise , ObjectContext , Request , SendHandle
22
22
from restate .exceptions import TerminalError
23
23
from restate .handler import Handler , handler_from_callable , invoke_handler
24
24
from restate .serde import BytesSerde , JsonSerde , Serde
36
36
# disable line too long
37
37
# pylint: disable=C0301
38
38
39
+ # disable too few public methods
40
+ # pylint: disable=R0903
41
+
42
+ class ServerSendHandle (SendHandle ):
43
+ """This class implements the send API"""
44
+ _invocation_id : typing .Optional [str ]
45
+
46
+ def __init__ (self , context , handle : int ) -> None :
47
+ self .handle = handle
48
+ self .context = context
49
+ self ._invocation_id = None
50
+
51
+ async def invocation_id (self ) -> str :
52
+ """Get the invocation id."""
53
+ if self ._invocation_id is not None :
54
+ return self ._invocation_id
55
+ res = await self .context .create_poll_or_cancel_coroutine (self .handle )
56
+ self ._invocation_id = res
57
+ return res
58
+
59
+
39
60
async def async_value (n : Callable [[], T ]) -> T :
40
61
"""convert a simple value to a coroutine."""
41
62
return n ()
@@ -317,7 +338,7 @@ def do_call(self,
317
338
parameter : I ,
318
339
key : Optional [str ] = None ,
319
340
send_delay : Optional [timedelta ] = None ,
320
- send : bool = False ) -> Awaitable [O ] | None :
341
+ send : bool = False ) -> Awaitable [O ] | SendHandle :
321
342
"""Make an RPC call to the given handler"""
322
343
target_handler = handler_from_callable (tpe )
323
344
service = target_handler .service_tag .name
@@ -335,16 +356,16 @@ def do_raw_call(self,
335
356
output_serde : Serde [O ],
336
357
key : Optional [str ] = None ,
337
358
send_delay : Optional [timedelta ] = None ,
338
- send : bool = False ) -> Awaitable [O ] | None :
359
+ send : bool = False ) -> Awaitable [O ] | SendHandle :
339
360
"""Make an RPC call to the given handler"""
340
361
parameter = input_serde .serialize (input_param )
341
362
if send_delay :
342
363
ms = int (send_delay .total_seconds () * 1000 )
343
- self .vm .sys_send (service , handler , parameter , key , delay = ms )
344
- return None
364
+ send_handle = self .vm .sys_send (service , handler , parameter , key , delay = ms )
365
+ return ServerSendHandle ( self , send_handle )
345
366
if send :
346
- self .vm .sys_send (service , handler , parameter , key )
347
- return None
367
+ send_handle = self .vm .sys_send (service , handler , parameter , key )
368
+ return ServerSendHandle ( self , send_handle )
348
369
349
370
handle = self .vm .sys_call (service = service ,
350
371
handler = handler ,
@@ -362,11 +383,13 @@ def service_call(self,
362
383
tpe : Callable [[Any , I ], Awaitable [O ]],
363
384
arg : I ) -> Awaitable [O ]:
364
385
coro = self .do_call (tpe , arg )
365
- assert coro is not None
366
- return coro
386
+ assert coro is not SendHandle
387
+ return coro # type: ignore
367
388
368
- def service_send (self , tpe : Callable [[Any , I ], Awaitable [O ]], arg : I , send_delay : timedelta | None = None ) -> None :
369
- self .do_call (tpe = tpe , parameter = arg , send_delay = send_delay , send = True )
389
+ def service_send (self , tpe : Callable [[Any , I ], Awaitable [O ]], arg : I , send_delay : timedelta | None = None ) -> SendHandle :
390
+ send = self .do_call (tpe = tpe , parameter = arg , send_delay = send_delay , send = True )
391
+ assert send is SendHandle
392
+ return send # type: ignore
370
393
371
394
def object_call (self ,
372
395
tpe : Callable [[Any , I ],Awaitable [O ]],
@@ -375,26 +398,30 @@ def object_call(self,
375
398
send_delay : Optional [timedelta ] = None ,
376
399
send : bool = False ) -> Awaitable [O ]:
377
400
coro = self .do_call (tpe , arg , key , send_delay , send )
378
- assert coro is not None
379
- return coro
401
+ assert coro is not SendHandle
402
+ return coro # type: ignore
380
403
381
- def object_send (self , tpe : Callable [[Any , I ], Awaitable [O ]], key : str , arg : I , send_delay : timedelta | None = None ) -> None :
382
- self .do_call (tpe = tpe , key = key , parameter = arg , send_delay = send_delay , send = True )
404
+ def object_send (self , tpe : Callable [[Any , I ], Awaitable [O ]], key : str , arg : I , send_delay : timedelta | None = None ) -> SendHandle :
405
+ send = self .do_call (tpe = tpe , key = key , parameter = arg , send_delay = send_delay , send = True )
406
+ assert send is SendHandle
407
+ return send # type: ignore
383
408
384
409
def workflow_call (self ,
385
410
tpe : Callable [[Any , I ], Awaitable [O ]],
386
411
key : str ,
387
412
arg : I ) -> Awaitable [O ]:
388
413
return self .object_call (tpe , key , arg )
389
414
390
- def workflow_send (self , tpe : Callable [[Any , I ], Awaitable [O ]], key : str , arg : I , send_delay : timedelta | None = None ) -> None :
391
- return self .object_send (tpe , key , arg , send_delay )
415
+ def workflow_send (self , tpe : Callable [[Any , I ], Awaitable [O ]], key : str , arg : I , send_delay : timedelta | None = None ) -> SendHandle :
416
+ send = self .object_send (tpe , key , arg , send_delay )
417
+ assert send is SendHandle
418
+ return send # type: ignore
392
419
393
420
def generic_call (self , service : str , handler : str , arg : bytes , key : str | None = None ) -> Awaitable [bytes ]:
394
421
serde = BytesSerde ()
395
422
return self .do_raw_call (service , handler , arg , serde , serde , key ) # type: ignore
396
423
397
- def generic_send (self , service : str , handler : str , arg : bytes , key : str | None = None , send_delay : timedelta | None = None ) -> None :
424
+ def generic_send (self , service : str , handler : str , arg : bytes , key : str | None = None , send_delay : timedelta | None = None ) -> SendHandle :
398
425
serde = BytesSerde ()
399
426
return self .do_raw_call (service , handler , arg , serde , serde , key , send_delay , True ) # type: ignore
400
427
0 commit comments