Skip to content

Commit f9faa78

Browse files
committed
Pass Headers in call/send
1 parent 95138ed commit f9faa78

File tree

3 files changed

+93
-33
lines changed

3 files changed

+93
-33
lines changed

python/restate/server_context.py

Lines changed: 51 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ async def create_poll_or_cancel_coroutine(self, handle) -> bytes | None:
219219
return self.must_take_notification(handle)
220220

221221
# Nothing ready yet, let's try to make some progress
222-
do_progress_response = self.vm.do_progress([handle, CANCEL_HANDLE])
222+
do_progress_response = self.vm.do_progress([handle])
223223
if isinstance(do_progress_response, PyDoProgressAnyCompleted):
224224
# One of the handles completed, we can continue
225225
continue
@@ -339,14 +339,16 @@ def do_call(self,
339339
key: Optional[str] = None,
340340
send_delay: Optional[timedelta] = None,
341341
send: bool = False,
342-
idempotency_key: str | None = None) -> Awaitable[O] | SendHandle:
342+
idempotency_key: str | None = None,
343+
headers: typing.List[typing.Tuple[str, str]] | None = None
344+
) -> Awaitable[O] | SendHandle:
343345
"""Make an RPC call to the given handler"""
344346
target_handler = handler_from_callable(tpe)
345347
service=target_handler.service_tag.name
346348
handler=target_handler.name
347349
input_serde = target_handler.handler_io.input_serde
348350
output_serde = target_handler.handler_io.output_serde
349-
return self.do_raw_call(service, handler, parameter, input_serde, output_serde, key, send_delay, send, idempotency_key)
351+
return self.do_raw_call(service, handler, parameter, input_serde, output_serde, key, send_delay, send, idempotency_key, headers)
350352

351353

352354
def do_raw_call(self,
@@ -358,23 +360,25 @@ def do_raw_call(self,
358360
key: Optional[str] = None,
359361
send_delay: Optional[timedelta] = None,
360362
send: bool = False,
361-
idempotency_key: str | None = None
363+
idempotency_key: str | None = None,
364+
headers: typing.List[typing.Tuple[str, str]] | None = None
362365
) -> Awaitable[O] | SendHandle:
363366
"""Make an RPC call to the given handler"""
364367
parameter = input_serde.serialize(input_param)
365368
if send_delay:
366369
ms = int(send_delay.total_seconds() * 1000)
367-
send_handle = self.vm.sys_send(service, handler, parameter, key, delay=ms, idempotency_key=idempotency_key)
370+
send_handle = self.vm.sys_send(service, handler, parameter, key, delay=ms, idempotency_key=idempotency_key, headers=headers)
368371
return ServerSendHandle(self, send_handle)
369372
if send:
370-
send_handle = self.vm.sys_send(service, handler, parameter, key, idempotency_key=idempotency_key)
373+
send_handle = self.vm.sys_send(service, handler, parameter, key, idempotency_key=idempotency_key, headers=headers)
371374
return ServerSendHandle(self, send_handle)
372375

373376
handle = self.vm.sys_call(service=service,
374377
handler=handler,
375378
parameter=parameter,
376379
key=key,
377-
idempotency_key=idempotency_key)
380+
idempotency_key=idempotency_key,
381+
headers=headers)
378382

379383
async def await_point(s: ServerInvocationContext, h, o: Serde[O]):
380384
"""Wait for this handle to be resolved, and deserialize the response."""
@@ -386,52 +390,75 @@ async def await_point(s: ServerInvocationContext, h, o: Serde[O]):
386390
def service_call(self,
387391
tpe: Callable[[Any, I], Awaitable[O]],
388392
arg: I,
389-
idempotency_key: str | None = None
393+
idempotency_key: str | None = None,
394+
headers: typing.List[typing.Tuple[str, str]] | None = None
390395
) -> Awaitable[O]:
391-
coro = self.do_call(tpe, arg, idempotency_key=idempotency_key)
396+
coro = self.do_call(tpe, arg, idempotency_key=idempotency_key, headers=headers)
392397
assert not isinstance(coro, SendHandle)
393398
return coro
394399

395-
def service_send(self, tpe: Callable[[Any, I], Awaitable[O]], arg: I, send_delay: timedelta | None = None, idempotency_key: str | None = None) -> SendHandle:
396-
send = self.do_call(tpe=tpe, parameter=arg, send_delay=send_delay, send=True, idempotency_key=idempotency_key)
400+
def service_send(self, tpe: Callable[[Any, I], Awaitable[O]], arg: I, send_delay: timedelta | None = None, idempotency_key: str | None = None, headers: typing.List[typing.Tuple[str, str]] | None = None) -> SendHandle:
401+
send = self.do_call(tpe=tpe, parameter=arg, send_delay=send_delay, send=True, idempotency_key=idempotency_key, headers=headers)
397402
assert isinstance(send, SendHandle)
398403
return send
399404

400405
def object_call(self,
401406
tpe: Callable[[Any, I],Awaitable[O]],
402407
key: str,
403408
arg: I,
404-
idempotency_key: str | None = None
409+
idempotency_key: str | None = None,
410+
headers: typing.List[typing.Tuple[str, str]] | None = None
405411
) -> Awaitable[O]:
406-
coro = self.do_call(tpe, arg, key, idempotency_key=idempotency_key)
412+
coro = self.do_call(tpe, arg, key, idempotency_key=idempotency_key, headers=headers)
407413
assert not isinstance(coro, SendHandle)
408414
return coro
409415

410-
def object_send(self, tpe: Callable[[Any, I], Awaitable[O]], key: str, arg: I, send_delay: timedelta | None = None, idempotency_key: str | None = None) -> SendHandle:
411-
send = self.do_call(tpe=tpe, key=key, parameter=arg, send_delay=send_delay, send=True, idempotency_key=idempotency_key)
416+
def object_send(self, tpe: Callable[[Any, I], Awaitable[O]], key: str, arg: I, send_delay: timedelta | None = None, idempotency_key: str | None = None, headers: typing.List[typing.Tuple[str, str]] | None = None) -> SendHandle:
417+
send = self.do_call(tpe=tpe, key=key, parameter=arg, send_delay=send_delay, send=True, idempotency_key=idempotency_key, headers=headers)
412418
assert isinstance(send, SendHandle)
413419
return send
414420

415421
def workflow_call(self,
416422
tpe: Callable[[Any, I], Awaitable[O]],
417423
key: str,
418424
arg: I,
419-
idempotency_key: str | None = None
425+
idempotency_key: str | None = None,
426+
headers: typing.List[typing.Tuple[str, str]] | None = None
420427
) -> Awaitable[O]:
421-
return self.object_call(tpe, key, arg, idempotency_key=idempotency_key)
428+
return self.object_call(tpe, key, arg, idempotency_key=idempotency_key, headers=headers)
422429

423-
def workflow_send(self, tpe: Callable[[Any, I], Awaitable[O]], key: str, arg: I, send_delay: timedelta | None = None, idempotency_key: str | None = None) -> SendHandle:
424-
send = self.object_send(tpe, key, arg, send_delay, idempotency_key=idempotency_key)
430+
def workflow_send(self, tpe: Callable[[Any, I], Awaitable[O]], key: str, arg: I, send_delay: timedelta | None = None, idempotency_key: str | None = None, headers: typing.List[typing.Tuple[str, str]] | None = None) -> SendHandle:
431+
send = self.object_send(tpe, key, arg, send_delay, idempotency_key=idempotency_key, headers=headers)
425432
assert isinstance(send, SendHandle)
426433
return send
427434

428-
def generic_call(self, service: str, handler: str, arg: bytes, key: str | None = None, idempotency_key: str | None = None) -> Awaitable[bytes]:
435+
def generic_call(self, service: str, handler: str, arg: bytes, key: str | None = None, idempotency_key: str | None = None, headers: typing.List[typing.Tuple[str, str]] | None = None) -> Awaitable[bytes]:
429436
serde = BytesSerde()
430-
return self.do_raw_call(service, handler, arg, serde, serde, key, idempotency_key) # type: ignore
431-
432-
def generic_send(self, service: str, handler: str, arg: bytes, key: str | None = None, send_delay: timedelta | None = None, idempotency_key: str | None = None) -> SendHandle:
437+
call_handle = self.do_raw_call(service=service,
438+
handler=handler,
439+
input_param=arg,
440+
input_serde=serde,
441+
output_serde=serde,
442+
key=key,
443+
idempotency_key=idempotency_key,
444+
headers=headers)
445+
assert not isinstance(call_handle, SendHandle)
446+
return call_handle
447+
448+
def generic_send(self, service: str, handler: str, arg: bytes, key: str | None = None, send_delay: timedelta | None = None, idempotency_key: str | None = None, headers: typing.List[typing.Tuple[str, str]] | None = None) -> SendHandle:
433449
serde = BytesSerde()
434-
return self.do_raw_call(service, handler, arg, serde, serde , key, send_delay, True, idempotency_key) # type: ignore
450+
send_handle = self.do_raw_call(service=service,
451+
handler=handler,
452+
input_param=arg,
453+
input_serde=serde,
454+
output_serde=serde,
455+
key=key,
456+
send_delay=send_delay,
457+
send=True,
458+
idempotency_key=idempotency_key,
459+
headers=headers)
460+
assert isinstance(send_handle, SendHandle)
461+
return send_handle
435462

436463
def awakeable(self,
437464
serde: typing.Optional[Serde[I]] = JsonSerde()) -> typing.Tuple[str, Awaitable[Any]]:

python/restate/vm.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from dataclasses import dataclass
1818
import typing
19-
from restate._internal import PyVM, PyFailure, PySuspended, PyVoid, PyStateKeys, PyExponentialRetryConfig, PyDoProgressAnyCompleted, PyDoProgressReadFromInput, PyDoProgressExecuteRun, PyDoProgressCancelSignalReceived, CANCEL_NOTIFICATION_HANDLE # pylint: disable=import-error,no-name-in-module,line-too-long
19+
from restate._internal import PyVM, PyHeader, PyFailure, PySuspended, PyVoid, PyStateKeys, PyExponentialRetryConfig, PyDoProgressAnyCompleted, PyDoProgressReadFromInput, PyDoProgressExecuteRun, PyDoProgressCancelSignalReceived, CANCEL_NOTIFICATION_HANDLE # pylint: disable=import-error,no-name-in-module,line-too-long
2020

2121
@dataclass
2222
class Invocation:
@@ -245,10 +245,13 @@ def sys_call(self,
245245
handler: str,
246246
parameter: bytes,
247247
key: typing.Optional[str] = None,
248-
idempotency_key: typing.Optional[str] = None
248+
idempotency_key: typing.Optional[str] = None,
249+
headers: typing.Optional[typing.List[typing.Tuple[str, str]]] = None
249250
):
250251
"""Call a service"""
251-
return self.vm.sys_call(service, handler, parameter, key, idempotency_key)
252+
if headers:
253+
headers = [PyHeader(key=h[0], value=h[1]) for h in headers]
254+
return self.vm.sys_call(service, handler, parameter, key, idempotency_key, headers)
252255

253256
# pylint: disable=too-many-arguments
254257
def sys_send(self,
@@ -257,13 +260,16 @@ def sys_send(self,
257260
parameter: bytes,
258261
key: typing.Optional[str] = None,
259262
delay: typing.Optional[int] = None,
260-
idempotency_key: typing.Optional[str] = None
263+
idempotency_key: typing.Optional[str] = None,
264+
headers: typing.Optional[typing.List[typing.Tuple[str, str]]] = None
261265
) -> int:
262266
"""
263267
send an invocation to a service, and return the handle
264268
to the promise that will resolve with the invocation id
265269
"""
266-
return self.vm.sys_send(service, handler, parameter, key, delay, idempotency_key)
270+
if headers:
271+
headers = [PyHeader(key=h[0], value=h[1]) for h in headers]
272+
return self.vm.sys_send(service, handler, parameter, key, delay, idempotency_key, headers)
267273

268274
def sys_run(self, name: str) -> int:
269275
"""

src/lib.rs

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,14 @@ struct PyHeader {
2222
value: String,
2323
}
2424

25+
#[pymethods]
26+
impl PyHeader {
27+
#[new]
28+
fn new(key: String, value: String) -> PyHeader {
29+
Self { key, value }
30+
}
31+
}
32+
2533
impl From<Header> for PyHeader {
2634
fn from(h: Header) -> Self {
2735
PyHeader {
@@ -31,6 +39,15 @@ impl From<Header> for PyHeader {
3139
}
3240
}
3341

42+
impl From<PyHeader> for Header {
43+
fn from(h: PyHeader) -> Self {
44+
Header {
45+
key: h.key.into(),
46+
value: h.value.into(),
47+
}
48+
}
49+
}
50+
3451
#[pyclass]
3552
struct PyResponseHead {
3653
#[pyo3(get, set)]
@@ -433,14 +450,15 @@ impl PyVM {
433450
.map_err(Into::into)
434451
}
435452

436-
#[pyo3(signature = (service, handler, buffer, key=None, idempotency_key=None))]
453+
#[pyo3(signature = (service, handler, buffer, key=None, idempotency_key=None, headers=None))]
437454
fn sys_call(
438455
mut self_: PyRefMut<'_, Self>,
439456
service: String,
440457
handler: String,
441458
buffer: &Bound<'_, PyBytes>,
442459
key: Option<String>,
443460
idempotency_key: Option<String>,
461+
headers: Option<Vec<PyHeader>>,
444462
) -> Result<PyCallHandle, PyVMError> {
445463
self_
446464
.vm
@@ -450,15 +468,19 @@ impl PyVM {
450468
handler,
451469
key,
452470
idempotency_key,
453-
headers: vec![],
471+
headers: headers
472+
.unwrap_or_default()
473+
.into_iter()
474+
.map(Into::into)
475+
.collect(),
454476
},
455477
buffer.as_bytes().to_vec().into(),
456478
)
457479
.map(Into::into)
458480
.map_err(Into::into)
459481
}
460482

461-
#[pyo3(signature = (service, handler, buffer, key=None, delay=None, idempotency_key=None))]
483+
#[pyo3(signature = (service, handler, buffer, key=None, delay=None, idempotency_key=None, headers=None))]
462484
fn sys_send(
463485
mut self_: PyRefMut<'_, Self>,
464486
service: String,
@@ -467,6 +489,7 @@ impl PyVM {
467489
key: Option<String>,
468490
delay: Option<u64>,
469491
idempotency_key: Option<String>,
492+
headers: Option<Vec<PyHeader>>,
470493
) -> Result<PyNotificationHandle, PyVMError> {
471494
self_
472495
.vm
@@ -476,7 +499,11 @@ impl PyVM {
476499
handler,
477500
key,
478501
idempotency_key,
479-
headers: vec![],
502+
headers: headers
503+
.unwrap_or_default()
504+
.into_iter()
505+
.map(Into::into)
506+
.collect(),
480507
},
481508
buffer.as_bytes().to_vec().into(),
482509
delay.map(|millis| {

0 commit comments

Comments
 (0)