Skip to content

Commit 10d2cc6

Browse files
authored
Merge pull request #540 from nats-io/fetch-hb
js: add fetch heartbeat option
2 parents 9e24c40 + 2bb6e8d commit 10d2cc6

File tree

3 files changed

+177
-13
lines changed

3 files changed

+177
-13
lines changed

nats/js/client.py

Lines changed: 70 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from nats.aio.msg import Msg
2626
from nats.aio.subscription import Subscription
2727
from nats.js import api
28-
from nats.js.errors import BadBucketError, BucketNotFoundError, InvalidBucketNameError, NotFoundError
28+
from nats.js.errors import BadBucketError, BucketNotFoundError, InvalidBucketNameError, NotFoundError, FetchTimeoutError
2929
from nats.js.kv import KeyValue
3030
from nats.js.manager import JetStreamManager
3131
from nats.js.object_store import (
@@ -547,6 +547,13 @@ def _is_temporary_error(cls, status: Optional[str]) -> bool:
547547
else:
548548
return False
549549

550+
@classmethod
551+
def _is_heartbeat(cls, status: Optional[str]) -> bool:
552+
if status == api.StatusCode.CONTROL_MESSAGE:
553+
return True
554+
else:
555+
return False
556+
550557
@classmethod
551558
def _time_until(cls, timeout: Optional[float],
552559
start_time: float) -> Optional[float]:
@@ -620,9 +627,7 @@ async def activity_check(self):
620627
self._active = False
621628
if not active:
622629
if self._ordered:
623-
await self.reset_ordered_consumer(
624-
self._sseq + 1
625-
)
630+
await self.reset_ordered_consumer(self._sseq + 1)
626631
except asyncio.CancelledError:
627632
break
628633

@@ -882,14 +887,18 @@ async def consumer_info(self) -> api.ConsumerInfo:
882887
)
883888
return info
884889

885-
async def fetch(self,
886-
batch: int = 1,
887-
timeout: Optional[float] = 5) -> List[Msg]:
890+
async def fetch(
891+
self,
892+
batch: int = 1,
893+
timeout: Optional[float] = 5,
894+
heartbeat: Optional[float] = None
895+
) -> List[Msg]:
888896
"""
889897
fetch makes a request to JetStream to be delivered a set of messages.
890898
891899
:param batch: Number of messages to fetch from server.
892900
:param timeout: Max duration of the fetch request before it expires.
901+
:param heartbeat: Idle Heartbeat interval in seconds for the fetch request.
893902
894903
::
895904
@@ -925,15 +934,16 @@ async def main():
925934
timeout * 1_000_000_000
926935
) - 100_000 if timeout else None
927936
if batch == 1:
928-
msg = await self._fetch_one(expires, timeout)
937+
msg = await self._fetch_one(expires, timeout, heartbeat)
929938
return [msg]
930-
msgs = await self._fetch_n(batch, expires, timeout)
939+
msgs = await self._fetch_n(batch, expires, timeout, heartbeat)
931940
return msgs
932941

933942
async def _fetch_one(
934943
self,
935944
expires: Optional[int],
936945
timeout: Optional[float],
946+
heartbeat: Optional[float] = None
937947
) -> Msg:
938948
queue = self._sub._pending_queue
939949

@@ -957,6 +967,10 @@ async def _fetch_one(
957967
next_req['batch'] = 1
958968
if expires:
959969
next_req['expires'] = int(expires)
970+
if heartbeat:
971+
next_req['idle_heartbeat'] = int(
972+
heartbeat * 1_000_000_000
973+
) # to nanoseconds
960974

961975
await self._nc.publish(
962976
self._nms,
@@ -965,6 +979,7 @@ async def _fetch_one(
965979
)
966980

967981
start_time = time.monotonic()
982+
got_any_response = False
968983
while True:
969984
try:
970985
deadline = JetStreamContext._time_until(
@@ -976,6 +991,10 @@ async def _fetch_one(
976991
# Should have received at least a processable message at this point,
977992
status = JetStreamContext.is_status_msg(msg)
978993
if status:
994+
if JetStreamContext._is_heartbeat(status):
995+
got_any_response = True
996+
continue
997+
979998
# In case of a temporary error, treat it as a timeout to retry.
980999
if JetStreamContext._is_temporary_error(status):
9811000
raise nats.errors.TimeoutError
@@ -993,17 +1012,21 @@ async def _fetch_one(
9931012
# due to a reconnect while the fetch request,
9941013
# the JS API not responding on time, or maybe
9951014
# there were no messages yet.
1015+
if got_any_response:
1016+
raise FetchTimeoutError
9961017
raise
9971018

9981019
async def _fetch_n(
9991020
self,
10001021
batch: int,
10011022
expires: Optional[int],
10021023
timeout: Optional[float],
1024+
heartbeat: Optional[float] = None
10031025
) -> List[Msg]:
10041026
msgs = []
10051027
queue = self._sub._pending_queue
10061028
start_time = time.monotonic()
1029+
got_any_response = False
10071030
needed = batch
10081031

10091032
# Fetch as many as needed from the internal pending queue.
@@ -1029,6 +1052,10 @@ async def _fetch_n(
10291052
next_req['batch'] = needed
10301053
if expires:
10311054
next_req['expires'] = expires
1055+
if heartbeat:
1056+
next_req['idle_heartbeat'] = int(
1057+
heartbeat * 1_000_000_000
1058+
) # to nanoseconds
10321059
next_req['no_wait'] = True
10331060
await self._nc.publish(
10341061
self._nms,
@@ -1040,12 +1067,20 @@ async def _fetch_n(
10401067
try:
10411068
msg = await self._sub.next_msg(timeout)
10421069
except asyncio.TimeoutError:
1070+
# Return any message that was already available in the internal queue.
10431071
if msgs:
10441072
return msgs
10451073
raise
10461074

1075+
got_any_response = False
1076+
10471077
status = JetStreamContext.is_status_msg(msg)
1048-
if JetStreamContext._is_processable_msg(status, msg):
1078+
if JetStreamContext._is_heartbeat(status):
1079+
# Mark that we got any response from the server so this is not
1080+
# a possible i/o timeout error or due to a disconnection.
1081+
got_any_response = True
1082+
pass
1083+
elif JetStreamContext._is_processable_msg(status, msg):
10491084
# First processable message received, do not raise error from now.
10501085
msgs.append(msg)
10511086
needed -= 1
@@ -1061,6 +1096,10 @@ async def _fetch_n(
10611096
# No more messages after this so fallthrough
10621097
# after receiving the rest.
10631098
break
1099+
elif JetStreamContext._is_heartbeat(status):
1100+
# Skip heartbeats.
1101+
got_any_response = True
1102+
continue
10641103
elif JetStreamContext._is_processable_msg(status, msg):
10651104
needed -= 1
10661105
msgs.append(msg)
@@ -1079,6 +1118,11 @@ async def _fetch_n(
10791118
next_req['batch'] = needed
10801119
if expires:
10811120
next_req['expires'] = expires
1121+
if heartbeat:
1122+
next_req['idle_heartbeat'] = int(
1123+
heartbeat * 1_000_000_000
1124+
) # to nanoseconds
1125+
10821126
await self._nc.publish(
10831127
self._nms,
10841128
json.dumps(next_req).encode(),
@@ -1099,7 +1143,12 @@ async def _fetch_n(
10991143
if len(msgs) == 0:
11001144
# Not a single processable message has been received so far,
11011145
# if this timed out then let the error be raised.
1102-
msg = await self._sub.next_msg(timeout=deadline)
1146+
try:
1147+
msg = await self._sub.next_msg(timeout=deadline)
1148+
except asyncio.TimeoutError:
1149+
if got_any_response:
1150+
raise FetchTimeoutError
1151+
raise
11031152
else:
11041153
try:
11051154
msg = await self._sub.next_msg(timeout=deadline)
@@ -1109,6 +1158,10 @@ async def _fetch_n(
11091158

11101159
if msg:
11111160
status = JetStreamContext.is_status_msg(msg)
1161+
if JetStreamContext._is_heartbeat(status):
1162+
got_any_response = True
1163+
continue
1164+
11121165
if not status:
11131166
needed -= 1
11141167
msgs.append(msg)
@@ -1132,6 +1185,9 @@ async def _fetch_n(
11321185

11331186
msg = await self._sub.next_msg(timeout=deadline)
11341187
status = JetStreamContext.is_status_msg(msg)
1188+
if JetStreamContext._is_heartbeat(status):
1189+
got_any_response = True
1190+
continue
11351191
if JetStreamContext._is_processable_msg(status, msg):
11361192
needed -= 1
11371193
msgs.append(msg)
@@ -1140,6 +1196,9 @@ async def _fetch_n(
11401196
# at least one message has already arrived.
11411197
pass
11421198

1199+
if len(msgs) == 0 and got_any_response:
1200+
raise FetchTimeoutError
1201+
11431202
return msgs
11441203

11451204
######################

nats/js/errors.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2016-2022 The NATS Authors
1+
# Copyright 2016-2024 The NATS Authors
22
# Licensed under the Apache License, Version 2.0 (the "License");
33
# you may not use this file except in compliance with the License.
44
# You may obtain a copy of the License at
@@ -133,6 +133,15 @@ def __str__(self) -> str:
133133
return "nats: no response from stream"
134134

135135

136+
class FetchTimeoutError(nats.errors.TimeoutError):
137+
"""
138+
Raised if the consumer timed out waiting for messages.
139+
"""
140+
141+
def __str__(self) -> str:
142+
return "nats: fetch timeout"
143+
144+
136145
class ConsumerSequenceMismatchError(Error):
137146
"""
138147
Async error raised by the client with idle_heartbeat mode enabled

tests/test_js.py

Lines changed: 97 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -842,7 +842,7 @@ async def test_consumer_with_multiple_filters(self):
842842
ok = await msgs[0].ack_sync()
843843
assert ok
844844

845-
@async_debug_test
845+
@async_long_test
846846
async def test_add_consumer_with_backoff(self):
847847
nc = NATS()
848848
await nc.connect()
@@ -901,6 +901,102 @@ async def cb(msg):
901901
assert info.config.backoff == [1, 2]
902902
await nc.close()
903903

904+
@async_long_test
905+
async def test_fetch_heartbeats(self):
906+
nc = NATS()
907+
await nc.connect()
908+
909+
js = nc.jetstream()
910+
911+
await js.add_stream(name="events", subjects=["events.>"])
912+
await js.add_consumer(
913+
"events",
914+
durable_name="a",
915+
max_deliver=2,
916+
max_waiting=5,
917+
ack_wait=30,
918+
max_ack_pending=5,
919+
filter_subject="events.>",
920+
)
921+
sub = await js.pull_subscribe_bind("a", stream="events")
922+
923+
with pytest.raises(nats.js.errors.FetchTimeoutError):
924+
await sub.fetch(1, timeout=1, heartbeat=0.1)
925+
926+
with pytest.raises(asyncio.TimeoutError):
927+
await sub.fetch(1, timeout=1, heartbeat=0.1)
928+
929+
with pytest.raises(nats.errors.TimeoutError):
930+
await sub.fetch(1, timeout=1, heartbeat=0.1)
931+
932+
for i in range(0, 15):
933+
await js.publish("events.%d" % i, b'i:%d' % i)
934+
935+
# Fetch(n)
936+
msgs = await sub.fetch(5, timeout=5, heartbeat=0.1)
937+
assert len(msgs) == 5
938+
for msg in msgs:
939+
await msg.ack_sync()
940+
info = await js.consumer_info("events", "a")
941+
assert info.num_pending == 10
942+
943+
# Fetch(1)
944+
msgs = await sub.fetch(1, timeout=1, heartbeat=0.1)
945+
assert len(msgs) == 1
946+
for msg in msgs:
947+
await msg.ack_sync()
948+
949+
# Receive some messages.
950+
msgs = await sub.fetch(20, timeout=2, heartbeat=0.1)
951+
for msg in msgs:
952+
await msg.ack_sync()
953+
msgs = await sub.fetch(4, timeout=2, heartbeat=0.1)
954+
for msg in msgs:
955+
await msg.ack_sync()
956+
957+
# Check that messages were removed from being pending.
958+
info = await js.consumer_info("events", "a")
959+
assert info.num_pending == 0
960+
961+
# Ask for more messages but there aren't any.
962+
with pytest.raises(nats.js.errors.FetchTimeoutError):
963+
await sub.fetch(4, timeout=1, heartbeat=0.1)
964+
965+
with pytest.raises(asyncio.TimeoutError):
966+
msgs = await sub.fetch(4, timeout=1, heartbeat=0.1)
967+
968+
with pytest.raises(nats.errors.TimeoutError):
969+
msgs = await sub.fetch(4, timeout=1, heartbeat=0.1)
970+
971+
with pytest.raises(nats.js.errors.APIError) as err:
972+
await sub.fetch(1, timeout=1, heartbeat=0.5)
973+
assert err.value.description == 'Bad Request - heartbeat value too large'
974+
975+
# Example of catching fetch timeout instead first.
976+
got_fetch_timeout = False
977+
got_io_timeout = False
978+
try:
979+
await sub.fetch(1, timeout=1, heartbeat=0.2)
980+
except nats.js.errors.FetchTimeoutError:
981+
got_fetch_timeout = True
982+
except nats.errors.TimeoutError:
983+
got_io_timeout = True
984+
assert got_fetch_timeout == True
985+
assert got_io_timeout == False
986+
987+
got_fetch_timeout = False
988+
got_io_timeout = False
989+
try:
990+
await sub.fetch(1, timeout=1, heartbeat=0.2)
991+
except nats.errors.TimeoutError:
992+
got_io_timeout = True
993+
except nats.js.errors.FetchTimeoutError:
994+
got_fetch_timeout = True
995+
assert got_fetch_timeout == False
996+
assert got_io_timeout == True
997+
998+
await nc.close()
999+
9041000

9051001
class JSMTest(SingleJetStreamServerTestCase):
9061002

0 commit comments

Comments
 (0)