Skip to content

Commit 628d238

Browse files
authored
enhancement: Add support for function args & kwargs type hints in ctx.run. (#109)
* feat: add run_typed function * feat: depricate run function * feat: add typing_extensions as dependency * feat: add better message * feat: add run_typed implementation * feat: update docs for RunOptions * feat: add RunOptions to __init__.py * feat: update example * feat: requiriments.txt * fix: requriments.txt * feat: bump version to 0.9.0 * feat: update tests to use run_typed * I have read the CLA Document and I hereby sign the CLA * fix: test case * fix: remove version bump * fix: update cargo.lock
1 parent e810585 commit 628d238

File tree

9 files changed

+139
-15
lines changed

9 files changed

+139
-15
lines changed

examples/workflow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def payment_gateway():
3737
print("To decline use:")
3838
print(f"""curl http://localhost:8080/payment/{workflow_key}/payment_verified --json '"declined"' """)
3939

40-
await ctx.run("payment", payment_gateway)
40+
await ctx.run_typed("payment", payment_gateway)
4141

4242
ctx.set("status", "waiting for the payment provider to approve")
4343

pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@ license = { file = "LICENSE" }
1212
authors = [
1313
{ name = "Restate Developers", email = "[email protected]" }
1414
]
15+
dependencies = [
16+
"typing-extensions>=4.14.0"
17+
]
18+
1519

1620
[project.optional-dependencies]
1721
test = ["pytest", "hypercorn"]

python/restate/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from .context import Context, ObjectContext, ObjectSharedContext
2121
from .context import WorkflowContext, WorkflowSharedContext
2222
# pylint: disable=line-too-long
23-
from .context import DurablePromise, RestateDurableFuture, RestateDurableCallFuture, RestateDurableSleepFuture, SendHandle
23+
from .context import DurablePromise, RestateDurableFuture, RestateDurableCallFuture, RestateDurableSleepFuture, SendHandle, RunOptions
2424
from .exceptions import TerminalError
2525
from .asyncio import as_completed, gather, wait_completed, select
2626

@@ -50,6 +50,7 @@ def test_harness(app, follow_logs = False, restate_image = ""): # type: ignore
5050
"RestateDurableCallFuture",
5151
"RestateDurableSleepFuture",
5252
"SendHandle",
53+
"RunOptions",
5354
"TerminalError",
5455
"app",
5556
"test_harness",

python/restate/context.py

Lines changed: 95 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,39 @@
1515

1616
import abc
1717
from dataclasses import dataclass
18-
from typing import Any, Awaitable, Callable, Dict, List, Optional, TypeVar, Union, Coroutine, overload
18+
from typing import Any, Awaitable, Callable, Dict, List, Optional, TypeVar, Union, Coroutine, overload, ParamSpec
1919
import typing
2020
from datetime import timedelta
21+
22+
import typing_extensions
2123
from restate.serde import DefaultSerde, Serde
2224

2325
T = TypeVar('T')
2426
I = TypeVar('I')
2527
O = TypeVar('O')
28+
P = ParamSpec('P')
2629

27-
RunAction = Union[Callable[..., Coroutine[Any, Any, T]], Callable[..., T]]
2830
HandlerType = Union[Callable[[Any, I], Awaitable[O]], Callable[[Any], Awaitable[O]]]
31+
RunAction = Union[Callable[..., Coroutine[Any, Any, T]], Callable[..., T]]
32+
33+
@dataclass
34+
class RunOptions(typing.Generic[T]):
35+
"""
36+
Options for running an action.
37+
"""
38+
39+
serde: Serde[T] = DefaultSerde()
40+
"""The serialization/deserialization mechanism. - if the default serde is used, a default serializer will be used based on the type.
41+
See also 'type_hint'."""
42+
max_attempts: Optional[int] = None
43+
"""The maximum number of retry attempts to complete the action.
44+
If None, the action will be retried indefinitely, until it succeeds.
45+
Otherwise, the action will be retried until the maximum number of attempts is reached and then it will raise a TerminalError."""
46+
max_retry_duration: Optional[timedelta] = None
47+
"""The maximum duration for retrying. If None, the action will be retried indefinitely, until it succeeds.
48+
Otherwise, the action will be retried until the maximum duration is reached and then it will raise a TerminalError."""
49+
type_hint: Optional[typing.Type[T]] = None
50+
"""The type hint of the return value of the action. This is used to pick the serializer. If None, the type hint will be inferred from the action's return type, or the provided serializer."""
2951

3052
# pylint: disable=R0903
3153
class RestateDurableFuture(typing.Generic[T], Awaitable[T]):
@@ -197,6 +219,8 @@ def request(self) -> Request:
197219
Returns the request object.
198220
"""
199221

222+
223+
@typing_extensions.deprecated("`run` is deprecated, use `run_typed` instead for better type safety")
200224
@overload
201225
@abc.abstractmethod
202226
def run(self,
@@ -226,6 +250,7 @@ def run(self,
226250
227251
"""
228252

253+
@typing_extensions.deprecated("`run` is deprecated, use `run_typed` instead for better type safety")
229254
@overload
230255
@abc.abstractmethod
231256
def run(self,
@@ -255,6 +280,7 @@ def run(self,
255280
256281
"""
257282

283+
@typing_extensions.deprecated("`run` is deprecated, use `run_typed` instead for better type safety")
258284
@abc.abstractmethod
259285
def run(self,
260286
name: str,
@@ -283,6 +309,73 @@ def run(self,
283309
284310
"""
285311

312+
313+
@overload
314+
@abc.abstractmethod
315+
def run_typed(self,
316+
name: str,
317+
action: Callable[P, Coroutine[Any, Any,T]],
318+
options: RunOptions[T] = RunOptions(),
319+
/,
320+
*args: P.args,
321+
**kwargs: P.kwargs,
322+
) -> RestateDurableFuture[T]:
323+
"""
324+
Typed version of run that provides type hints for the function arguments.
325+
Runs the given action with the given name.
326+
327+
Args:
328+
name: The name of the action.
329+
action: The action to run.
330+
options: The options for the run.
331+
*args: The arguments to pass to the action.
332+
**kwargs: The keyword arguments to pass to the action.
333+
"""
334+
335+
@overload
336+
@abc.abstractmethod
337+
def run_typed(self,
338+
name: str,
339+
action: Callable[P, T],
340+
options: RunOptions[T] = RunOptions(),
341+
/,
342+
*args: P.args,
343+
**kwargs: P.kwargs,
344+
) -> RestateDurableFuture[T]:
345+
"""
346+
Typed version of run that provides type hints for the function arguments.
347+
Runs the given coroutine action with the given name.
348+
349+
Args:
350+
name: The name of the action.
351+
action: The action to run.
352+
options: The options for the run.
353+
*args: The arguments to pass to the action.
354+
**kwargs: The keyword arguments to pass to the action.
355+
"""
356+
357+
@abc.abstractmethod
358+
def run_typed(self,
359+
name: str,
360+
action: Union[Callable[P, Coroutine[Any, Any, T]], Callable[P, T]],
361+
options: RunOptions[T] = RunOptions(),
362+
/,
363+
*args: P.args,
364+
**kwargs: P.kwargs,
365+
) -> RestateDurableFuture[T]:
366+
"""
367+
Typed version of run that provides type hints for the function arguments.
368+
Runs the given action with the given name.
369+
370+
Args:
371+
name: The name of the action.
372+
action: The action to run.
373+
options: The options for the run.
374+
*args: The arguments to pass to the action.
375+
**kwargs: The keyword arguments to pass to the action.
376+
377+
"""
378+
286379
@abc.abstractmethod
287380
def sleep(self, delta: timedelta) -> RestateDurableSleepFuture:
288381
"""

python/restate/server_context.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,18 @@
2121
from datetime import timedelta
2222
import inspect
2323
import functools
24-
from typing import Any, Awaitable, Callable, Dict, List, Optional, TypeVar
24+
from typing import Any, Awaitable, Callable, Dict, List, Optional, TypeVar, Union, Coroutine
2525
import typing
2626
import traceback
2727

28-
from restate.context import DurablePromise, AttemptFinishedEvent, HandlerType, ObjectContext, Request, RestateDurableCallFuture, RestateDurableFuture, RunAction, SendHandle, RestateDurableSleepFuture
28+
from restate.context import DurablePromise, AttemptFinishedEvent, HandlerType, ObjectContext, Request, RestateDurableCallFuture, RestateDurableFuture, RunAction, SendHandle, RestateDurableSleepFuture, RunOptions, P
2929
from restate.exceptions import TerminalError
3030
from restate.handler import Handler, handler_from_callable, invoke_handler
3131
from restate.serde import BytesSerde, DefaultSerde, JsonSerde, Serde
3232
from restate.server_types import ReceiveChannel, Send
3333
from restate.vm import Failure, Invocation, NotReady, SuspendedException, VMWrapper, RunRetryConfig # pylint: disable=line-too-long
3434
from restate.vm import DoProgressAnyCompleted, DoProgressCancelSignalReceived, DoProgressReadFromInput, DoProgressExecuteRun, DoWaitPendingRun
35+
import typing_extensions
3536

3637

3738
T = TypeVar('T')
@@ -510,6 +511,7 @@ async def create_run_coroutine(self,
510511
self.vm.propose_run_completion_transient(handle, failure=failure, attempt_duration_ms=1, config=config)
511512
# pylint: disable=W0236
512513
# pylint: disable=R0914
514+
@typing_extensions.deprecated("`run` is deprecated, use `run_typed` instead for better type safety")
513515
def run(self,
514516
name: str,
515517
action: RunAction[T],
@@ -536,6 +538,25 @@ def run(self,
536538
self.run_coros_to_execute[handle] = lambda : self.create_run_coroutine(handle, noargs_action, serde, max_attempts, max_retry_duration)
537539
return self.create_future(handle, serde) # type: ignore
538540

541+
def run_typed(
542+
self,
543+
name: str,
544+
action: Union[Callable[P, T], Callable[P, Coroutine[Any, Any, T]]],
545+
options: RunOptions[T] = RunOptions(),
546+
/,
547+
*args: P.args,
548+
**kwargs: P.kwargs,
549+
) -> RestateDurableFuture[T]:
550+
if isinstance(options.serde, DefaultSerde):
551+
if options.type_hint is None:
552+
signature = inspect.signature(action, eval_str=True)
553+
options.type_hint = signature.return_annotation
554+
options.serde = options.serde.with_maybe_type(options.type_hint)
555+
handle = self.vm.sys_run(name)
556+
557+
func = functools.partial(action, *args, **kwargs)
558+
self.run_coros_to_execute[handle] = lambda : self.create_run_coroutine(handle, func, options.serde, options.max_attempts, options.max_retry_duration)
559+
return self.create_future(handle, options.serde)
539560

540561
def sleep(self, delta: timedelta) -> RestateDurableSleepFuture:
541562
# convert timedelta to milliseconds

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@ pytest
66
pydantic
77
httpx
88
testcontainers
9+
typing-extensions>=4.14.0

test-services/services/failing.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from restate import VirtualObject, ObjectContext
1717
from restate.exceptions import TerminalError
18+
from restate import RunOptions
1819

1920
failing = VirtualObject("Failing")
2021

@@ -45,23 +46,23 @@ async def terminally_failing_side_effect(ctx: ObjectContext, error_message: str)
4546
def side_effect():
4647
raise TerminalError(message=error_message)
4748

48-
await ctx.run("sideEffect", side_effect)
49+
await ctx.run_typed("sideEffect", side_effect)
4950
raise ValueError("Should not reach here")
5051

5152

5253
eventual_success_side_effects = 0
5354

5455
@failing.handler(name="sideEffectSucceedsAfterGivenAttempts")
5556
async def side_effect_succeeds_after_given_attempts(ctx: ObjectContext, minimum_attempts: int) -> int:
56-
5757
def side_effect():
5858
global eventual_success_side_effects
5959
eventual_success_side_effects += 1
6060
if eventual_success_side_effects >= minimum_attempts:
6161
return eventual_success_side_effects
6262
raise ValueError(f"Failed at attempt: {eventual_success_side_effects}")
6363

64-
return await ctx.run("sideEffect", side_effect, max_attempts=minimum_attempts + 1) # type: ignore
64+
options: RunOptions[int] = RunOptions(max_attempts=minimum_attempts + 1)
65+
return await ctx.run_typed("sideEffect", side_effect, options)
6566

6667
eventual_failure_side_effects = 0
6768

@@ -74,7 +75,8 @@ def side_effect():
7475
raise ValueError(f"Failed at attempt: {eventual_failure_side_effects}")
7576

7677
try:
77-
await ctx.run("sideEffect", side_effect, max_attempts=retry_policy_max_retry_count)
78+
options: RunOptions[int] = RunOptions(max_attempts=retry_policy_max_retry_count)
79+
await ctx.run_typed("sideEffect", side_effect, options)
7880
raise ValueError("Side effect did not fail.")
7981
except TerminalError as t:
8082
global eventual_failure_side_effects

test-services/services/interpreter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ async def await_promise(index: int) -> None:
227227
coros[i] = (expected, service.echo_later(expected, command['sleep']))
228228
elif command_type == SIDE_EFFECT:
229229
expected = f"hello-{i}"
230-
result = await ctx.run("sideEffect", lambda : expected) # pylint: disable=W0640
230+
result = await ctx.run_typed("sideEffect", lambda: expected)
231231
if result != expected:
232232
raise TerminalError(f"Expected {expected} but got {result}")
233233
elif command_type == SLOW_SIDE_EFFECT:
@@ -246,7 +246,7 @@ async def side_effect():
246246
if bool(random.getrandbits(1)):
247247
raise ValueError("Random error")
248248

249-
await ctx.run("throwingSideEffect", side_effect)
249+
await ctx.run_typed("throwingSideEffect", side_effect)
250250
elif command_type == INCREMENT_STATE_COUNTER_INDIRECTLY:
251251
await service.increment_indirectly(layer=layer, key=ctx.key())
252252
elif command_type == AWAIT_PROMISE:

test-services/services/virtual_object_command_interpreter.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,9 +116,9 @@ def to_durable_future(ctx: ObjectContext, cmd: AwaitableCommand) -> RestateDurab
116116
elif cmd['type'] == "sleep":
117117
return ctx.sleep(timedelta(milliseconds=cmd['timeoutMillis']))
118118
elif cmd['type'] == "runThrowTerminalException":
119-
def side_effect(reason):
119+
def side_effect(reason: str):
120120
raise TerminalError(message=reason)
121-
res = ctx.run("run should fail command", side_effect, args=(cmd['reason'],))
121+
res = ctx.run_typed("run should fail command", side_effect, reason=cmd['reason'])
122122
return res
123123

124124
@virtual_object_command_interpreter.handler(name="interpretCommands")
@@ -142,7 +142,9 @@ async def interpret_commands(ctx: ObjectContext, req: InterpretRequest):
142142
result = ""
143143
elif cmd['type'] == "getEnvVariable":
144144
env_name = cmd['envName']
145-
result = await ctx.run("get_env", lambda e=env_name: os.environ.get(e, ""))
145+
def side_effect(env_name: str):
146+
return os.environ.get(env_name, "")
147+
result = await ctx.run_typed("get_env", side_effect, env_name=env_name)
146148
elif cmd['type'] == "awaitOne":
147149
awaitable = to_durable_future(ctx, cmd['command'])
148150
# We need this dance because the Python SDK doesn't support .map on futures

0 commit comments

Comments
 (0)