Skip to content

Commit b878470

Browse files
committed
Support side effect retries
1 parent 8e3d78f commit b878470

File tree

7 files changed

+158
-24
lines changed

7 files changed

+158
-24
lines changed

Cargo.lock

Lines changed: 3 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,5 @@ doc = false
1414
[dependencies]
1515
pyo3 = { version = "0.22.0", features = ["extension-module"] }
1616
tracing-subscriber = { version = "0.3", features = ["fmt", "env-filter"] }
17-
restate-sdk-shared-core = "0.0.5"
17+
restate-sdk-shared-core = "0.1.0"
18+
bytes = "1.6.0"

python/restate/context.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
# directory of this repository or package, or at
99
# https://github.com/restatedev/sdk-typescript/blob/main/LICENSE
1010
#
11+
# pylint: disable=R0913,C0301
1112
"""
1213
Restate Context
1314
"""
@@ -25,6 +26,7 @@
2526

2627
RunAction = Union[Callable[[], T], Callable[[], Awaitable[T]]]
2728

29+
2830
@dataclass
2931
class Request:
3032
"""
@@ -79,7 +81,6 @@ def clear(self, name: str) -> None:
7981
def clear_all(self) -> None:
8082
"""clear all the values in the store."""
8183

82-
8384
class Context(abc.ABC):
8485
"""
8586
Represents the context of the current invocation.
@@ -95,9 +96,21 @@ def request(self) -> Request:
9596
def run(self,
9697
name: str,
9798
action: RunAction[T],
98-
serde: Serde[T] = JsonSerde()) -> Awaitable[T]:
99+
serde: Serde[T] = JsonSerde(),
100+
max_attempts: typing.Optional[int] = None,
101+
max_retry_duration: typing.Optional[timedelta] = None) -> Awaitable[T]:
99102
"""
100103
Runs the given action with the given name.
104+
105+
Args:
106+
name: The name of the action.
107+
action: The action to run.
108+
serde: The serialization/deserialization mechanism.
109+
max_attempts: The maximum number of retry attempts to complete the action.
110+
If None, the action will be retried indefinitely, until it succeeds.
111+
Otherwise, the action will be retried until the maximum number of attempts is reached and then it will raise a TerminalError.
112+
max_retry_duration: The maximum duration for retrying. If None, the action will be retried indefinitely, until it succeeds.
113+
Otherwise, the action will be retried until the maximum duration is reached and then it will raise a TerminalError.
101114
"""
102115

103116
@abc.abstractmethod

python/restate/discovery.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,6 @@ def compute_discovery(endpoint: RestateEndpoint, discovered_as : typing.Literal[
147147
else:
148148
protocol_mode = PROTOCOL_MODES[discovered_as]
149149
return Endpoint(protocolMode=protocol_mode,
150-
minProtocolVersion=1,
151-
maxProtocolVersion=1,
150+
minProtocolVersion=2,
151+
maxProtocolVersion=2,
152152
services=services)

python/restate/server_context.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from restate.handler import Handler, handler_from_callable, invoke_handler
2222
from restate.serde import BytesSerde, JsonSerde, Serde
2323
from restate.server_types import Receive, Send
24-
from restate.vm import Failure, Invocation, NotReady, SuspendedException, VMWrapper
24+
from restate.vm import Failure, Invocation, NotReady, SuspendedException, VMWrapper, RunRetryConfig
2525

2626

2727
T = TypeVar('T')
@@ -227,10 +227,13 @@ def request(self) -> Request:
227227
)
228228

229229
# pylint: disable=W0236
230+
# pylint: disable=R0914
230231
async def run(self,
231232
name: str,
232233
action: Callable[[], T] | Callable[[], Awaitable[T]],
233-
serde: Optional[Serde[T]] = JsonSerde()) -> T:
234+
serde: Optional[Serde[T]] = JsonSerde(),
235+
max_attempts: Optional[int] = None,
236+
max_retry_duration: Optional[timedelta] = None) -> T:
234237
assert serde is not None
235238
res = self.vm.sys_run_enter(name)
236239
if isinstance(res, Failure):
@@ -254,6 +257,20 @@ async def run(self,
254257
await self.create_poll_coroutine(handle)
255258
# unreachable
256259
assert False
260+
# pylint: disable=W0718
261+
except Exception as e:
262+
if max_attempts is None and max_retry_duration is None:
263+
# no retry policy
264+
raise e
265+
failure = Failure(code=500, message=str(e))
266+
max_duration_ms = None if max_retry_duration is None else int(max_retry_duration.total_seconds() * 1000)
267+
config = RunRetryConfig(max_attempts=max_attempts, max_duration=max_duration_ms)
268+
exit_handle = self.vm.sys_run_exit_transient(failure=failure, attempt_duration_ms=1, config=config)
269+
if exit_handle is None:
270+
raise e from None # avoid the traceback that says exception was raised while handling another exception
271+
await self.create_poll_coroutine(exit_handle)
272+
# unreachable
273+
assert False
257274

258275
def sleep(self, delta: timedelta) -> Awaitable[None]:
259276
# convert timedelta to milliseconds

python/restate/vm.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from dataclasses import dataclass
1717
import typing
18-
from restate._internal import PyVM, PyFailure, PySuspended, PyVoid, PyStateKeys # pylint: disable=import-error,no-name-in-module
18+
from restate._internal import PyVM, PyFailure, PySuspended, PyVoid, PyStateKeys, PyExponentialRetryConfig # pylint: disable=import-error,no-name-in-module,line-too-long
1919

2020
@dataclass
2121
class Invocation:
@@ -28,6 +28,14 @@ class Invocation:
2828
input_buffer: bytes
2929
key: str
3030

31+
@dataclass
32+
class RunRetryConfig:
33+
"""
34+
Expo Retry Configuration
35+
"""
36+
initial_interval: typing.Optional[int] = None
37+
max_attempts: typing.Optional[int] = None
38+
max_duration: typing.Optional[int] = None
3139

3240
@dataclass
3341
class Failure:
@@ -312,6 +320,24 @@ def sys_run_exit_failure(self, output: Failure) -> int:
312320
res = PyFailure(output.code, output.message)
313321
return self.vm.sys_run_exit_failure(res)
314322

323+
# pylint: disable=line-too-long
324+
def sys_run_exit_transient(self, failure: Failure, attempt_duration_ms: int, config: RunRetryConfig) -> int | None:
325+
"""
326+
Exit a side effect with a transient Error.
327+
This requires a retry policy to be provided.
328+
"""
329+
py_failure = PyFailure(failure.code, failure.message)
330+
py_config = PyExponentialRetryConfig(config.initial_interval, config.max_attempts, config.max_duration)
331+
try:
332+
handle = self.vm.sys_run_exit_failure_transient(py_failure, attempt_duration_ms, py_config)
333+
# The VM decided not to retry, therefore we get back an handle that will be resolved
334+
# with a terminal failure.
335+
return handle
336+
# pylint: disable=bare-except
337+
except:
338+
# The VM decided to retry, therefore we tear down the current execution
339+
return None
340+
315341
def sys_end(self):
316342
"""
317343
This method is responsible for ending the system.

src/lib.rs

Lines changed: 90 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@ use pyo3::prelude::*;
33
use pyo3::types::{PyBytes, PyNone};
44
use restate_sdk_shared_core::{
55
AsyncResultHandle, CoreVM, Failure, Header, IdentityVerifier, Input, NonEmptyValue,
6-
ResponseHead, RunEnterResult, SuspendedOrVMError, TakeOutputResult, Target, VMError, Value, VM,
6+
ResponseHead, RetryPolicy, RunEnterResult, RunExitResult, SuspendedOrVMError, TakeOutputResult,
7+
Target, VMError, Value, VM,
78
};
89
use std::borrow::Cow;
910
use std::time::Duration;
@@ -87,6 +88,46 @@ impl PyFailure {
8788
}
8889
}
8990

91+
#[pyclass]
92+
#[derive(Clone)]
93+
struct PyExponentialRetryConfig {
94+
#[pyo3(get, set)]
95+
initial_interval: Option<u64>,
96+
#[pyo3(get, set)]
97+
max_attempts: Option<u32>,
98+
#[pyo3(get, set)]
99+
max_duration: Option<u64>,
100+
}
101+
102+
#[pymethods]
103+
impl PyExponentialRetryConfig {
104+
#[pyo3(signature = (initial_interval=None, max_attempts=None, max_duration=None))]
105+
#[new]
106+
fn new(
107+
initial_interval: Option<u64>,
108+
max_attempts: Option<u32>,
109+
max_duration: Option<u64>,
110+
) -> Self {
111+
Self {
112+
initial_interval,
113+
max_attempts,
114+
max_duration,
115+
}
116+
}
117+
}
118+
119+
impl From<PyExponentialRetryConfig> for RetryPolicy {
120+
fn from(value: PyExponentialRetryConfig) -> Self {
121+
RetryPolicy::Exponential {
122+
initial_interval: Duration::from_millis(value.initial_interval.unwrap_or(10)),
123+
max_attempts: value.max_attempts,
124+
max_duration: value.max_duration.map(Duration::from_millis),
125+
factor: 2.0,
126+
max_interval: None,
127+
}
128+
}
129+
}
130+
90131
impl From<Failure> for PyFailure {
91132
fn from(value: Failure) -> Self {
92133
PyFailure {
@@ -133,7 +174,7 @@ impl From<Input> for PyInput {
133174
random_seed: value.random_seed,
134175
key: value.key,
135176
headers: value.headers.into_iter().map(Into::into).collect(),
136-
input: value.input,
177+
input: value.input.into(),
137178
}
138179
}
139180
}
@@ -186,7 +227,8 @@ impl PyVM {
186227
// Notifications
187228

188229
fn notify_input(mut self_: PyRefMut<'_, Self>, buffer: &Bound<'_, PyBytes>) {
189-
self_.vm.notify_input(buffer.as_bytes().to_vec());
230+
let buf = buffer.as_bytes().to_vec().into();
231+
self_.vm.notify_input(buf);
190232
}
191233

192234
fn notify_input_closed(mut self_: PyRefMut<'_, Self>) {
@@ -195,9 +237,11 @@ impl PyVM {
195237

196238
#[pyo3(signature = (error, description=None))]
197239
fn notify_error(mut self_: PyRefMut<'_, Self>, error: String, description: Option<String>) {
198-
self_.vm.notify_error(
240+
CoreVM::notify_error(
241+
&mut self_.vm,
199242
Cow::Owned(error),
200243
description.map(Cow::Owned).unwrap_or(Cow::Borrowed("")),
244+
None,
201245
);
202246
}
203247

@@ -280,7 +324,7 @@ impl PyVM {
280324
) -> Result<(), PyVMError> {
281325
self_
282326
.vm
283-
.sys_state_set(key, buffer.as_bytes().to_vec())
327+
.sys_state_set(key, buffer.as_bytes().to_vec().into())
284328
.map_err(Into::into)
285329
}
286330

@@ -319,7 +363,7 @@ impl PyVM {
319363
handler,
320364
key,
321365
},
322-
buffer.as_bytes().to_vec(),
366+
buffer.as_bytes().to_vec().into(),
323367
)
324368
.map(Into::into)
325369
.map_err(Into::into)
@@ -342,7 +386,7 @@ impl PyVM {
342386
handler,
343387
key,
344388
},
345-
buffer.as_bytes().to_vec(),
389+
buffer.as_bytes().to_vec().into(),
346390
delay.map(Duration::from_millis),
347391
)
348392
.map_err(Into::into)
@@ -365,7 +409,10 @@ impl PyVM {
365409
) -> Result<(), PyVMError> {
366410
self_
367411
.vm
368-
.sys_complete_awakeable(id, NonEmptyValue::Success(buffer.as_bytes().to_vec()))
412+
.sys_complete_awakeable(
413+
id,
414+
NonEmptyValue::Success(buffer.as_bytes().to_vec().into()),
415+
)
369416
.map_err(Into::into)
370417
}
371418

@@ -409,7 +456,10 @@ impl PyVM {
409456
) -> Result<PyAsyncResultHandle, PyVMError> {
410457
self_
411458
.vm
412-
.sys_complete_promise(key, NonEmptyValue::Success(buffer.as_bytes().to_vec()))
459+
.sys_complete_promise(
460+
key,
461+
NonEmptyValue::Success(buffer.as_bytes().to_vec().into()),
462+
)
413463
.map(Into::into)
414464
.map_err(Into::into)
415465
}
@@ -446,28 +496,52 @@ impl PyVM {
446496
RunEnterResult::Executed(NonEmptyValue::Failure(f)) => {
447497
PyFailure::from(f).into_py(py).into_bound(py).into_any()
448498
}
449-
RunEnterResult::NotExecuted => PyNone::get_bound(py).to_owned().into_any(),
499+
RunEnterResult::NotExecuted(_retry_info) => PyNone::get_bound(py).to_owned().into_any(),
450500
})
451501
}
452502

453503
fn sys_run_exit_success(
454504
mut self_: PyRefMut<'_, Self>,
455505
buffer: &Bound<'_, PyBytes>,
506+
) -> Result<PyAsyncResultHandle, PyVMError> {
507+
CoreVM::sys_run_exit(
508+
&mut self_.vm,
509+
RunExitResult::Success(buffer.as_bytes().to_vec().into()),
510+
RetryPolicy::None,
511+
)
512+
.map(Into::into)
513+
.map_err(Into::into)
514+
}
515+
516+
fn sys_run_exit_failure(
517+
mut self_: PyRefMut<'_, Self>,
518+
value: PyFailure,
456519
) -> Result<PyAsyncResultHandle, PyVMError> {
457520
self_
458521
.vm
459-
.sys_run_exit(NonEmptyValue::Success(buffer.as_bytes().to_vec()))
522+
.sys_run_exit(
523+
RunExitResult::TerminalFailure(value.into()),
524+
RetryPolicy::None,
525+
)
460526
.map(Into::into)
461527
.map_err(Into::into)
462528
}
463529

464-
fn sys_run_exit_failure(
530+
fn sys_run_exit_failure_transient(
465531
mut self_: PyRefMut<'_, Self>,
466532
value: PyFailure,
533+
attempt_duration: u64,
534+
config: PyExponentialRetryConfig,
467535
) -> Result<PyAsyncResultHandle, PyVMError> {
468536
self_
469537
.vm
470-
.sys_run_exit(NonEmptyValue::Failure(value.into()))
538+
.sys_run_exit(
539+
RunExitResult::RetryableFailure {
540+
attempt_duration: Duration::from_millis(attempt_duration),
541+
failure: value.into(),
542+
},
543+
config.into(),
544+
)
471545
.map(Into::into)
472546
.map_err(Into::into)
473547
}
@@ -478,7 +552,7 @@ impl PyVM {
478552
) -> Result<(), PyVMError> {
479553
self_
480554
.vm
481-
.sys_write_output(NonEmptyValue::Success(buffer.as_bytes().to_vec()))
555+
.sys_write_output(NonEmptyValue::Success(buffer.as_bytes().to_vec().into()))
482556
.map(Into::into)
483557
.map_err(Into::into)
484558
}
@@ -558,6 +632,8 @@ fn _internal(m: &Bound<'_, PyModule>) -> PyResult<()> {
558632
m.add_class::<PySuspended>()?;
559633
m.add_class::<PyVM>()?;
560634
m.add_class::<PyIdentityVerifier>()?;
635+
m.add_class::<PyExponentialRetryConfig>()?;
636+
561637
m.add("VMException", m.py().get_type_bound::<VMException>())?;
562638
m.add(
563639
"IdentityKeyException",

0 commit comments

Comments
 (0)