Skip to content

Commit ea3772d

Browse files
committed
Detect correctly http.disconnect event
1 parent 4573ea6 commit ea3772d

File tree

2 files changed

+38
-9
lines changed

2 files changed

+38
-9
lines changed

python/restate/server.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import traceback
1616
from restate.discovery import compute_discovery_json
1717
from restate.endpoint import Endpoint
18-
from restate.server_context import ServerInvocationContext
18+
from restate.server_context import ServerInvocationContext, DisconnectedException
1919
from restate.server_types import Receive, Scope, Send, binary_to_header, header_to_binary
2020
from restate.vm import VMWrapper
2121
from restate._internal import PyIdentityVerifier, IdentityVerificationException # pylint: disable=import-error,no-name-in-module
@@ -81,6 +81,7 @@ async def send_health_check(send: Send):
8181
'more_body': False,
8282
})
8383

84+
8485
async def process_invocation_to_completion(vm: VMWrapper,
8586
handler,
8687
attempt_headers: Dict[str, str],
@@ -128,6 +129,10 @@ async def process_invocation_to_completion(vm: VMWrapper,
128129
except asyncio.exceptions.CancelledError:
129130
context.on_attempt_finished()
130131
raise
132+
except DisconnectedException:
133+
# The client disconnected before we could send the response
134+
context.on_attempt_finished()
135+
return
131136
# pylint: disable=W0718
132137
except Exception:
133138
traceback.print_exc()

python/restate/server_context.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,16 @@
3737
I = TypeVar('I')
3838
O = TypeVar('O')
3939

40+
class DisconnectedException(Exception):
41+
"""
42+
This exception is raised when the connection to the restate server is lost.
43+
This can be due to a network error or a long inactivity timeout.
44+
The restate-server will automatically retry the attempt.
45+
"""
46+
47+
def __init__(self) -> None:
48+
super().__init__("Disconnected. The connection to the restate server was lost. Restate will retry the attempt.")
49+
4050

4151
class ServerTeardownEvent(AttemptFinishedEvent):
4252
"""
@@ -202,7 +212,6 @@ def peek(self) -> Awaitable[Any | None]:
202212
handle = vm.sys_peek_promise(self.name)
203213
serde = self.serde
204214
assert serde is not None
205-
206215
return self.server_context.create_future(handle, serde)
207216

208217

@@ -263,6 +272,8 @@ async def enter(self):
263272
# pylint: disable=W0718
264273
except SuspendedException:
265274
pass
275+
except DisconnectedException:
276+
raise
266277
except Exception as e:
267278
stacktrace = '\n'.join(traceback.format_exception(e))
268279
self.vm.notify_error(repr(e), stacktrace)
@@ -311,6 +322,17 @@ def on_attempt_finished(self):
311322
self.request_finished_event.set()
312323

313324

325+
async def receive_and_notify_input(self):
326+
"""Receive input from the state machine."""
327+
chunk = await self.receive()
328+
if chunk.get('type') == 'http.request':
329+
assert isinstance(chunk['body'], bytes)
330+
self.vm.notify_input(chunk['body'])
331+
if not chunk.get('more_body', False):
332+
self.vm.notify_input_closed()
333+
if chunk.get('type') == 'http.disconnect':
334+
raise DisconnectedException()
335+
314336
async def take_and_send_output(self):
315337
"""Take output from state machine and send it"""
316338
output = self.vm.take_output()
@@ -344,12 +366,7 @@ async def create_poll_or_cancel_coroutine(self, handles: typing.List[int]) -> No
344366
if isinstance(do_progress_response, DoProgressCancelSignalReceived):
345367
raise TerminalError("cancelled", 409)
346368
if isinstance(do_progress_response, DoProgressReadFromInput):
347-
chunk = await self.receive()
348-
if chunk.get('body', None) is not None:
349-
assert isinstance(chunk['body'], bytes)
350-
self.vm.notify_input(chunk['body'])
351-
if not chunk.get('more_body', False):
352-
self.vm.notify_input_closed()
369+
await self.receive_and_notify_input()
353370
continue
354371
if isinstance(do_progress_response, DoProgressExecuteRun):
355372
fn = self.run_coros_to_execute[do_progress_response.handle]
@@ -364,7 +381,14 @@ async def wrapper(f):
364381
asyncio.create_task(wrapper(fn))
365382
continue
366383
if isinstance(do_progress_response, DoWaitPendingRun):
367-
await self.sync_point.wait()
384+
sync_task = asyncio.create_task(self.sync_point.wait())
385+
read_task = asyncio.create_task(self.receive_and_notify_input())
386+
done, _ = await asyncio.wait([sync_task, read_task], return_when=asyncio.FIRST_COMPLETED)
387+
if read_task in done:
388+
_ = read_task.result() # rethrow any exception
389+
if sync_task in done:
390+
continue
391+
368392

369393
def _create_fetch_result_coroutine(self, handle: int, serde: Serde[T] | None = None):
370394
"""Create a coroutine that fetches a result from a notification handle."""

0 commit comments

Comments
 (0)