Skip to content

Commit 9ac204f

Browse files
authored
Use RestateDurableFuture in workflow methods (#61)
1 parent 79415e7 commit 9ac204f

File tree

2 files changed

+11
-28
lines changed

2 files changed

+11
-28
lines changed

python/restate/context.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,7 @@ def peek(self) -> Awaitable[typing.Optional[T]]:
355355
"""
356356

357357
@abc.abstractmethod
358-
def value(self) -> Awaitable[T]:
358+
def value(self) -> RestateDurableFuture[T]:
359359
"""
360360
Returns the value of the promise if it is resolved, None otherwise.
361361
"""

python/restate/server_context.py

Lines changed: 10 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -133,21 +133,9 @@ def __init__(self, server_context: "ServerInvocationContext", name, serde) -> No
133133
super().__init__(name=name, serde=JsonSerde() if serde is None else serde)
134134
self.server_context = server_context
135135

136-
def value(self) -> Awaitable[Any]:
137-
vm: VMWrapper = self.server_context.vm
138-
handle = vm.sys_get_promise(self.name)
139-
coro = self.server_context.create_poll_or_cancel_coroutine([handle])
140-
serde = self.serde
141-
assert serde is not None
142-
143-
async def await_point():
144-
await coro
145-
res = self.server_context.must_take_notification(handle)
146-
if res is None:
147-
return None
148-
return serde.deserialize(res)
149-
150-
return await_point()
136+
def value(self) -> RestateDurableFuture[Any]:
137+
handle = self.server_context.vm.sys_get_promise(self.name)
138+
return self.server_context.create_future(handle, self.serde)
151139

152140
def resolve(self, value: Any) -> Awaitable[None]:
153141
vm: VMWrapper = self.server_context.vm
@@ -156,36 +144,31 @@ def resolve(self, value: Any) -> Awaitable[None]:
156144
handle = vm.sys_complete_promise_success(self.name, value_buffer)
157145

158146
async def await_point():
159-
await self.server_context.create_poll_or_cancel_coroutine([handle])
147+
if not self.server_context.vm.is_completed(handle):
148+
await self.server_context.create_poll_or_cancel_coroutine([handle])
160149
self.server_context.must_take_notification(handle)
161150

162-
return await_point()
151+
return ServerDurableFuture(self.server_context, handle, await_point)
163152

164153
def reject(self, message: str, code: int = 500) -> Awaitable[None]:
165154
vm: VMWrapper = self.server_context.vm
166155
py_failure = Failure(code=code, message=message)
167156
handle = vm.sys_complete_promise_failure(self.name, py_failure)
168157

169158
async def await_point():
170-
await self.server_context.create_poll_or_cancel_coroutine([handle])
159+
if not self.server_context.vm.is_completed(handle):
160+
await self.server_context.create_poll_or_cancel_coroutine([handle])
171161
self.server_context.must_take_notification(handle)
172162

173-
return await_point()
163+
return ServerDurableFuture(self.server_context, handle, await_point)
174164

175165
def peek(self) -> Awaitable[Any | None]:
176166
vm: VMWrapper = self.server_context.vm
177167
handle = vm.sys_peek_promise(self.name)
178168
serde = self.serde
179169
assert serde is not None
180170

181-
async def await_point():
182-
await self.server_context.create_poll_or_cancel_coroutine([handle])
183-
res = self.server_context.must_take_notification(handle)
184-
if res is None:
185-
return None
186-
return serde.deserialize(res)
187-
188-
return await_point()
171+
return self.server_context.create_future(handle, serde)
189172

190173

191174
# disable too many public method

0 commit comments

Comments
 (0)