28
28
from restate .handler import Handler , handler_from_callable , invoke_handler
29
29
from restate .serde import BytesSerde , DefaultSerde , JsonSerde , Serde
30
30
from restate .server_types import Receive , Send
31
- from restate .vm import Failure , Invocation , NotReady , SuspendedException , VMWrapper , RunRetryConfig # pylint: disable=line-too-long
31
+ from restate .vm import Failure , Invocation , NotReady , NotificationType , SuspendedException , VMWrapper , RunRetryConfig # pylint: disable=line-too-long
32
32
from restate .vm import DoProgressAnyCompleted , DoProgressCancelSignalReceived , DoProgressReadFromInput , DoProgressExecuteRun # pylint: disable=line-too-long
33
33
34
34
T = TypeVar ('T' )
@@ -42,7 +42,7 @@ class ServerDurableFuture(RestateDurableFuture[T]):
42
42
error : TerminalError | None = None
43
43
state : typing .Literal ["pending" , "fulfilled" , "rejected" ] = "pending"
44
44
45
- def __init__ (self , context : "ServerInvocationContext" , handle : int , awaitable_factory ) -> None :
45
+ def __init__ (self , context : "ServerInvocationContext" , handle : int , awaitable_factory : Callable [[], Awaitable [ T ]] ) -> None :
46
46
super ().__init__ ()
47
47
self .context = context
48
48
self .source_notification_handle = handle
@@ -60,7 +60,7 @@ def is_completed(self):
60
60
return True
61
61
62
62
63
- def __await__ (self ):
63
+ def __await__ (self ) -> typing . Generator [ Any , Any , T ] :
64
64
65
65
async def await_point ():
66
66
match self .state :
@@ -74,6 +74,7 @@ async def await_point():
74
74
self .state = "rejected"
75
75
raise t
76
76
case "fulfilled" :
77
+ assert self .value is not None
77
78
return self .value
78
79
case "rejected" :
79
80
assert self .error is not None
@@ -91,7 +92,7 @@ def __init__(self,
91
92
result_handle : int ,
92
93
result_factory ,
93
94
invocation_id_handle : int ,
94
- invocation_id_factory ) -> None :
95
+ invocation_id_factory : Callable [[], Awaitable [ NotificationType ]] ) -> None :
95
96
super ().__init__ (context , result_handle , result_factory )
96
97
self .invocation_id_handle = invocation_id_handle
97
98
self .invocation_id_factory = invocation_id_factory
@@ -100,8 +101,12 @@ def __init__(self,
100
101
async def invocation_id (self ) -> str :
101
102
"""Get the invocation id."""
102
103
if self ._invocation_id is None :
103
- self ._invocation_id = await self .invocation_id_factory ()
104
-
104
+ res = await self .invocation_id_factory ()
105
+ if isinstance (res , str ):
106
+ self ._invocation_id = res
107
+ else :
108
+ raise ValueError (f"Unexpected notification type for handle { self .invocation_id_handle } : { type (res ).__name__ } . "
109
+ "Expected str." )
105
110
if self ._invocation_id is None :
106
111
raise ValueError ("invocation_id is None" )
107
112
return self ._invocation_id
@@ -147,6 +152,7 @@ async def await_point():
147
152
res = self .server_context .must_take_notification (handle )
148
153
if res is None :
149
154
return None
155
+ assert isinstance (res , bytes )
150
156
return serde .deserialize (res )
151
157
152
158
return await_point ()
@@ -185,6 +191,7 @@ async def await_point():
185
191
res = self .server_context .must_take_notification (handle )
186
192
if res is None :
187
193
return None
194
+ assert isinstance (res , bytes )
188
195
return serde .deserialize (res )
189
196
190
197
return await_point ()
@@ -278,11 +285,11 @@ async def take_and_send_output(self):
278
285
'more_body' : True ,
279
286
})
280
287
281
- def must_take_notification (self , handle : int ) -> bytes | None :
288
+ def must_take_notification (self , handle : int ) -> bytes | None | str | List [ str ] :
282
289
"""Take notification, which must be present. It must be either bytes or None"""
283
290
res = self .vm .take_notification (handle )
284
291
if isinstance (res , NotReady ):
285
- raise ValueError (f"Notification for handle { handle } is not ready. "
292
+ raise RuntimeError (f"Notification for handle { handle } is not ready. "
286
293
"This likely indicates an unexpected async state." )
287
294
288
295
if res is None :
@@ -316,17 +323,35 @@ async def create_poll_or_cancel_coroutine(self, handles: typing.List[int]) -> No
316
323
await self .take_and_send_output ()
317
324
318
325
319
- def create_df (self , handle : int , serde : Serde [T ] | None = None ) -> ServerDurableFuture [T ]:
320
- """Create a durable future."""
321
-
322
- async def transform ():
326
+ def create_df (self , handle : int , serde : Serde [T ] | None = None ) -> ServerDurableFuture [Any ]:
327
+ """Create a durable future for handling asynchronous state operations.
328
+
329
+ This is a general-purpose factory method that creates a future that will resolve
330
+ to whatever the VM notification system returns for the given handle. The specific
331
+ return type depends on the operation that was performed.
332
+
333
+ Args:
334
+ handle: The notification handle from the VM
335
+ serde: Optional serializer/deserializer for converting raw bytes to specific types
336
+
337
+ Returns:
338
+ A durable future that will eventually resolve to the appropriate value.
339
+ The caller is responsible for knowing what type to expect based on the context.
340
+ """
341
+
342
+ async def fetch_result ():
323
343
await self .create_poll_or_cancel_coroutine ([handle ])
324
344
res = self .must_take_notification (handle )
325
345
if res is None or serde is None :
326
346
return res
327
- return serde .deserialize (res )
347
+ if isinstance (res , bytes ):
348
+ return serde .deserialize (res )
349
+ if isinstance (res , (str , list )):
350
+ return res
351
+ raise ValueError (f"Unexpected notification type for handle { handle } : { type (res ).__name__ } . "
352
+ "Expected bytes or None." )
328
353
329
- return ServerDurableFuture (self , handle , transform )
354
+ return ServerDurableFuture (self , handle , fetch_result )
330
355
331
356
332
357
@@ -338,9 +363,12 @@ async def transform():
338
363
res = self .must_take_notification (handle )
339
364
if res is None or serde is None :
340
365
return res
341
- return serde .deserialize (res )
366
+ if isinstance (res , bytes ):
367
+ return serde .deserialize (res )
368
+ if isinstance (res , (str , list )):
369
+ return res
342
370
343
- async def inv_id_factory ():
371
+ async def inv_id_factory () -> NotificationType :
344
372
await self .create_poll_or_cancel_coroutine ([invocation_id_handle ])
345
373
return self .must_take_notification (invocation_id_handle )
346
374
@@ -351,7 +379,7 @@ def get(self, name: str, serde: Serde[T] = JsonSerde()) -> Awaitable[Optional[T]
351
379
handle = self .vm .sys_get_state (name )
352
380
return self .create_df (handle , serde ) # type: ignore
353
381
354
- def state_keys (self ) -> Awaitable [List [str ]]:
382
+ def state_keys (self ) -> RestateDurableFuture [List [str ]]:
355
383
return self .create_df (self .vm .sys_get_state_keys ())
356
384
357
385
def set (self , name : str , value : T , serde : Serde [T ] = JsonSerde ()) -> None :
0 commit comments