Skip to content

Commit f7ec282

Browse files
committed
Agressive type checks and and clarifying return types in context and server_context modules
- Updated the __await__ method signatures in RestateDurableFuture and ServerDurableFuture to specify return types as typing.Generator. - Changed the state_keys method return type in ServerInvocationContext to RestateDurableFuture[List[str]] for improved type safety. - Refined the create_df method to handle various notification types and added detailed docstring for clarity. - Ensured consistent handling of notification types across methods in ServerInvocationContext.
1 parent 8d3b579 commit f7ec282

File tree

2 files changed

+47
-19
lines changed

2 files changed

+47
-19
lines changed

python/restate/context.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def is_completed(self) -> bool:
4040
"""
4141

4242
@abc.abstractmethod
43-
def __await__(self):
43+
def __await__(self) -> typing.Generator[Any, Any, T]:
4444
pass
4545

4646

@@ -94,7 +94,7 @@ def get(self,
9494
"""
9595

9696
@abc.abstractmethod
97-
def state_keys(self) -> Awaitable[List[str]]:
97+
def state_keys(self) -> RestateDurableFuture[List[str]]:
9898
"""Returns the list of keys in the store."""
9999

100100
@abc.abstractmethod

python/restate/server_context.py

Lines changed: 45 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from restate.handler import Handler, handler_from_callable, invoke_handler
2929
from restate.serde import BytesSerde, DefaultSerde, JsonSerde, Serde
3030
from restate.server_types import Receive, Send
31-
from restate.vm import Failure, Invocation, NotReady, SuspendedException, VMWrapper, RunRetryConfig # pylint: disable=line-too-long
31+
from restate.vm import Failure, Invocation, NotReady, NotificationType, SuspendedException, VMWrapper, RunRetryConfig # pylint: disable=line-too-long
3232
from restate.vm import DoProgressAnyCompleted, DoProgressCancelSignalReceived, DoProgressReadFromInput, DoProgressExecuteRun # pylint: disable=line-too-long
3333

3434
T = TypeVar('T')
@@ -42,7 +42,7 @@ class ServerDurableFuture(RestateDurableFuture[T]):
4242
error: TerminalError | None = None
4343
state: typing.Literal["pending", "fulfilled", "rejected"] = "pending"
4444

45-
def __init__(self, context: "ServerInvocationContext", handle: int, awaitable_factory) -> None:
45+
def __init__(self, context: "ServerInvocationContext", handle: int, awaitable_factory: Callable[[], Awaitable[T]]) -> None:
4646
super().__init__()
4747
self.context = context
4848
self.source_notification_handle = handle
@@ -60,7 +60,7 @@ def is_completed(self):
6060
return True
6161

6262

63-
def __await__(self):
63+
def __await__(self) -> typing.Generator[Any, Any, T]:
6464

6565
async def await_point():
6666
match self.state:
@@ -74,6 +74,7 @@ async def await_point():
7474
self.state = "rejected"
7575
raise t
7676
case "fulfilled":
77+
assert self.value is not None
7778
return self.value
7879
case "rejected":
7980
assert self.error is not None
@@ -91,7 +92,7 @@ def __init__(self,
9192
result_handle: int,
9293
result_factory,
9394
invocation_id_handle: int,
94-
invocation_id_factory) -> None:
95+
invocation_id_factory: Callable[[], Awaitable[NotificationType]]) -> None:
9596
super().__init__(context, result_handle, result_factory)
9697
self.invocation_id_handle = invocation_id_handle
9798
self.invocation_id_factory = invocation_id_factory
@@ -100,8 +101,12 @@ def __init__(self,
100101
async def invocation_id(self) -> str:
101102
"""Get the invocation id."""
102103
if self._invocation_id is None:
103-
self._invocation_id = await self.invocation_id_factory()
104-
104+
res = await self.invocation_id_factory()
105+
if isinstance(res, str):
106+
self._invocation_id = res
107+
else:
108+
raise ValueError(f"Unexpected notification type for handle {self.invocation_id_handle}: {type(res).__name__}. "
109+
"Expected str.")
105110
if self._invocation_id is None:
106111
raise ValueError("invocation_id is None")
107112
return self._invocation_id
@@ -147,6 +152,7 @@ async def await_point():
147152
res = self.server_context.must_take_notification(handle)
148153
if res is None:
149154
return None
155+
assert isinstance(res, bytes)
150156
return serde.deserialize(res)
151157

152158
return await_point()
@@ -185,6 +191,7 @@ async def await_point():
185191
res = self.server_context.must_take_notification(handle)
186192
if res is None:
187193
return None
194+
assert isinstance(res, bytes)
188195
return serde.deserialize(res)
189196

190197
return await_point()
@@ -278,11 +285,11 @@ async def take_and_send_output(self):
278285
'more_body': True,
279286
})
280287

281-
def must_take_notification(self, handle: int) -> bytes | None:
288+
def must_take_notification(self, handle: int) -> bytes | None | str | List[str]:
282289
"""Take notification, which must be present. It must be either bytes or None"""
283290
res = self.vm.take_notification(handle)
284291
if isinstance(res, NotReady):
285-
raise ValueError(f"Notification for handle {handle} is not ready. "
292+
raise RuntimeError(f"Notification for handle {handle} is not ready. "
286293
"This likely indicates an unexpected async state.")
287294

288295
if res is None:
@@ -316,17 +323,35 @@ async def create_poll_or_cancel_coroutine(self, handles: typing.List[int]) -> No
316323
await self.take_and_send_output()
317324

318325

319-
def create_df(self, handle: int, serde: Serde[T] | None = None) -> ServerDurableFuture[T]:
320-
"""Create a durable future."""
321-
322-
async def transform():
326+
def create_df(self, handle: int, serde: Serde[T] | None = None) -> ServerDurableFuture[Any]:
327+
"""Create a durable future for handling asynchronous state operations.
328+
329+
This is a general-purpose factory method that creates a future that will resolve
330+
to whatever the VM notification system returns for the given handle. The specific
331+
return type depends on the operation that was performed.
332+
333+
Args:
334+
handle: The notification handle from the VM
335+
serde: Optional serializer/deserializer for converting raw bytes to specific types
336+
337+
Returns:
338+
A durable future that will eventually resolve to the appropriate value.
339+
The caller is responsible for knowing what type to expect based on the context.
340+
"""
341+
342+
async def fetch_result():
323343
await self.create_poll_or_cancel_coroutine([handle])
324344
res = self.must_take_notification(handle)
325345
if res is None or serde is None:
326346
return res
327-
return serde.deserialize(res)
347+
if isinstance(res, bytes):
348+
return serde.deserialize(res)
349+
if isinstance(res, (str, list)):
350+
return res
351+
raise ValueError(f"Unexpected notification type for handle {handle}: {type(res).__name__}. "
352+
"Expected bytes or None.")
328353

329-
return ServerDurableFuture(self, handle, transform)
354+
return ServerDurableFuture(self, handle, fetch_result)
330355

331356

332357

@@ -338,9 +363,12 @@ async def transform():
338363
res = self.must_take_notification(handle)
339364
if res is None or serde is None:
340365
return res
341-
return serde.deserialize(res)
366+
if isinstance(res, bytes):
367+
return serde.deserialize(res)
368+
if isinstance(res, (str, list)):
369+
return res
342370

343-
async def inv_id_factory():
371+
async def inv_id_factory() -> NotificationType:
344372
await self.create_poll_or_cancel_coroutine([invocation_id_handle])
345373
return self.must_take_notification(invocation_id_handle)
346374

@@ -351,7 +379,7 @@ def get(self, name: str, serde: Serde[T] = JsonSerde()) -> Awaitable[Optional[T]
351379
handle = self.vm.sys_get_state(name)
352380
return self.create_df(handle, serde) # type: ignore
353381

354-
def state_keys(self) -> Awaitable[List[str]]:
382+
def state_keys(self) -> RestateDurableFuture[List[str]]:
355383
return self.create_df(self.vm.sys_get_state_keys())
356384

357385
def set(self, name: str, value: T, serde: Serde[T] = JsonSerde()) -> None:

0 commit comments

Comments
 (0)