37
37
I = TypeVar ('I' )
38
38
O = TypeVar ('O' )
39
39
40
+ class DisconnectedException (IOError ):
41
+ """
42
+ This class implements the disconnected exception.
43
+ """
44
+
45
+ def __init__ (self ) -> None :
46
+ super ().__init__ ("disconnected" )
47
+
48
+
40
49
41
50
class ServerTeardownEvent (AttemptFinishedEvent ):
42
51
"""
@@ -202,7 +211,6 @@ def peek(self) -> Awaitable[Any | None]:
202
211
handle = vm .sys_peek_promise (self .name )
203
212
serde = self .serde
204
213
assert serde is not None
205
-
206
214
return self .server_context .create_future (handle , serde )
207
215
208
216
@@ -311,6 +319,17 @@ def on_attempt_finished(self):
311
319
self .request_finished_event .set ()
312
320
313
321
322
+ async def receive_and_notify_input (self ):
323
+ """Receive input from the state machine."""
324
+ chunk = await self .receive ()
325
+ if chunk .get ('type' ) == 'http.request' :
326
+ assert isinstance (chunk ['body' ], bytes )
327
+ self .vm .notify_input (chunk ['body' ])
328
+ if not chunk .get ('more_body' , False ):
329
+ self .vm .notify_input_closed ()
330
+ if chunk .get ('type' ) == 'http.disconnect' :
331
+ raise DisconnectedException ()
332
+
314
333
async def take_and_send_output (self ):
315
334
"""Take output from state machine and send it"""
316
335
output = self .vm .take_output ()
@@ -344,12 +363,7 @@ async def create_poll_or_cancel_coroutine(self, handles: typing.List[int]) -> No
344
363
if isinstance (do_progress_response , DoProgressCancelSignalReceived ):
345
364
raise TerminalError ("cancelled" , 409 )
346
365
if isinstance (do_progress_response , DoProgressReadFromInput ):
347
- chunk = await self .receive ()
348
- if chunk .get ('body' , None ) is not None :
349
- assert isinstance (chunk ['body' ], bytes )
350
- self .vm .notify_input (chunk ['body' ])
351
- if not chunk .get ('more_body' , False ):
352
- self .vm .notify_input_closed ()
366
+ await self .receive_and_notify_input ()
353
367
continue
354
368
if isinstance (do_progress_response , DoProgressExecuteRun ):
355
369
fn = self .run_coros_to_execute [do_progress_response .handle ]
@@ -364,7 +378,14 @@ async def wrapper(f):
364
378
asyncio .create_task (wrapper (fn ))
365
379
continue
366
380
if isinstance (do_progress_response , DoWaitPendingRun ):
367
- await self .sync_point .wait ()
381
+ sync_task = asyncio .create_task (self .sync_point .wait ())
382
+ read_task = asyncio .create_task (self .receive_and_notify_input ())
383
+ done , _ = await asyncio .wait ([sync_task , read_task ], return_when = asyncio .FIRST_COMPLETED )
384
+ if read_task in done :
385
+ _ = read_task .result () # rethrow any exception
386
+ if sync_task in done :
387
+ continue
388
+
368
389
369
390
def _create_fetch_result_coroutine (self , handle : int , serde : Serde [T ] | None = None ):
370
391
"""Create a coroutine that fetches a result from a notification handle."""
0 commit comments