Skip to content

Commit f02ba1b

Browse files
committed
Add select future combinator
1 parent 9ac204f commit f02ba1b

File tree

3 files changed

+49
-8
lines changed

3 files changed

+49
-8
lines changed

examples/workflow.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,13 @@
1313
# pylint: disable=W0613
1414
# pylint: disable=C0301
1515

16+
from datetime import timedelta
1617

1718
from restate import Workflow, WorkflowContext, WorkflowSharedContext
18-
from restate.exceptions import TerminalError
19+
from restate import select
20+
from restate import TerminalError
21+
22+
TIMEOUT = timedelta(seconds=10)
1923

2024
payment = Workflow("payment")
2125

@@ -38,13 +42,17 @@ def payment_gateway():
3842
ctx.set("status", "waiting for the payment provider to approve")
3943

4044
# Wait for the payment to be verified
41-
result = await ctx.promise("verify.payment").value()
42-
if result == "approved":
43-
ctx.set("status", "payment approved")
44-
return { "success" : True }
4545

46-
ctx.set("status", "payment declined")
47-
raise TerminalError(message="Payment declined", status_code=401)
46+
match await select(result=ctx.promise("verify.payment").value(), timeout=ctx.sleep(TIMEOUT)):
47+
case ['result', "approved"]:
48+
ctx.set("status", "payment approved")
49+
return { "success" : True }
50+
case ['result', "declined"]:
51+
ctx.set("status", "payment declined")
52+
raise TerminalError(message="Payment declined", status_code=401)
53+
case ['timeout', _]:
54+
ctx.set("status", "payment verification timed out")
55+
raise TerminalError(message="Payment verification timed out", status_code=410)
4856

4957
@payment.handler()
5058
async def payment_verified(ctx: WorkflowSharedContext, result: str):

python/restate/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
# pylint: disable=line-too-long
2323
from .context import DurablePromise, RestateDurableFuture, RestateDurableCallFuture, RestateDurableSleepFuture, SendHandle
2424
from .exceptions import TerminalError
25-
from .asyncio import as_completed, gather, wait_completed
25+
from .asyncio import as_completed, gather, wait_completed, select
2626

2727
from .endpoint import app
2828

@@ -56,4 +56,5 @@ def test_harness(app, follow_logs = False, restate_image = ""): # type: ignore
5656
"gather",
5757
"as_completed",
5858
"wait_completed",
59+
"select"
5960
]

python/restate/asyncio.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,38 @@ async def gather(*futures: RestateDurableFuture[Any]) -> List[RestateDurableFutu
2727
pass
2828
return list(futures)
2929

30+
async def select(**kws: RestateDurableFuture[Any]) -> List[Any]:
31+
"""
32+
Blocks until one of the futures is completed.
33+
34+
Example:
35+
36+
who, what = await select(car=f1, hotel=f2, flight=f3)
37+
if who == "car":
38+
print(what)
39+
elif who == "hotel":
40+
print(what)
41+
elif who == "flight":
42+
print(what)
43+
44+
works the best with matching:
45+
46+
match await select(result=ctx.promise("verify.payment"), timeout=ctx.sleep(timedelta(seconds=10))):
47+
case ['result', "approved"]:
48+
return { "success" : True }
49+
case ['result', "declined"]:
50+
raise TerminalError(message="Payment declined", status_code=401)
51+
case ['timeout', _]:
52+
raise TerminalError(message="Payment verification timed out", status_code=410)
53+
54+
"""
55+
if not kws:
56+
raise ValueError("At least one future must be passed.")
57+
reverse = { f: key for key, f in kws.items() }
58+
async for f in as_completed(*kws.values()):
59+
return [reverse[f], await f]
60+
assert False, "unreachable"
61+
3062
async def as_completed(*futures: RestateDurableFuture[Any]):
3163
"""
3264
Returns an iterator that yields the futures as they are completed.

0 commit comments

Comments
 (0)