Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/restate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from .endpoint import app

try:
from .harness import test_harness
from .harness import test_harness # type: ignore
except ImportError:
# we don't have the appropriate dependencies installed

Expand Down
7 changes: 4 additions & 3 deletions python/restate/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,11 @@ class RestateDurableFuture(typing.Generic[T], Awaitable[T]):
"""

@abc.abstractmethod
def __await__(self):
def __await__(self) -> typing.Generator[Any, Any, T]:
pass



# pylint: disable=R0903
class RestateDurableCallFuture(RestateDurableFuture[T]):
"""
Expand All @@ -57,7 +58,7 @@ class RestateDurableSleepFuture(RestateDurableFuture[None]):
"""

@abc.abstractmethod
def __await__(self):
def __await__(self) -> typing.Generator[Any, Any, None]:
pass


Expand Down Expand Up @@ -97,7 +98,7 @@ def get(self,
"""

@abc.abstractmethod
def state_keys(self) -> Awaitable[List[str]]:
def state_keys(self) -> RestateDurableFuture[List[str]]:
"""Returns the list of keys in the store."""

@abc.abstractmethod
Expand Down
2 changes: 1 addition & 1 deletion python/restate/discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def json_schema_from_type_hint(type_hint: Optional[TypeHint[Any]]) -> Any:
if not type_hint.annotation:
return None
if type_hint.is_pydantic:
return type_hint.annotation.model_json_schema(mode='serialization') # type: ignore
return type_hint.annotation.model_json_schema(mode='serialization')
return type_hint_to_json_schema(type_hint.annotation)


Expand Down
8 changes: 4 additions & 4 deletions python/restate/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,15 +89,15 @@ def update_handler_io_with_type_hints(handler_io: HandlerIO[I, O], signature: Si

if is_pydantic(annotation):
handler_io.input_type.is_pydantic = True
if isinstance(handler_io.input_serde, DefaultSerde): # type: ignore
if isinstance(handler_io.input_serde, DefaultSerde):
handler_io.input_serde = PydanticJsonSerde(annotation)

annotation = signature.return_annotation
handler_io.output_type = TypeHint(annotation=annotation, is_pydantic=False)

if is_pydantic(annotation):
handler_io.output_type.is_pydantic=True
if isinstance(handler_io.output_serde, DefaultSerde): # type: ignore
if isinstance(handler_io.output_serde, DefaultSerde):
handler_io.output_serde = PydanticJsonSerde(annotation)

# pylint: disable=R0902
Expand Down Expand Up @@ -170,11 +170,11 @@ async def invoke_handler(handler: Handler[I, O], ctx: Any, in_buffer: bytes) ->
"""
if handler.arity == 2:
try:
in_arg = handler.handler_io.input_serde.deserialize(in_buffer) # type: ignore
in_arg = handler.handler_io.input_serde.deserialize(in_buffer)
except Exception as e:
raise TerminalError(message=f"Unable to parse an input argument. {e}") from e
out_arg = await handler.fn(ctx, in_arg) # type: ignore [call-arg, arg-type]
else:
out_arg = await handler.fn(ctx) # type: ignore [call-arg]
out_buffer = handler.handler_io.output_serde.serialize(out_arg) # type: ignore
out_buffer = handler.handler_io.output_serde.serialize(out_arg)
return bytes(out_buffer)
17 changes: 12 additions & 5 deletions python/restate/harness.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#
"""Test containers based wrapper for the restate server"""

import abc
import asyncio
from dataclasses import dataclass
import threading
Expand Down Expand Up @@ -40,15 +41,18 @@ def runner():
return thread


class BindAddress:
class BindAddress(abc.ABC):
"""A bind address for the ASGI server"""

def get_local_bind_address(self):
@abc.abstractmethod
def get_local_bind_address(self) -> str:
"""return the local bind address for hypercorn to bind to"""

def get_endpoint_connection_string(self):
@abc.abstractmethod
def get_endpoint_connection_string(self) -> str:
"""return the SDK connection string to be used by restate"""

@abc.abstractmethod
def cleanup(self):
"""cleanup any resources used by the bind address"""

Expand All @@ -58,12 +62,15 @@ class TcpSocketBindAddress(BindAddress):
def __init__(self):
self.port = find_free_port()

def get_local_bind_address(self):
def get_local_bind_address(self) -> str:
return f"0.0.0.0:{self.port}"

def get_endpoint_connection_string(self):
def get_endpoint_connection_string(self) -> str:
return f"http://host.docker.internal:{self.port}"

def cleanup(self):
pass


class AsgiServer:
"""A simple ASGI server that listens on a unix domain socket"""
Expand Down
4 changes: 2 additions & 2 deletions python/restate/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ def handler(self,
kind: typing.Optional[typing.Literal["exclusive", "shared"]] = "exclusive",
accept: str = "application/json",
content_type: str = "application/json",
input_serde: Serde[I] = DefaultSerde[I](), # type: ignore
output_serde: Serde[O] = DefaultSerde[O](), # type: ignore
input_serde: Serde[I] = DefaultSerde(),
output_serde: Serde[O] = DefaultSerde(),
metadata: typing.Optional[dict] = None) -> typing.Callable:
"""
Decorator for defining a handler function.
Expand Down
44 changes: 18 additions & 26 deletions python/restate/server_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def __await__(self):
class ServerDurableSleepFuture(RestateDurableSleepFuture, ServerDurableFuture[None]):
"""This class implements a durable sleep future API"""

def __await__(self):
def __await__(self) -> typing.Generator[Any, Any, None]:
return self.future.__await__()

class ServerCallDurableFuture(RestateDurableCallFuture[T], ServerDurableFuture[T]):
Expand Down Expand Up @@ -210,7 +210,7 @@ def __init__(self,
self.attempt_headers = attempt_headers
self.send = send
self.receive = receive
self.run_coros_to_execute: dict[int, Callable[[], Awaitable[typing.Union[bytes | Failure]]]] = {}
self.run_coros_to_execute: dict[int, Callable[[], Awaitable[bytes | Failure]]] = {}
self.sync_point = SyncPoint()

async def enter(self):
Expand Down Expand Up @@ -324,54 +324,44 @@ async def wrapper(f):
if isinstance(do_progress_response, DoWaitPendingRun):
await self.sync_point.wait()

def create_future(self, handle: int, serde: Serde[T] | None = None) -> ServerDurableFuture[T]:
"""Create a durable future."""

async def transform():
def _create_fetch_result_coroutine(self, handle: int, serde: Serde[T] | None = None):
"""Create a coroutine that fetches a result from a notification handle."""
async def fetch_result():
if not self.vm.is_completed(handle):
await self.create_poll_or_cancel_coroutine([handle])
res = self.must_take_notification(handle)
if res is None or serde is None:
if res is None or serde is None or not isinstance(res, bytes):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure about this check. I'd need to review this in depth, can you please revert this check back for now?

Copy link
Contributor Author

@ouatu-ro ouatu-ro Mar 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently the function runs serde.deserialize on any object, I made it so it runs deserialisation only on bytes. will rewrite it to this for clarity:

            if res is None or serde is None:
                return res
            if isinstance(res, bytes):
                return serde.deserialize(res)
            return res

Let me know if you want it changed back, but I think it would be a bug? must_take_notification should never return a JSON, it should only return str, list[str], bytes, None.

return res
return serde.deserialize(res)

return ServerDurableFuture(self, handle, transform)
return fetch_result

def create_future(self, handle: int, serde: Serde[T] | None = None) -> ServerDurableFuture[T]:
"""Create a durable future for handling asynchronous state operations."""
return ServerDurableFuture(self, handle, self._create_fetch_result_coroutine(handle, serde))

def create_sleep_future(self, handle: int) -> ServerDurableSleepFuture:
"""Create a durable sleep future."""

async def transform():
if not self.vm.is_completed(handle):
await self.create_poll_or_cancel_coroutine([handle])
self.must_take_notification(handle)

return ServerDurableSleepFuture(self, handle, transform)


def create_call_future(self, handle: int, invocation_id_handle: int, serde: Serde[T] | None = None) -> ServerCallDurableFuture[T]:
"""Create a durable future."""

async def transform():
if not self.vm.is_completed(handle):
await self.create_poll_or_cancel_coroutine([handle])
res = self.must_take_notification(handle)
if res is None or serde is None:
return res
return serde.deserialize(res)

async def inv_id_factory():
if not self.vm.is_completed(invocation_id_handle):
await self.create_poll_or_cancel_coroutine([invocation_id_handle])
return self.must_take_notification(invocation_id_handle)

return ServerCallDurableFuture(self, handle, transform, inv_id_factory)

return ServerCallDurableFuture(self, handle, self._create_fetch_result_coroutine(handle, serde), inv_id_factory)

def get(self, name: str, serde: Serde[T] = JsonSerde()) -> Awaitable[Optional[T]]:
handle = self.vm.sys_get_state(name)
return self.create_future(handle, serde) # type: ignore

def state_keys(self) -> Awaitable[List[str]]:
def state_keys(self) -> RestateDurableFuture[List[str]]:
return self.create_future(self.vm.sys_get_state_keys()) # type: ignore

def set(self, name: str, value: T, serde: Serde[T] = JsonSerde()) -> None:
Expand Down Expand Up @@ -402,15 +392,17 @@ async def create_run_coroutine(self,
"""Create a coroutine to poll the handle."""
try:
if inspect.iscoroutinefunction(action):
action_result = await action() # type: ignore
action_result: T = await action() # type: ignore
else:
action_result = await asyncio.to_thread(action)
action_result = typing.cast(T, await asyncio.to_thread(action))

buffer = serde.serialize(action_result)
self.vm.propose_run_completion_success(handle, buffer)
return buffer
except TerminalError as t:
failure = Failure(code=t.status_code, message=t.message)
self.vm.propose_run_completion_failure(handle, failure)
return failure
# pylint: disable=W0718
except Exception as e:
if max_attempts is None and max_retry_duration is None:
Expand All @@ -420,7 +412,7 @@ async def create_run_coroutine(self,
max_duration_ms = None if max_retry_duration is None else int(max_retry_duration.total_seconds() * 1000)
config = RunRetryConfig(max_attempts=max_attempts, max_duration=max_duration_ms)
self.vm.propose_run_completion_transient(handle, failure=failure, attempt_duration_ms=1, config=config)

return failure
# pylint: disable=W0236
# pylint: disable=R0914
def run(self,
Expand Down
4 changes: 1 addition & 3 deletions python/restate/server_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,7 @@ class HTTPResponseBodyEvent(TypedDict):
more_body: bool


ASGIReceiveEvent = Union[
HTTPRequestEvent
]
ASGIReceiveEvent = HTTPRequestEvent


ASGISendEvent = Union[
Expand Down
6 changes: 3 additions & 3 deletions python/restate/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ def handler(self,
name: typing.Optional[str] = None,
accept: str = "application/json",
content_type: str = "application/json",
input_serde: Serde[I] = DefaultSerde[I](), # type: ignore
output_serde: Serde[O] = DefaultSerde[O](), # type: ignore
metadata: typing.Optional[typing.Dict[str, str]] = None) -> typing.Callable: # type: ignore
input_serde: Serde[I] = DefaultSerde(),
output_serde: Serde[O] = DefaultSerde(),
metadata: typing.Optional[typing.Dict[str, str]] = None) -> typing.Callable:

"""
Decorator for defining a handler function.
Expand Down