-
Notifications
You must be signed in to change notification settings - Fork 7
Chore/fix type hints and general clarity #57
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 9 commits
8e20702
6db2ace
07d3e24
1fb1dfa
30c80a7
8d3b579
06772d8
3660364
8d2a144
0ec7950
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -87,7 +87,7 @@ def __await__(self): | |
class ServerDurableSleepFuture(RestateDurableSleepFuture, ServerDurableFuture[None]): | ||
"""This class implements a durable sleep future API""" | ||
|
||
def __await__(self): | ||
def __await__(self) -> typing.Generator[Any, Any, None]: | ||
return self.future.__await__() | ||
|
||
class ServerCallDurableFuture(RestateDurableCallFuture[T], ServerDurableFuture[T]): | ||
|
@@ -210,7 +210,7 @@ def __init__(self, | |
self.attempt_headers = attempt_headers | ||
self.send = send | ||
self.receive = receive | ||
self.run_coros_to_execute: dict[int, Callable[[], Awaitable[typing.Union[bytes | Failure]]]] = {} | ||
self.run_coros_to_execute: dict[int, Callable[[], Awaitable[bytes | Failure]]] = {} | ||
self.sync_point = SyncPoint() | ||
|
||
async def enter(self): | ||
|
@@ -324,54 +324,44 @@ async def wrapper(f): | |
if isinstance(do_progress_response, DoWaitPendingRun): | ||
await self.sync_point.wait() | ||
|
||
def create_future(self, handle: int, serde: Serde[T] | None = None) -> ServerDurableFuture[T]: | ||
"""Create a durable future.""" | ||
|
||
async def transform(): | ||
def _create_fetch_result_coroutine(self, handle: int, serde: Serde[T] | None = None): | ||
"""Create a coroutine that fetches a result from a notification handle.""" | ||
async def fetch_result(): | ||
if not self.vm.is_completed(handle): | ||
await self.create_poll_or_cancel_coroutine([handle]) | ||
res = self.must_take_notification(handle) | ||
if res is None or serde is None: | ||
if res is None or serde is None or not isinstance(res, bytes): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure about this check. I'd need to review this in depth, can you please revert this check back for now? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Currently the function runs
Let me know if you want it changed back, but I think it would be a bug? |
||
return res | ||
return serde.deserialize(res) | ||
|
||
return ServerDurableFuture(self, handle, transform) | ||
return fetch_result | ||
|
||
def create_future(self, handle: int, serde: Serde[T] | None = None) -> ServerDurableFuture[T]: | ||
"""Create a durable future for handling asynchronous state operations.""" | ||
return ServerDurableFuture(self, handle, self._create_fetch_result_coroutine(handle, serde)) | ||
|
||
def create_sleep_future(self, handle: int) -> ServerDurableSleepFuture: | ||
"""Create a durable sleep future.""" | ||
|
||
async def transform(): | ||
if not self.vm.is_completed(handle): | ||
await self.create_poll_or_cancel_coroutine([handle]) | ||
self.must_take_notification(handle) | ||
|
||
return ServerDurableSleepFuture(self, handle, transform) | ||
|
||
|
||
def create_call_future(self, handle: int, invocation_id_handle: int, serde: Serde[T] | None = None) -> ServerCallDurableFuture[T]: | ||
"""Create a durable future.""" | ||
|
||
async def transform(): | ||
if not self.vm.is_completed(handle): | ||
await self.create_poll_or_cancel_coroutine([handle]) | ||
res = self.must_take_notification(handle) | ||
if res is None or serde is None: | ||
return res | ||
return serde.deserialize(res) | ||
|
||
async def inv_id_factory(): | ||
if not self.vm.is_completed(invocation_id_handle): | ||
await self.create_poll_or_cancel_coroutine([invocation_id_handle]) | ||
return self.must_take_notification(invocation_id_handle) | ||
|
||
return ServerCallDurableFuture(self, handle, transform, inv_id_factory) | ||
|
||
return ServerCallDurableFuture(self, handle, self._create_fetch_result_coroutine(handle, serde), inv_id_factory) | ||
|
||
def get(self, name: str, serde: Serde[T] = JsonSerde()) -> Awaitable[Optional[T]]: | ||
handle = self.vm.sys_get_state(name) | ||
return self.create_future(handle, serde) # type: ignore | ||
|
||
def state_keys(self) -> Awaitable[List[str]]: | ||
def state_keys(self) -> RestateDurableFuture[List[str]]: | ||
ouatu-ro marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return self.create_future(self.vm.sys_get_state_keys()) # type: ignore | ||
|
||
def set(self, name: str, value: T, serde: Serde[T] = JsonSerde()) -> None: | ||
|
@@ -402,15 +392,17 @@ async def create_run_coroutine(self, | |
"""Create a coroutine to poll the handle.""" | ||
try: | ||
if inspect.iscoroutinefunction(action): | ||
action_result = await action() # type: ignore | ||
action_result: T = await action() # type: ignore | ||
else: | ||
action_result = await asyncio.to_thread(action) | ||
action_result = typing.cast(T, await asyncio.to_thread(action)) | ||
|
||
buffer = serde.serialize(action_result) | ||
self.vm.propose_run_completion_success(handle, buffer) | ||
return buffer | ||
ouatu-ro marked this conversation as resolved.
Show resolved
Hide resolved
|
||
except TerminalError as t: | ||
failure = Failure(code=t.status_code, message=t.message) | ||
self.vm.propose_run_completion_failure(handle, failure) | ||
return failure | ||
ouatu-ro marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# pylint: disable=W0718 | ||
except Exception as e: | ||
if max_attempts is None and max_retry_duration is None: | ||
|
@@ -420,7 +412,7 @@ async def create_run_coroutine(self, | |
max_duration_ms = None if max_retry_duration is None else int(max_retry_duration.total_seconds() * 1000) | ||
config = RunRetryConfig(max_attempts=max_attempts, max_duration=max_duration_ms) | ||
self.vm.propose_run_completion_transient(handle, failure=failure, attempt_duration_ms=1, config=config) | ||
|
||
return failure | ||
ouatu-ro marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# pylint: disable=W0236 | ||
# pylint: disable=R0914 | ||
def run(self, | ||
|
Uh oh!
There was an error while loading. Please reload this page.