Skip to content

Commit 4573ea6

Browse files
committed
Add attempt finished signal
1 parent 2a65a13 commit 4573ea6

File tree

4 files changed

+61
-3
lines changed

4 files changed

+61
-3
lines changed

python/restate/context.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,31 @@ class RestateDurableSleepFuture(RestateDurableFuture[None]):
7272
def __await__(self) -> typing.Generator[Any, Any, None]:
7373
pass
7474

75+
class AttemptFinishedEvent(abc.ABC):
76+
"""
77+
Represents an attempt finished event.
78+
79+
This event is used to signal that an attempt has finished (either successfully or with an error), and it is now
80+
safe to cleanup any attempt related resources, such as pending ctx.run() 3rd party calls, or any other resources that
81+
are only valid for the duration of the attempt.
82+
83+
An attempt is considered finished when either the connection to the restate server is closed, the invocation is completed, or a transient
84+
error occurs.
85+
"""
86+
87+
@abc.abstractmethod
88+
def is_set(self) -> bool:
89+
"""
90+
Returns True if the event is set, False otherwise.
91+
"""
92+
93+
94+
@abc.abstractmethod
95+
async def wait(self):
96+
"""
97+
Waits for the event to be set.
98+
"""
99+
75100

76101
@dataclass
77102
class Request:
@@ -83,11 +108,13 @@ class Request:
83108
headers (dict[str, str]): The headers of the request.
84109
attempt_headers (dict[str, str]): The attempt headers of the request.
85110
body (bytes): The body of the request.
111+
attempt_finished_event (AttemptFinishedEvent): The teardown event of the request.
86112
"""
87113
id: str
88114
headers: Dict[str, str]
89115
attempt_headers: Dict[str,str]
90116
body: bytes
117+
attempt_finished_event: AttemptFinishedEvent
91118

92119

93120
class KeyValueStore(abc.ABC):
@@ -158,6 +185,7 @@ async def cancel_invocation(self) -> None:
158185
await ctx.cancel_invocation(await f.invocation_id())
159186
"""
160187

188+
161189
class Context(abc.ABC):
162190
"""
163191
Represents the context of the current invocation.

python/restate/discovery.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ def compute_discovery(endpoint: RestateEndpoint, discovered_as : typing.Literal[
184184
if handler.handler_io.input_type and handler.handler_io.input_type.is_void:
185185
inp = {}
186186
else:
187-
inp =InputPayload(required=False,
187+
inp = InputPayload(required=False,
188188
contentType=handler.handler_io.accept,
189189
jsonSchema=json_schema_from_type_hint(handler.handler_io.input_type))
190190
# output

python/restate/server.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,11 +126,15 @@ async def process_invocation_to_completion(vm: VMWrapper,
126126
try:
127127
await context.enter()
128128
except asyncio.exceptions.CancelledError:
129+
context.on_attempt_finished()
129130
raise
130131
# pylint: disable=W0718
131132
except Exception:
132133
traceback.print_exc()
133-
await context.leave()
134+
try:
135+
await context.leave()
136+
finally:
137+
context.on_attempt_finished()
134138

135139
class LifeSpanNotImplemented(ValueError):
136140
"""Signal to the asgi server that we didn't implement lifespans"""

python/restate/server_context.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import typing
2525
import traceback
2626

27-
from restate.context import DurablePromise, ObjectContext, Request, RestateDurableCallFuture, RestateDurableFuture, SendHandle, RestateDurableSleepFuture
27+
from restate.context import DurablePromise, AttemptFinishedEvent, ObjectContext, Request, RestateDurableCallFuture, RestateDurableFuture, SendHandle, RestateDurableSleepFuture
2828
from restate.exceptions import TerminalError
2929
from restate.handler import Handler, handler_from_callable, invoke_handler
3030
from restate.serde import BytesSerde, DefaultSerde, JsonSerde, Serde
@@ -37,6 +37,24 @@
3737
I = TypeVar('I')
3838
O = TypeVar('O')
3939

40+
41+
class ServerTeardownEvent(AttemptFinishedEvent):
42+
"""
43+
This class implements the teardown event for the server.
44+
"""
45+
46+
def __init__(self, event: asyncio.Event) -> None:
47+
super().__init__()
48+
self.event = event
49+
50+
def is_set(self):
51+
return self.event.is_set()
52+
53+
async def wait(self):
54+
"""Wait for the event to be set."""
55+
await self.event.wait()
56+
57+
4058
class LazyFuture:
4159
"""
4260
Creates a task lazily, and allows multiple awaiters to the same coroutine.
@@ -229,6 +247,7 @@ def __init__(self,
229247
self.receive = receive
230248
self.run_coros_to_execute: dict[int, Callable[[], Awaitable[None]]] = {}
231249
self.sync_point = SyncPoint()
250+
self.request_finished_event = asyncio.Event()
232251

233252
async def enter(self):
234253
"""Invoke the user code."""
@@ -285,6 +304,12 @@ async def leave(self):
285304
'body': b'',
286305
'more_body': False,
287306
})
307+
# notify to any holder of the abort signal that we are done
308+
309+
def on_attempt_finished(self):
310+
"""Notify the attempt finished event."""
311+
self.request_finished_event.set()
312+
288313

289314
async def take_and_send_output(self):
290315
"""Take output from state machine and send it"""
@@ -407,6 +432,7 @@ def request(self) -> Request:
407432
headers=dict(self.invocation.headers),
408433
attempt_headers=self.attempt_headers,
409434
body=self.invocation.input_buffer,
435+
attempt_finished_event=ServerTeardownEvent(self.request_finished_event),
410436
)
411437

412438
async def create_run_coroutine(self,

0 commit comments

Comments
 (0)