18
18
import typing
19
19
import traceback
20
20
21
- from restate .context import DurablePromise , ObjectContext , Request , RestateDurableFuture , SendHandle
21
+ from restate .context import DurablePromise , ObjectContext , Request , RestateDurableCallFuture , RestateDurableFuture , SendHandle
22
22
from restate .exceptions import TerminalError
23
23
from restate .handler import Handler , handler_from_callable , invoke_handler
24
24
from restate .serde import BytesSerde , DefaultSerde , JsonSerde , Serde
45
45
class ServerDurableFuture (RestateDurableFuture [T ]):
46
46
"""This class implements a durable future API"""
47
47
value : T | None = None
48
- metadata : Dict [str , Any ] | None = None
49
48
50
49
def __init__ (self , handle : int , factory ) -> None :
51
50
super ().__init__ ()
52
51
self .factory = factory
53
52
self .handle = handle
54
53
55
- def with_metadata (self , ** metadata ) -> 'ServerDurableFuture' :
56
- """Add metadata to the future."""
57
- self .metadata = metadata
58
- return self
59
-
60
54
def __await__ (self ):
61
- print ("..........Awaiting............" , flush = True )
62
55
task = asyncio .create_task (self .factory ())
63
56
return task .__await__ ()
64
57
65
58
59
+ class ServerCallDurableFuture (RestateDurableCallFuture [T ], ServerDurableFuture [T ]):
60
+ """This class implements a durable future but for calls"""
61
+ _invocation_id : typing .Optional [str ] = None
62
+
63
+ def __init__ (self , result_handle : int ,
64
+ result_factory ,
65
+ invocation_id_handle : int ,
66
+ invocation_id_factory ) -> None :
67
+ super ().__init__ (result_handle , result_factory )
68
+ self .invocation_id_handle = invocation_id_handle
69
+ self .invocation_id_factory = invocation_id_factory
70
+
71
+ async def invocation_id (self ) -> str :
72
+ """Get the invocation id."""
73
+ if self ._invocation_id is None :
74
+ self ._invocation_id = await self .invocation_id_factory ()
75
+ return self ._invocation_id
76
+
77
+
66
78
class ServerSendHandle (SendHandle ):
67
79
"""This class implements the send API"""
68
80
_invocation_id : typing .Optional [str ]
@@ -279,6 +291,21 @@ async def transform():
279
291
return ServerDurableFuture (handle , transform )
280
292
281
293
294
+
295
+ def create_call_df (self , handle : int , invocation_id_handle : int , serde : Serde [T ] | None = None ) -> ServerCallDurableFuture [T ]:
296
+ """Create a durable future."""
297
+
298
+ async def transform ():
299
+ res = await self .create_poll_or_cancel_coroutine (handle )
300
+ if res is None or serde is None :
301
+ return res
302
+ return serde .deserialize (res )
303
+
304
+ def inv_id_factory ():
305
+ return self .create_poll_or_cancel_coroutine (invocation_id_handle )
306
+
307
+ return ServerCallDurableFuture (handle , transform , invocation_id_handle , inv_id_factory )
308
+
282
309
283
310
def get (self , name : str , serde : Serde [T ] = JsonSerde ()) -> Awaitable [Optional [T ]]:
284
311
handle = self .vm .sys_get_state (name )
@@ -366,7 +393,7 @@ def do_call(self,
366
393
send : bool = False ,
367
394
idempotency_key : str | None = None ,
368
395
headers : typing .List [typing .Tuple [str , str ]] | None = None
369
- ) -> RestateDurableFuture [O ] | SendHandle :
396
+ ) -> RestateDurableCallFuture [O ] | SendHandle :
370
397
"""Make an RPC call to the given handler"""
371
398
target_handler = handler_from_callable (tpe )
372
399
service = target_handler .service_tag .name
@@ -387,7 +414,7 @@ def do_raw_call(self,
387
414
send : bool = False ,
388
415
idempotency_key : str | None = None ,
389
416
headers : typing .List [typing .Tuple [str , str ]] | None = None
390
- ) -> RestateDurableFuture [O ] | SendHandle :
417
+ ) -> RestateDurableCallFuture [O ] | SendHandle :
391
418
"""Make an RPC call to the given handler"""
392
419
parameter = input_serde .serialize (input_param )
393
420
if send_delay :
@@ -405,15 +432,16 @@ def do_raw_call(self,
405
432
idempotency_key = idempotency_key ,
406
433
headers = headers )
407
434
408
- # TODO: specialize this future for calls!
409
- return self .create_df (handle = handle .result_handle , serde = output_serde ).with_metadata (invocation_id = handle .invocation_id_handle )
435
+ return self .create_call_df (handle = handle .result_handle ,
436
+ invocation_id_handle = handle .invocation_id_handle ,
437
+ serde = output_serde )
410
438
411
439
def service_call (self ,
412
440
tpe : Callable [[Any , I ], Awaitable [O ]],
413
441
arg : I ,
414
442
idempotency_key : str | None = None ,
415
443
headers : typing .List [typing .Tuple [str , str ]] | None = None
416
- ) -> RestateDurableFuture [O ]:
444
+ ) -> RestateDurableCallFuture [O ]:
417
445
coro = self .do_call (tpe , arg , idempotency_key = idempotency_key , headers = headers )
418
446
assert not isinstance (coro , SendHandle )
419
447
return coro
@@ -429,7 +457,7 @@ def object_call(self,
429
457
arg : I ,
430
458
idempotency_key : str | None = None ,
431
459
headers : typing .List [typing .Tuple [str , str ]] | None = None
432
- ) -> RestateDurableFuture [O ]:
460
+ ) -> RestateDurableCallFuture [O ]:
433
461
coro = self .do_call (tpe , arg , key , idempotency_key = idempotency_key , headers = headers )
434
462
assert not isinstance (coro , SendHandle )
435
463
return coro
@@ -445,15 +473,15 @@ def workflow_call(self,
445
473
arg : I ,
446
474
idempotency_key : str | None = None ,
447
475
headers : typing .List [typing .Tuple [str , str ]] | None = None
448
- ) -> RestateDurableFuture [O ]:
476
+ ) -> RestateDurableCallFuture [O ]:
449
477
return self .object_call (tpe , key , arg , idempotency_key = idempotency_key , headers = headers )
450
478
451
479
def workflow_send (self , tpe : Callable [[Any , I ], Awaitable [O ]], key : str , arg : I , send_delay : timedelta | None = None , idempotency_key : str | None = None , headers : typing .List [typing .Tuple [str , str ]] | None = None ) -> SendHandle :
452
480
send = self .object_send (tpe , key , arg , send_delay , idempotency_key = idempotency_key , headers = headers )
453
481
assert isinstance (send , SendHandle )
454
482
return send
455
483
456
- def generic_call (self , service : str , handler : str , arg : bytes , key : str | None = None , idempotency_key : str | None = None , headers : typing .List [typing .Tuple [str , str ]] | None = None ) -> RestateDurableFuture [bytes ]:
484
+ def generic_call (self , service : str , handler : str , arg : bytes , key : str | None = None , idempotency_key : str | None = None , headers : typing .List [typing .Tuple [str , str ]] | None = None ) -> RestateDurableCallFuture [bytes ]:
457
485
serde = BytesSerde ()
458
486
call_handle = self .do_raw_call (service = service ,
459
487
handler = handler ,
0 commit comments