37
37
I = TypeVar ('I' )
38
38
O = TypeVar ('O' )
39
39
40
+ class DisconnectedException (Exception ):
41
+ """
42
+ This exception is raised when the connection to the restate server is lost.
43
+ This can be due to a network error or a long inactivity timeout.
44
+ The restate-server will automatically retry the attempt.
45
+ """
46
+
47
+ def __init__ (self ) -> None :
48
+ super ().__init__ ("Disconnected. The connection to the restate server was lost. Restate will retry the attempt." )
49
+
40
50
41
51
class ServerTeardownEvent (AttemptFinishedEvent ):
42
52
"""
@@ -202,7 +212,6 @@ def peek(self) -> Awaitable[Any | None]:
202
212
handle = vm .sys_peek_promise (self .name )
203
213
serde = self .serde
204
214
assert serde is not None
205
-
206
215
return self .server_context .create_future (handle , serde )
207
216
208
217
@@ -263,6 +272,8 @@ async def enter(self):
263
272
# pylint: disable=W0718
264
273
except SuspendedException :
265
274
pass
275
+ except DisconnectedException :
276
+ raise
266
277
except Exception as e :
267
278
stacktrace = '\n ' .join (traceback .format_exception (e ))
268
279
self .vm .notify_error (repr (e ), stacktrace )
@@ -311,6 +322,17 @@ def on_attempt_finished(self):
311
322
self .request_finished_event .set ()
312
323
313
324
325
+ async def receive_and_notify_input (self ):
326
+ """Receive input from the state machine."""
327
+ chunk = await self .receive ()
328
+ if chunk .get ('type' ) == 'http.request' :
329
+ assert isinstance (chunk ['body' ], bytes )
330
+ self .vm .notify_input (chunk ['body' ])
331
+ if not chunk .get ('more_body' , False ):
332
+ self .vm .notify_input_closed ()
333
+ if chunk .get ('type' ) == 'http.disconnect' :
334
+ raise DisconnectedException ()
335
+
314
336
async def take_and_send_output (self ):
315
337
"""Take output from state machine and send it"""
316
338
output = self .vm .take_output ()
@@ -344,12 +366,7 @@ async def create_poll_or_cancel_coroutine(self, handles: typing.List[int]) -> No
344
366
if isinstance (do_progress_response , DoProgressCancelSignalReceived ):
345
367
raise TerminalError ("cancelled" , 409 )
346
368
if isinstance (do_progress_response , DoProgressReadFromInput ):
347
- chunk = await self .receive ()
348
- if chunk .get ('body' , None ) is not None :
349
- assert isinstance (chunk ['body' ], bytes )
350
- self .vm .notify_input (chunk ['body' ])
351
- if not chunk .get ('more_body' , False ):
352
- self .vm .notify_input_closed ()
369
+ await self .receive_and_notify_input ()
353
370
continue
354
371
if isinstance (do_progress_response , DoProgressExecuteRun ):
355
372
fn = self .run_coros_to_execute [do_progress_response .handle ]
@@ -364,7 +381,14 @@ async def wrapper(f):
364
381
asyncio .create_task (wrapper (fn ))
365
382
continue
366
383
if isinstance (do_progress_response , DoWaitPendingRun ):
367
- await self .sync_point .wait ()
384
+ sync_task = asyncio .create_task (self .sync_point .wait ())
385
+ read_task = asyncio .create_task (self .receive_and_notify_input ())
386
+ done , _ = await asyncio .wait ([sync_task , read_task ], return_when = asyncio .FIRST_COMPLETED )
387
+ if read_task in done :
388
+ _ = read_task .result () # rethrow any exception
389
+ if sync_task in done :
390
+ continue
391
+
368
392
369
393
def _create_fetch_result_coroutine (self , handle : int , serde : Serde [T ] | None = None ):
370
394
"""Create a coroutine that fetches a result from a notification handle."""
0 commit comments