Skip to content

Commit e38e9dc

Browse files
committed
Add invocation_id to a call promise
1 parent ad78d56 commit e38e9dc

File tree

2 files changed

+61
-20
lines changed

2 files changed

+61
-20
lines changed

python/restate/context.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,19 @@ def __await__(self):
3838
pass
3939

4040

41+
# pylint: disable=R0903
42+
class RestateDurableCallFuture(RestateDurableFuture[T]):
43+
"""
44+
Represents a durable call future.
45+
"""
46+
47+
@abc.abstractmethod
48+
async def invocation_id(self) -> str:
49+
"""
50+
Returns the invocation id of the call.
51+
"""
52+
53+
4154

4255
@dataclass
4356
class Request:
@@ -147,7 +160,7 @@ def sleep(self, delta: timedelta) -> RestateDurableFuture[None]:
147160
def service_call(self,
148161
tpe: Callable[[Any, I], Awaitable[O]],
149162
arg: I,
150-
idempotency_key: str | None = None) -> RestateDurableFuture[O]:
163+
idempotency_key: str | None = None) -> RestateDurableCallFuture[O]:
151164
"""
152165
Invokes the given service with the given argument.
153166
"""
@@ -170,7 +183,7 @@ def object_call(self,
170183
key: str,
171184
arg: I,
172185
idempotency_key: str | None = None,
173-
) -> RestateDurableFuture[O]:
186+
) -> RestateDurableCallFuture[O]:
174187
"""
175188
Invokes the given object with the given argument.
176189
"""
@@ -193,7 +206,7 @@ def workflow_call(self,
193206
key: str,
194207
arg: I,
195208
idempotency_key: str | None = None,
196-
) -> RestateDurableFuture[O]:
209+
) -> RestateDurableCallFuture[O]:
197210
"""
198211
Invokes the given workflow with the given argument.
199212
"""
@@ -217,7 +230,7 @@ def generic_call(self,
217230
handler: str,
218231
arg: bytes,
219232
key: Optional[str] = None,
220-
idempotency_key: str | None = None) -> RestateDurableFuture[bytes]:
233+
idempotency_key: str | None = None) -> RestateDurableCallFuture[bytes]:
221234
"""
222235
Invokes the given generic service/handler with the given argument.
223236
"""

python/restate/server_context.py

Lines changed: 44 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import typing
1919
import traceback
2020

21-
from restate.context import DurablePromise, ObjectContext, Request, RestateDurableFuture, SendHandle
21+
from restate.context import DurablePromise, ObjectContext, Request, RestateDurableCallFuture, RestateDurableFuture, SendHandle
2222
from restate.exceptions import TerminalError
2323
from restate.handler import Handler, handler_from_callable, invoke_handler
2424
from restate.serde import BytesSerde, DefaultSerde, JsonSerde, Serde
@@ -45,24 +45,36 @@
4545
class ServerDurableFuture(RestateDurableFuture[T]):
4646
"""This class implements a durable future API"""
4747
value: T | None = None
48-
metadata: Dict[str, Any] | None = None
4948

5049
def __init__(self, handle: int, factory) -> None:
5150
super().__init__()
5251
self.factory = factory
5352
self.handle = handle
5453

55-
def with_metadata(self, **metadata) -> 'ServerDurableFuture':
56-
"""Add metadata to the future."""
57-
self.metadata = metadata
58-
return self
59-
6054
def __await__(self):
61-
print("..........Awaiting............", flush=True)
6255
task = asyncio.create_task(self.factory())
6356
return task.__await__()
6457

6558

59+
class ServerCallDurableFuture(RestateDurableCallFuture[T], ServerDurableFuture[T]):
60+
"""This class implements a durable future but for calls"""
61+
_invocation_id: typing.Optional[str] = None
62+
63+
def __init__(self, result_handle: int,
64+
result_factory,
65+
invocation_id_handle: int,
66+
invocation_id_factory) -> None:
67+
super().__init__(result_handle, result_factory)
68+
self.invocation_id_handle = invocation_id_handle
69+
self.invocation_id_factory = invocation_id_factory
70+
71+
async def invocation_id(self) -> str:
72+
"""Get the invocation id."""
73+
if self._invocation_id is None:
74+
self._invocation_id = await self.invocation_id_factory()
75+
return self._invocation_id
76+
77+
6678
class ServerSendHandle(SendHandle):
6779
"""This class implements the send API"""
6880
_invocation_id: typing.Optional[str]
@@ -279,6 +291,21 @@ async def transform():
279291
return ServerDurableFuture(handle, transform)
280292

281293

294+
295+
def create_call_df(self, handle: int, invocation_id_handle: int, serde: Serde[T] | None = None) -> ServerCallDurableFuture[T]:
296+
"""Create a durable future."""
297+
298+
async def transform():
299+
res = await self.create_poll_or_cancel_coroutine(handle)
300+
if res is None or serde is None:
301+
return res
302+
return serde.deserialize(res)
303+
304+
def inv_id_factory():
305+
return self.create_poll_or_cancel_coroutine(invocation_id_handle)
306+
307+
return ServerCallDurableFuture(handle, transform, invocation_id_handle, inv_id_factory)
308+
282309

283310
def get(self, name: str, serde: Serde[T] = JsonSerde()) -> Awaitable[Optional[T]]:
284311
handle = self.vm.sys_get_state(name)
@@ -366,7 +393,7 @@ def do_call(self,
366393
send: bool = False,
367394
idempotency_key: str | None = None,
368395
headers: typing.List[typing.Tuple[str, str]] | None = None
369-
) -> RestateDurableFuture[O] | SendHandle:
396+
) -> RestateDurableCallFuture[O] | SendHandle:
370397
"""Make an RPC call to the given handler"""
371398
target_handler = handler_from_callable(tpe)
372399
service=target_handler.service_tag.name
@@ -387,7 +414,7 @@ def do_raw_call(self,
387414
send: bool = False,
388415
idempotency_key: str | None = None,
389416
headers: typing.List[typing.Tuple[str, str]] | None = None
390-
) -> RestateDurableFuture[O] | SendHandle:
417+
) -> RestateDurableCallFuture[O] | SendHandle:
391418
"""Make an RPC call to the given handler"""
392419
parameter = input_serde.serialize(input_param)
393420
if send_delay:
@@ -405,15 +432,16 @@ def do_raw_call(self,
405432
idempotency_key=idempotency_key,
406433
headers=headers)
407434

408-
# TODO: specialize this future for calls!
409-
return self.create_df(handle=handle.result_handle, serde=output_serde).with_metadata(invocation_id=handle.invocation_id_handle)
435+
return self.create_call_df(handle=handle.result_handle,
436+
invocation_id_handle=handle.invocation_id_handle,
437+
serde=output_serde)
410438

411439
def service_call(self,
412440
tpe: Callable[[Any, I], Awaitable[O]],
413441
arg: I,
414442
idempotency_key: str | None = None,
415443
headers: typing.List[typing.Tuple[str, str]] | None = None
416-
) -> RestateDurableFuture[O]:
444+
) -> RestateDurableCallFuture[O]:
417445
coro = self.do_call(tpe, arg, idempotency_key=idempotency_key, headers=headers)
418446
assert not isinstance(coro, SendHandle)
419447
return coro
@@ -429,7 +457,7 @@ def object_call(self,
429457
arg: I,
430458
idempotency_key: str | None = None,
431459
headers: typing.List[typing.Tuple[str, str]] | None = None
432-
) -> RestateDurableFuture[O]:
460+
) -> RestateDurableCallFuture[O]:
433461
coro = self.do_call(tpe, arg, key, idempotency_key=idempotency_key, headers=headers)
434462
assert not isinstance(coro, SendHandle)
435463
return coro
@@ -445,15 +473,15 @@ def workflow_call(self,
445473
arg: I,
446474
idempotency_key: str | None = None,
447475
headers: typing.List[typing.Tuple[str, str]] | None = None
448-
) -> RestateDurableFuture[O]:
476+
) -> RestateDurableCallFuture[O]:
449477
return self.object_call(tpe, key, arg, idempotency_key=idempotency_key, headers=headers)
450478

451479
def workflow_send(self, tpe: Callable[[Any, I], Awaitable[O]], key: str, arg: I, send_delay: timedelta | None = None, idempotency_key: str | None = None, headers: typing.List[typing.Tuple[str, str]] | None = None) -> SendHandle:
452480
send = self.object_send(tpe, key, arg, send_delay, idempotency_key=idempotency_key, headers=headers)
453481
assert isinstance(send, SendHandle)
454482
return send
455483

456-
def generic_call(self, service: str, handler: str, arg: bytes, key: str | None = None, idempotency_key: str | None = None, headers: typing.List[typing.Tuple[str, str]] | None = None) -> RestateDurableFuture[bytes]:
484+
def generic_call(self, service: str, handler: str, arg: bytes, key: str | None = None, idempotency_key: str | None = None, headers: typing.List[typing.Tuple[str, str]] | None = None) -> RestateDurableCallFuture[bytes]:
457485
serde = BytesSerde()
458486
call_handle = self.do_raw_call(service=service,
459487
handler=handler,

0 commit comments

Comments
 (0)