Skip to content

Commit 24ae217

Browse files
committed
Add SendHandle
1 parent 7f78a2a commit 24ae217

File tree

3 files changed

+67
-25
lines changed

3 files changed

+67
-25
lines changed

python/restate/context.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,18 @@ def clear(self, name: str) -> None:
8181
def clear_all(self) -> None:
8282
"""clear all the values in the store."""
8383

84+
# pylint: disable=R0903
85+
class SendHandle(abc.ABC):
86+
"""
87+
Represents a send operation.
88+
"""
89+
90+
@abc.abstractmethod
91+
async def invocation_id(self) -> str:
92+
"""
93+
Returns the invocation id of the send operation.
94+
"""
95+
8496
class Context(abc.ABC):
8597
"""
8698
Represents the context of the current invocation.
@@ -133,7 +145,7 @@ def service_send(self,
133145
tpe: Callable[[Any, I], Awaitable[O]],
134146
arg: I,
135147
send_delay: Optional[timedelta] = None,
136-
) -> None:
148+
) -> SendHandle:
137149
"""
138150
Invokes the given service with the given argument.
139151
"""
@@ -153,7 +165,7 @@ def object_send(self,
153165
key: str,
154166
arg: I,
155167
send_delay: Optional[timedelta] = None,
156-
) -> None:
168+
) -> SendHandle:
157169
"""
158170
Send a message to an object with the given argument.
159171
"""
@@ -173,7 +185,7 @@ def workflow_send(self,
173185
key: str,
174186
arg: I,
175187
send_delay: Optional[timedelta] = None,
176-
) -> None:
188+
) -> SendHandle:
177189
"""
178190
Send a message to an object with the given argument.
179191
"""
@@ -195,7 +207,7 @@ def generic_send(self,
195207
handler: str,
196208
arg: bytes,
197209
key: Optional[str] = None,
198-
send_delay: Optional[timedelta] = None) -> None:
210+
send_delay: Optional[timedelta] = None) -> SendHandle:
199211
"""
200212
Send a message to a generic service/handler with the given argument.
201213
"""

python/restate/server_context.py

Lines changed: 45 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import typing
1919
import traceback
2020

21-
from restate.context import DurablePromise, ObjectContext, Request
21+
from restate.context import DurablePromise, ObjectContext, Request, SendHandle
2222
from restate.exceptions import TerminalError
2323
from restate.handler import Handler, handler_from_callable, invoke_handler
2424
from restate.serde import BytesSerde, JsonSerde, Serde
@@ -36,6 +36,27 @@
3636
# disable line too long
3737
# pylint: disable=C0301
3838

39+
# disable too few public methods
40+
# pylint: disable=R0903
41+
42+
class ServerSendHandle(SendHandle):
43+
"""This class implements the send API"""
44+
_invocation_id: typing.Optional[str]
45+
46+
def __init__(self, context, handle: int) -> None:
47+
self.handle = handle
48+
self.context = context
49+
self._invocation_id = None
50+
51+
async def invocation_id(self) -> str:
52+
"""Get the invocation id."""
53+
if self._invocation_id is not None:
54+
return self._invocation_id
55+
res = await self.context.create_poll_or_cancel_coroutine(self.handle)
56+
self._invocation_id = res
57+
return res
58+
59+
3960
async def async_value(n: Callable[[], T]) -> T:
4061
"""convert a simple value to a coroutine."""
4162
return n()
@@ -317,7 +338,7 @@ def do_call(self,
317338
parameter: I,
318339
key: Optional[str] = None,
319340
send_delay: Optional[timedelta] = None,
320-
send: bool = False) -> Awaitable[O] | None:
341+
send: bool = False) -> Awaitable[O] | SendHandle:
321342
"""Make an RPC call to the given handler"""
322343
target_handler = handler_from_callable(tpe)
323344
service=target_handler.service_tag.name
@@ -335,16 +356,16 @@ def do_raw_call(self,
335356
output_serde: Serde[O],
336357
key: Optional[str] = None,
337358
send_delay: Optional[timedelta] = None,
338-
send: bool = False) -> Awaitable[O] | None:
359+
send: bool = False) -> Awaitable[O] | SendHandle:
339360
"""Make an RPC call to the given handler"""
340361
parameter = input_serde.serialize(input_param)
341362
if send_delay:
342363
ms = int(send_delay.total_seconds() * 1000)
343-
self.vm.sys_send(service, handler, parameter, key, delay=ms)
344-
return None
364+
send_handle = self.vm.sys_send(service, handler, parameter, key, delay=ms)
365+
return ServerSendHandle(self, send_handle)
345366
if send:
346-
self.vm.sys_send(service, handler, parameter, key)
347-
return None
367+
send_handle = self.vm.sys_send(service, handler, parameter, key)
368+
return ServerSendHandle(self, send_handle)
348369

349370
handle = self.vm.sys_call(service=service,
350371
handler=handler,
@@ -362,11 +383,13 @@ def service_call(self,
362383
tpe: Callable[[Any, I], Awaitable[O]],
363384
arg: I) -> Awaitable[O]:
364385
coro = self.do_call(tpe, arg)
365-
assert coro is not None
366-
return coro
386+
assert coro is not SendHandle
387+
return coro # type: ignore
367388

368-
def service_send(self, tpe: Callable[[Any, I], Awaitable[O]], arg: I, send_delay: timedelta | None = None) -> None:
369-
self.do_call(tpe=tpe, parameter=arg, send_delay=send_delay, send=True)
389+
def service_send(self, tpe: Callable[[Any, I], Awaitable[O]], arg: I, send_delay: timedelta | None = None) -> SendHandle:
390+
send = self.do_call(tpe=tpe, parameter=arg, send_delay=send_delay, send=True)
391+
assert send is SendHandle
392+
return send # type: ignore
370393

371394
def object_call(self,
372395
tpe: Callable[[Any, I],Awaitable[O]],
@@ -375,26 +398,30 @@ def object_call(self,
375398
send_delay: Optional[timedelta] = None,
376399
send: bool = False) -> Awaitable[O]:
377400
coro = self.do_call(tpe, arg, key, send_delay, send)
378-
assert coro is not None
379-
return coro
401+
assert coro is not SendHandle
402+
return coro # type: ignore
380403

381-
def object_send(self, tpe: Callable[[Any, I], Awaitable[O]], key: str, arg: I, send_delay: timedelta | None = None) -> None:
382-
self.do_call(tpe=tpe, key=key, parameter=arg, send_delay=send_delay, send=True)
404+
def object_send(self, tpe: Callable[[Any, I], Awaitable[O]], key: str, arg: I, send_delay: timedelta | None = None) -> SendHandle:
405+
send = self.do_call(tpe=tpe, key=key, parameter=arg, send_delay=send_delay, send=True)
406+
assert send is SendHandle
407+
return send # type: ignore
383408

384409
def workflow_call(self,
385410
tpe: Callable[[Any, I], Awaitable[O]],
386411
key: str,
387412
arg: I) -> Awaitable[O]:
388413
return self.object_call(tpe, key, arg)
389414

390-
def workflow_send(self, tpe: Callable[[Any, I], Awaitable[O]], key: str, arg: I, send_delay: timedelta | None = None) -> None:
391-
return self.object_send(tpe, key, arg, send_delay)
415+
def workflow_send(self, tpe: Callable[[Any, I], Awaitable[O]], key: str, arg: I, send_delay: timedelta | None = None) -> SendHandle:
416+
send = self.object_send(tpe, key, arg, send_delay)
417+
assert send is SendHandle
418+
return send # type: ignore
392419

393420
def generic_call(self, service: str, handler: str, arg: bytes, key: str | None = None) -> Awaitable[bytes]:
394421
serde = BytesSerde()
395422
return self.do_raw_call(service, handler, arg, serde, serde, key) # type: ignore
396423

397-
def generic_send(self, service: str, handler: str, arg: bytes, key: str | None = None, send_delay: timedelta | None = None) -> None:
424+
def generic_send(self, service: str, handler: str, arg: bytes, key: str | None = None, send_delay: timedelta | None = None) -> SendHandle:
398425
serde = BytesSerde()
399426
return self.do_raw_call(service, handler, arg, serde, serde , key, send_delay, True) # type: ignore
400427

python/restate/vm.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -253,9 +253,12 @@ def sys_send(self,
253253
handler: str,
254254
parameter: bytes,
255255
key: typing.Optional[str] = None,
256-
delay: typing.Optional[int] = None) -> None:
257-
"""send an invocation to a service (no response)"""
258-
self.vm.sys_send(service, handler, parameter, key, delay)
256+
delay: typing.Optional[int] = None) -> int:
257+
"""
258+
send an invocation to a service, and return the handle
259+
to the promise that will resolve with the invocation id
260+
"""
261+
return self.vm.sys_send(service, handler, parameter, key, delay)
259262

260263
def sys_run(self, name: str) -> int:
261264
"""

0 commit comments

Comments
 (0)