Skip to content

Commit d006698

Browse files
authored
Use DefaultSerde in more places (#64)
* Add type hints in more places * Add docstr for get methods
1 parent 21d682e commit d006698

File tree

3 files changed

+73
-48
lines changed

3 files changed

+73
-48
lines changed

python/restate/context.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from typing import Any, Awaitable, Callable, Dict, List, Optional, TypeVar, Union
1919
import typing
2020
from datetime import timedelta
21-
from restate.serde import DefaultSerde, JsonSerde, Serde
21+
from restate.serde import DefaultSerde, Serde
2222

2323
T = TypeVar('T')
2424
I = TypeVar('I')
@@ -92,9 +92,17 @@ class KeyValueStore(abc.ABC):
9292
@abc.abstractmethod
9393
def get(self,
9494
name: str,
95-
serde: Serde[T] = JsonSerde()) -> Awaitable[Optional[Any]]:
95+
serde: Serde[T] = DefaultSerde(),
96+
type_hint: Optional[typing.Type[T]] = None
97+
) -> Awaitable[Optional[Any]]:
9698
"""
9799
Retrieves the value associated with the given name.
100+
101+
Args:
102+
name: The state name
103+
serde: The serialization/deserialization mechanism. - if the default serde is used, a default serializer will be used based on the type.
104+
See also 'type_hint'.
105+
type_hint: The type hint of the return value. This is used to pick the serializer. If None, the type hint will be inferred from the action's return type, or the provided serializer.
98106
"""
99107

100108
@abc.abstractmethod
@@ -105,7 +113,7 @@ def state_keys(self) -> Awaitable[List[str]]:
105113
def set(self,
106114
name: str,
107115
value: T,
108-
serde: Serde[T] = JsonSerde()) -> None:
116+
serde: Serde[T] = DefaultSerde()) -> None:
109117
"""set the value associated with the given name."""
110118

111119
@abc.abstractmethod
@@ -266,7 +274,9 @@ def generic_send(self,
266274

267275
@abc.abstractmethod
268276
def awakeable(self,
269-
serde: Serde[T] = JsonSerde()) -> typing.Tuple[str, RestateDurableFuture[Any]]:
277+
serde: Serde[T] = DefaultSerde(),
278+
type_hint: Optional[typing.Type[T]] = None
279+
) -> typing.Tuple[str, RestateDurableFuture[Any]]:
270280
"""
271281
Returns the name of the awakeable and the future to be awaited.
272282
"""
@@ -275,7 +285,7 @@ def awakeable(self,
275285
def resolve_awakeable(self,
276286
name: str,
277287
value: I,
278-
serde: Serde[I] = JsonSerde()) -> None:
288+
serde: Serde[I] = DefaultSerde()) -> None:
279289
"""
280290
Resolves the awakeable with the given name.
281291
"""
@@ -293,7 +303,9 @@ def cancel(self, invocation_id: str):
293303
"""
294304

295305
@abc.abstractmethod
296-
def attach_invocation(self, invocation_id: str, serde: Serde[T] = JsonSerde()) -> RestateDurableFuture[T]:
306+
def attach_invocation(self, invocation_id: str, serde: Serde[T] = DefaultSerde(),
307+
type_hint: typing.Optional[typing.Type[T]] = None
308+
) -> RestateDurableFuture[T]:
297309
"""
298310
Attaches the invocation with the given id.
299311
"""
@@ -323,9 +335,17 @@ def key(self) -> str:
323335
@abc.abstractmethod
324336
def get(self,
325337
name: str,
326-
serde: Serde[T] = JsonSerde()) -> RestateDurableFuture[Optional[Any]]:
338+
serde: Serde[T] = DefaultSerde(),
339+
type_hint: Optional[typing.Type[T]] = None
340+
) -> RestateDurableFuture[Optional[Any]]:
327341
"""
328342
Retrieves the value associated with the given name.
343+
344+
Args:
345+
name: The state name
346+
serde: The serialization/deserialization mechanism. - if the default serde is used, a default serializer will be used based on the type.
347+
See also 'type_hint'.
348+
type_hint: The type hint of the return value. This is used to pick the serializer. If None, the type hint will be inferred from the action's return type, or the provided serializer.
329349
"""
330350

331351
@abc.abstractmethod
@@ -339,7 +359,7 @@ class DurablePromise(typing.Generic[T]):
339359
Represents a durable promise.
340360
"""
341361

342-
def __init__(self, name: str, serde: Serde[T] = JsonSerde()) -> None:
362+
def __init__(self, name: str, serde: Serde[T] = DefaultSerde()) -> None:
343363
self.name = name
344364
self.serde = serde
345365

@@ -373,7 +393,7 @@ class WorkflowContext(ObjectContext):
373393
"""
374394

375395
@abc.abstractmethod
376-
def promise(self, name: str, serde: Serde[T] = JsonSerde()) -> DurablePromise[Any]:
396+
def promise(self, name: str, serde: Serde[T] = DefaultSerde()) -> DurablePromise[Any]:
377397
"""
378398
Returns a durable promise with the given name.
379399
"""
@@ -384,7 +404,7 @@ class WorkflowSharedContext(ObjectSharedContext):
384404
"""
385405

386406
@abc.abstractmethod
387-
def promise(self, name: str, serde: Serde[T] = JsonSerde()) -> DurablePromise[Any]:
407+
def promise(self, name: str, serde: Serde[T] = DefaultSerde()) -> DurablePromise[Any]:
388408
"""
389409
Returns a durable promise with the given name.
390410
"""

python/restate/serde.py

Lines changed: 14 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,17 @@ class DefaultSerde(Serde[I]):
145145
while allowing automatic serde selection based on type hints.
146146
"""
147147

148+
def __init__(self, type_hint: typing.Optional[typing.Type[I]] = None):
149+
super().__init__()
150+
self.type_hint = type_hint
151+
152+
def with_maybe_type(self, type_hint: typing.Type[I] | None = None) -> "DefaultSerde[I]":
153+
"""
154+
Sets the type hint for the serde.
155+
"""
156+
self.type_hint = type_hint
157+
return self
158+
148159
def deserialize(self, buf: bytes) -> typing.Optional[I]:
149160
"""
150161
Deserializes a byte array into a Python object.
@@ -157,6 +168,8 @@ def deserialize(self, buf: bytes) -> typing.Optional[I]:
157168
"""
158169
if not buf:
159170
return None
171+
if is_pydantic(self.type_hint):
172+
return self.type_hint.model_validate_json(buf) # type: ignore
160173
return json.loads(buf)
161174

162175
def serialize(self, obj: typing.Optional[I]) -> bytes:
@@ -174,11 +187,9 @@ def serialize(self, obj: typing.Optional[I]) -> bytes:
174187
if obj is None:
175188
return bytes()
176189

177-
if isinstance(obj, PydanticBaseModel):
178-
# Use the Pydantic-specific serialization
190+
if is_pydantic(self.type_hint):
179191
return obj.model_dump_json().encode("utf-8") # type: ignore[attr-defined]
180192

181-
# Fallback to standard JSON serialization
182193
return json.dumps(obj).encode("utf-8")
183194

184195

@@ -218,22 +229,3 @@ def serialize(self, obj: typing.Optional[I]) -> bytes:
218229
return bytes()
219230
json_str = obj.model_dump_json() # type: ignore[attr-defined]
220231
return json_str.encode("utf-8")
221-
222-
223-
def for_type(type_hint: typing.Type[T]) -> Serde[T]:
224-
"""
225-
Automatically selects a serde based on the type hint.
226-
227-
Args:
228-
type_hint (typing.Type[T]): The type hint to use for serde selection.
229-
230-
Returns:
231-
Serde[T]: The serde to use for the given type hint.
232-
"""
233-
if is_pydantic(type_hint):
234-
return PydanticJsonSerde(type_hint)
235-
if isinstance(type_hint, bytes):
236-
return BytesSerde()
237-
if isinstance(type_hint, (dict, list, int, float, str, bool)):
238-
return JsonSerde()
239-
return DefaultSerde()

python/restate/server_context.py

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from restate.context import DurablePromise, ObjectContext, Request, RestateDurableCallFuture, RestateDurableFuture, SendHandle, RestateDurableSleepFuture
2727
from restate.exceptions import TerminalError
2828
from restate.handler import Handler, handler_from_callable, invoke_handler
29-
from restate.serde import BytesSerde, DefaultSerde, JsonSerde, Serde, for_type
29+
from restate.serde import BytesSerde, DefaultSerde, JsonSerde, Serde
3030
from restate.server_types import Receive, Send
3131
from restate.vm import Failure, Invocation, NotReady, SuspendedException, VMWrapper, RunRetryConfig # pylint: disable=line-too-long
3232
from restate.vm import DoProgressAnyCompleted, DoProgressCancelSignalReceived, DoProgressReadFromInput, DoProgressExecuteRun, DoWaitPendingRun
@@ -130,7 +130,7 @@ class ServerDurablePromise(DurablePromise):
130130
"""This class implements a durable promise API"""
131131

132132
def __init__(self, server_context: "ServerInvocationContext", name, serde) -> None:
133-
super().__init__(name=name, serde=JsonSerde() if serde is None else serde)
133+
super().__init__(name=name, serde=DefaultSerde() if serde is None else serde)
134134
self.server_context = server_context
135135

136136
def value(self) -> RestateDurableFuture[Any]:
@@ -359,15 +359,22 @@ async def inv_id_factory():
359359

360360
return ServerCallDurableFuture(self, handle, self._create_fetch_result_coroutine(handle, serde), inv_id_factory)
361361

362-
def get(self, name: str, serde: Serde[T] = JsonSerde()) -> Awaitable[Optional[T]]:
362+
def get(self, name: str,
363+
serde: Serde[T] = DefaultSerde(),
364+
type_hint: Optional[typing.Type[T]] = None
365+
) -> Awaitable[Optional[T]]:
363366
handle = self.vm.sys_get_state(name)
367+
if isinstance(serde, DefaultSerde):
368+
serde = serde.with_maybe_type(type_hint)
364369
return self.create_future(handle, serde) # type: ignore
365370

366371
def state_keys(self) -> Awaitable[List[str]]:
367372
return self.create_future(self.vm.sys_get_state_keys())
368373

369-
def set(self, name: str, value: T, serde: Serde[T] = JsonSerde()) -> None:
374+
def set(self, name: str, value: T, serde: Serde[T] = DefaultSerde()) -> None:
370375
"""Set the value associated with the given name."""
376+
if isinstance(serde, DefaultSerde):
377+
serde = serde.with_maybe_type(type(value))
371378
buffer = serde.serialize(value)
372379
self.vm.sys_set_state(name, bytes(buffer))
373380

@@ -423,12 +430,11 @@ def run(self,
423430
type_hint: Optional[typing.Type[T]] = None
424431
) -> RestateDurableFuture[T]:
425432

426-
if type_hint is not None:
427-
serde = for_type(type_hint)
428-
elif isinstance(serde, DefaultSerde):
429-
signature = inspect.signature(action, eval_str=True)
430-
serde = for_type(signature.return_annotation)
431-
433+
if isinstance(serde, DefaultSerde):
434+
if type_hint is None:
435+
signature = inspect.signature(action, eval_str=True)
436+
type_hint = signature.return_annotation
437+
serde = serde.with_maybe_type(type_hint)
432438
handle = self.vm.sys_run(name)
433439
self.run_coros_to_execute[handle] = lambda : self.create_run_coroutine(handle, action, serde, max_attempts, max_retry_duration)
434440
return self.create_future(handle, serde) # type: ignore
@@ -564,16 +570,20 @@ def generic_send(self, service: str, handler: str, arg: bytes, key: str | None =
564570
return send_handle
565571

566572
def awakeable(self,
567-
serde: typing.Optional[Serde[I]] = JsonSerde()) -> typing.Tuple[str, RestateDurableFuture[Any]]:
568-
assert serde is not None
573+
serde: Serde[I] = DefaultSerde(),
574+
type_hint: Optional[typing.Type[I]] = None
575+
) -> typing.Tuple[str, RestateDurableFuture[Any]]:
576+
if isinstance(serde, DefaultSerde):
577+
serde = serde.with_maybe_type(type_hint)
569578
name, handle = self.vm.sys_awakeable()
570579
return name, self.create_future(handle, serde)
571580

572581
def resolve_awakeable(self,
573582
name: str,
574583
value: I,
575-
serde: typing.Optional[Serde[I]] = JsonSerde()) -> None:
576-
assert serde is not None
584+
serde: Serde[I] = DefaultSerde()) -> None:
585+
if isinstance(serde, DefaultSerde):
586+
serde = serde.with_maybe_type(type(value))
577587
buf = serde.serialize(value)
578588
self.vm.sys_resolve_awakeable(name, buf)
579589

@@ -593,9 +603,12 @@ def cancel(self, invocation_id: str):
593603
raise ValueError("invocation_id cannot be None")
594604
self.vm.sys_cancel(invocation_id)
595605

596-
def attach_invocation(self, invocation_id: str, serde: Serde[T] = JsonSerde()) -> RestateDurableFuture[T]:
606+
def attach_invocation(self, invocation_id: str, serde: Serde[T] = DefaultSerde(),
607+
type_hint: Optional[typing.Type[T]] = None
608+
) -> RestateDurableFuture[T]:
597609
if invocation_id is None:
598610
raise ValueError("invocation_id cannot be None")
599-
assert serde is not None
611+
if isinstance(serde, DefaultSerde):
612+
serde = serde.with_maybe_type(type_hint)
600613
handle = self.vm.attach_invocation(invocation_id)
601614
return self.create_future(handle, serde)

0 commit comments

Comments
 (0)