Skip to content

Commit ed57d59

Browse files
committed
Detect correctly http.disconnect event
1 parent 4573ea6 commit ed57d59

File tree

2 files changed

+35
-9
lines changed

2 files changed

+35
-9
lines changed

python/restate/server.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import traceback
1616
from restate.discovery import compute_discovery_json
1717
from restate.endpoint import Endpoint
18-
from restate.server_context import ServerInvocationContext
18+
from restate.server_context import ServerInvocationContext, DisconnectedException
1919
from restate.server_types import Receive, Scope, Send, binary_to_header, header_to_binary
2020
from restate.vm import VMWrapper
2121
from restate._internal import PyIdentityVerifier, IdentityVerificationException # pylint: disable=import-error,no-name-in-module
@@ -81,6 +81,7 @@ async def send_health_check(send: Send):
8181
'more_body': False,
8282
})
8383

84+
8485
async def process_invocation_to_completion(vm: VMWrapper,
8586
handler,
8687
attempt_headers: Dict[str, str],
@@ -128,6 +129,10 @@ async def process_invocation_to_completion(vm: VMWrapper,
128129
except asyncio.exceptions.CancelledError:
129130
context.on_attempt_finished()
130131
raise
132+
except DisconnectedException:
133+
# The client disconnected before we could send the response
134+
context.on_attempt_finished()
135+
return
131136
# pylint: disable=W0718
132137
except Exception:
133138
traceback.print_exc()

python/restate/server_context.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,15 @@
3737
I = TypeVar('I')
3838
O = TypeVar('O')
3939

40+
class DisconnectedException(IOError):
41+
"""
42+
This class implements the disconnected exception.
43+
"""
44+
45+
def __init__(self) -> None:
46+
super().__init__("disconnected")
47+
48+
4049

4150
class ServerTeardownEvent(AttemptFinishedEvent):
4251
"""
@@ -202,7 +211,6 @@ def peek(self) -> Awaitable[Any | None]:
202211
handle = vm.sys_peek_promise(self.name)
203212
serde = self.serde
204213
assert serde is not None
205-
206214
return self.server_context.create_future(handle, serde)
207215

208216

@@ -311,6 +319,17 @@ def on_attempt_finished(self):
311319
self.request_finished_event.set()
312320

313321

322+
async def receive_and_notify_input(self):
323+
"""Receive input from the state machine."""
324+
chunk = await self.receive()
325+
if chunk.get('type') == 'http.request':
326+
assert isinstance(chunk['body'], bytes)
327+
self.vm.notify_input(chunk['body'])
328+
if not chunk.get('more_body', False):
329+
self.vm.notify_input_closed()
330+
if chunk.get('type') == 'http.disconnect':
331+
raise DisconnectedException()
332+
314333
async def take_and_send_output(self):
315334
"""Take output from state machine and send it"""
316335
output = self.vm.take_output()
@@ -344,12 +363,7 @@ async def create_poll_or_cancel_coroutine(self, handles: typing.List[int]) -> No
344363
if isinstance(do_progress_response, DoProgressCancelSignalReceived):
345364
raise TerminalError("cancelled", 409)
346365
if isinstance(do_progress_response, DoProgressReadFromInput):
347-
chunk = await self.receive()
348-
if chunk.get('body', None) is not None:
349-
assert isinstance(chunk['body'], bytes)
350-
self.vm.notify_input(chunk['body'])
351-
if not chunk.get('more_body', False):
352-
self.vm.notify_input_closed()
366+
await self.receive_and_notify_input()
353367
continue
354368
if isinstance(do_progress_response, DoProgressExecuteRun):
355369
fn = self.run_coros_to_execute[do_progress_response.handle]
@@ -364,7 +378,14 @@ async def wrapper(f):
364378
asyncio.create_task(wrapper(fn))
365379
continue
366380
if isinstance(do_progress_response, DoWaitPendingRun):
367-
await self.sync_point.wait()
381+
sync_task = asyncio.create_task(self.sync_point.wait())
382+
read_task = asyncio.create_task(self.receive_and_notify_input())
383+
done, _ = await asyncio.wait([sync_task, read_task], return_when=asyncio.FIRST_COMPLETED)
384+
if read_task in done:
385+
_ = read_task.result() # rethrow any exception
386+
if sync_task in done:
387+
continue
388+
368389

369390
def _create_fetch_result_coroutine(self, handle: int, serde: Serde[T] | None = None):
370391
"""Create a coroutine that fetches a result from a notification handle."""

0 commit comments

Comments
 (0)