Skip to content

Commit cbc54fa

Browse files
Protocol alignment (#657)
1 parent 3a5f4b1 commit cbc54fa

File tree

3 files changed

+204
-56
lines changed

3 files changed

+204
-56
lines changed

jupyter_server/base/zmqhandlers.py

Lines changed: 51 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,38 @@ def deserialize_binary_message(bmsg):
8282
return msg
8383

8484

85+
def serialize_msg_to_ws_v1(msg_or_list, channel, pack=None):
86+
if pack:
87+
msg_list = [
88+
pack(msg_or_list["header"]),
89+
pack(msg_or_list["parent_header"]),
90+
pack(msg_or_list["metadata"]),
91+
pack(msg_or_list["content"]),
92+
]
93+
else:
94+
msg_list = msg_or_list
95+
channel = channel.encode("utf-8")
96+
offsets = []
97+
offsets.append(8 * (1 + 1 + len(msg_list) + 1))
98+
offsets.append(len(channel) + offsets[-1])
99+
for msg in msg_list:
100+
offsets.append(len(msg) + offsets[-1])
101+
offset_number = len(offsets).to_bytes(8, byteorder="little")
102+
offsets = [offset.to_bytes(8, byteorder="little") for offset in offsets]
103+
bin_msg = b"".join([offset_number] + offsets + [channel] + msg_list)
104+
return bin_msg
105+
106+
107+
def deserialize_msg_from_ws_v1(ws_msg):
108+
offset_number = int.from_bytes(ws_msg[:8], "little")
109+
offsets = [
110+
int.from_bytes(ws_msg[8 * (i + 1) : 8 * (i + 2)], "little") for i in range(offset_number)
111+
]
112+
channel = ws_msg[offsets[0] : offsets[1]].decode("utf-8")
113+
msg_list = [ws_msg[offsets[i] : offsets[i + 1]] for i in range(1, offset_number - 1)]
114+
return channel, msg_list
115+
116+
85117
# ping interval for keeping websockets alive (30 seconds)
86118
WS_PING_INTERVAL = 30000
87119

@@ -239,6 +271,16 @@ def _reserialize_reply(self, msg_or_list, channel=None):
239271
smsg = json.dumps(msg, default=json_default)
240272
return cast_unicode(smsg)
241273

274+
def select_subprotocol(self, subprotocols):
275+
preferred_protocol = self.settings.get("kernel_ws_protocol")
276+
if preferred_protocol is None:
277+
preferred_protocol = "v1.kernel.websocket.jupyter.org"
278+
elif preferred_protocol == "":
279+
preferred_protocol = None
280+
selected_subprotocol = preferred_protocol if preferred_protocol in subprotocols else None
281+
# None is the default, "legacy" protocol
282+
return selected_subprotocol
283+
242284
def _on_zmq_reply(self, stream, msg_list):
243285
# Sometimes this gets triggered when the on_close method is scheduled in the
244286
# eventloop but hasn't been called.
@@ -247,12 +289,16 @@ def _on_zmq_reply(self, stream, msg_list):
247289
self.close()
248290
return
249291
channel = getattr(stream, "channel", None)
250-
try:
251-
msg = self._reserialize_reply(msg_list, channel=channel)
252-
except Exception:
253-
self.log.critical("Malformed message: %r" % msg_list, exc_info=True)
292+
if self.selected_subprotocol == "v1.kernel.websocket.jupyter.org":
293+
bin_msg = serialize_msg_to_ws_v1(msg_list, channel)
294+
self.write_message(bin_msg, binary=True)
254295
else:
255-
self.write_message(msg, binary=isinstance(msg, bytes))
296+
try:
297+
msg = self._reserialize_reply(msg_list, channel=channel)
298+
except Exception:
299+
self.log.critical("Malformed message: %r" % msg_list, exc_info=True)
300+
else:
301+
self.write_message(msg, binary=isinstance(msg, bytes))
256302

257303

258304
class AuthenticatedZMQStreamHandler(ZMQStreamHandler, JupyterHandler):

jupyter_server/serverapp.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,10 @@ def init_settings(
314314
"no_cache_paths": [url_path_join(base_url, "static", "custom")],
315315
},
316316
version_hash=version_hash,
317+
# kernel message protocol over websoclet
318+
kernel_ws_protocol=jupyter_app.kernel_ws_protocol,
317319
# rate limits
320+
limit_rate=jupyter_app.limit_rate,
318321
iopub_msg_rate_limit=jupyter_app.iopub_msg_rate_limit,
319322
iopub_data_rate_limit=jupyter_app.iopub_data_rate_limit,
320323
rate_limit_window=jupyter_app.rate_limit_window,
@@ -1612,6 +1615,29 @@ def _update_server_extensions(self, change):
16121615
help=_i18n("Reraise exceptions encountered loading server extensions?"),
16131616
)
16141617

1618+
kernel_ws_protocol = Unicode(
1619+
None,
1620+
allow_none=True,
1621+
config=True,
1622+
help=_i18n(
1623+
"Preferred kernel message protocol over websocket to use (default: None). "
1624+
"If an empty string is passed, select the legacy protocol. If None, "
1625+
"the selected protocol will depend on what the front-end supports "
1626+
"(usually the most recent protocol supported by the back-end and the "
1627+
"front-end)."
1628+
),
1629+
)
1630+
1631+
limit_rate = Bool(
1632+
True,
1633+
config=True,
1634+
help=_i18n(
1635+
"Whether to limit the rate of IOPub messages (default: True). "
1636+
"If True, use iopub_msg_rate_limit, iopub_data_rate_limit and/or rate_limit_window "
1637+
"to tune the rate."
1638+
),
1639+
)
1640+
16151641
iopub_msg_rate_limit = Float(
16161642
1000,
16171643
config=True,

jupyter_server/services/kernels/handlers.py

Lines changed: 127 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,11 @@
2222

2323
from ...base.handlers import APIHandler
2424
from ...base.zmqhandlers import AuthenticatedZMQStreamHandler
25-
from ...base.zmqhandlers import deserialize_binary_message
25+
from ...base.zmqhandlers import (
26+
deserialize_binary_message,
27+
serialize_msg_to_ws_v1,
28+
deserialize_msg_from_ws_v1,
29+
)
2630
from jupyter_server.utils import ensure_async
2731
from jupyter_server.utils import url_escape
2832
from jupyter_server.utils import url_path_join
@@ -105,6 +109,10 @@ def kernel_info_timeout(self):
105109
km_default = self.kernel_manager.kernel_info_timeout
106110
return self.settings.get("kernel_info_timeout", km_default)
107111

112+
@property
113+
def limit_rate(self):
114+
return self.settings.get("limit_rate", True)
115+
108116
@property
109117
def iopub_msg_rate_limit(self):
110118
return self.settings.get("iopub_msg_rate_limit", 0)
@@ -452,64 +460,112 @@ def subscribe(value):
452460

453461
return connected
454462

455-
def on_message(self, msg):
463+
def on_message(self, ws_msg):
456464
if not self.channels:
457465
# already closed, ignore the message
458-
self.log.debug("Received message on closed websocket %r", msg)
466+
self.log.debug("Received message on closed websocket %r", ws_msg)
459467
return
460-
if isinstance(msg, bytes):
461-
msg = deserialize_binary_message(msg)
468+
469+
if self.selected_subprotocol == "v1.kernel.websocket.jupyter.org":
470+
channel, msg_list = deserialize_msg_from_ws_v1(ws_msg)
471+
msg = {
472+
"header": None,
473+
}
462474
else:
463-
msg = json.loads(msg)
464-
channel = msg.pop("channel", None)
475+
if isinstance(ws_msg, bytes):
476+
msg = deserialize_binary_message(ws_msg)
477+
else:
478+
msg = json.loads(ws_msg)
479+
msg_list = []
480+
channel = msg.pop("channel", None)
481+
465482
if channel is None:
466483
self.log.warning("No channel specified, assuming shell: %s", msg)
467484
channel = "shell"
468485
if channel not in self.channels:
469486
self.log.warning("No such channel: %r", channel)
470487
return
471488
am = self.kernel_manager.allowed_message_types
472-
mt = msg["header"]["msg_type"]
473-
if am and mt not in am:
474-
self.log.warning('Received message of type "%s", which is not allowed. Ignoring.' % mt)
475-
else:
489+
ignore_msg = False
490+
if am:
491+
msg["header"] = self.get_part("header", msg["header"], msg_list)
492+
if msg["header"]["msg_type"] not in am:
493+
self.log.warning(
494+
'Received message of type "%s", which is not allowed. Ignoring.'
495+
% msg["header"]["msg_type"]
496+
)
497+
ignore_msg = True
498+
if not ignore_msg:
476499
stream = self.channels[channel]
477-
self.session.send(stream, msg)
500+
if self.selected_subprotocol == "v1.kernel.websocket.jupyter.org":
501+
self.session.send_raw(stream, msg_list)
502+
else:
503+
self.session.send(stream, msg)
504+
505+
def get_part(self, field, value, msg_list):
506+
if value is None:
507+
field2idx = {
508+
"header": 0,
509+
"parent_header": 1,
510+
"content": 3,
511+
}
512+
value = self.session.unpack(msg_list[field2idx[field]])
513+
return value
478514

479515
def _on_zmq_reply(self, stream, msg_list):
480516
idents, fed_msg_list = self.session.feed_identities(msg_list)
481-
msg = self.session.deserialize(fed_msg_list)
482517

483-
parent = msg["parent_header"]
484-
485-
def write_stderr(error_message):
486-
self.log.warning(error_message)
487-
msg = self.session.msg(
488-
"stream", content={"text": error_message + "\n", "name": "stderr"}, parent=parent
489-
)
490-
msg["channel"] = "iopub"
491-
self.write_message(json.dumps(msg, default=json_default))
518+
if self.selected_subprotocol == "v1.kernel.websocket.jupyter.org":
519+
msg = {"header": None, "parent_header": None, "content": None}
520+
else:
521+
msg = self.session.deserialize(fed_msg_list)
492522

493523
channel = getattr(stream, "channel", None)
494-
msg_type = msg["header"]["msg_type"]
524+
parts = fed_msg_list[1:]
495525

496-
if channel == "iopub" and msg_type == "error":
497-
self._on_error(msg)
526+
self._on_error(channel, msg, parts)
498527

499-
if (
500-
channel == "iopub"
501-
and msg_type == "status"
502-
and msg["content"].get("execution_state") == "idle"
503-
):
504-
# reset rate limit counter on status=idle,
505-
# to avoid 'Run All' hitting limits prematurely.
506-
self._iopub_window_byte_queue = []
507-
self._iopub_window_msg_count = 0
508-
self._iopub_window_byte_count = 0
509-
self._iopub_msgs_exceeded = False
510-
self._iopub_data_exceeded = False
511-
512-
if channel == "iopub" and msg_type not in {"status", "comm_open", "execute_input"}:
528+
if self._limit_rate(channel, msg, parts):
529+
return
530+
531+
if self.selected_subprotocol == "v1.kernel.websocket.jupyter.org":
532+
super(ZMQChannelsHandler, self)._on_zmq_reply(stream, parts)
533+
else:
534+
super(ZMQChannelsHandler, self)._on_zmq_reply(stream, msg)
535+
536+
def write_stderr(self, error_message, parent_header):
537+
self.log.warning(error_message)
538+
err_msg = self.session.msg(
539+
"stream",
540+
content={"text": error_message + "\n", "name": "stderr"},
541+
parent=parent_header,
542+
)
543+
if self.selected_subprotocol == "v1.kernel.websocket.jupyter.org":
544+
bin_msg = serialize_msg_to_ws_v1(err_msg, "iopub", self.session.pack)
545+
self.write_message(bin_msg, binary=True)
546+
else:
547+
err_msg["channel"] = "iopub"
548+
self.write_message(json.dumps(err_msg, default=json_default))
549+
550+
def _limit_rate(self, channel, msg, msg_list):
551+
if not (self.limit_rate and channel == "iopub"):
552+
return False
553+
554+
msg["header"] = self.get_part("header", msg["header"], msg_list)
555+
556+
msg_type = msg["header"]["msg_type"]
557+
if msg_type == "status":
558+
msg["content"] = self.get_part("content", msg["content"], msg_list)
559+
if msg["content"].get("execution_state") == "idle":
560+
# reset rate limit counter on status=idle,
561+
# to avoid 'Run All' hitting limits prematurely.
562+
self._iopub_window_byte_queue = []
563+
self._iopub_window_msg_count = 0
564+
self._iopub_window_byte_count = 0
565+
self._iopub_msgs_exceeded = False
566+
self._iopub_data_exceeded = False
567+
568+
if msg_type not in {"status", "comm_open", "execute_input"}:
513569

514570
# Remove the counts queued for removal.
515571
now = IOLoop.current().time()
@@ -545,7 +601,10 @@ def write_stderr(error_message):
545601
if self.iopub_msg_rate_limit > 0 and msg_rate > self.iopub_msg_rate_limit:
546602
if not self._iopub_msgs_exceeded:
547603
self._iopub_msgs_exceeded = True
548-
write_stderr(
604+
msg["parent_header"] = self.get_part(
605+
"parent_header", msg["parent_header"], msg_list
606+
)
607+
self.write_stderr(
549608
dedent(
550609
"""\
551610
IOPub message rate exceeded.
@@ -560,7 +619,8 @@ def write_stderr(error_message):
560619
""".format(
561620
self.iopub_msg_rate_limit, self.rate_limit_window
562621
)
563-
)
622+
),
623+
msg["parent_header"],
564624
)
565625
else:
566626
# resume once we've got some headroom below the limit
@@ -573,7 +633,10 @@ def write_stderr(error_message):
573633
if self.iopub_data_rate_limit > 0 and data_rate > self.iopub_data_rate_limit:
574634
if not self._iopub_data_exceeded:
575635
self._iopub_data_exceeded = True
576-
write_stderr(
636+
msg["parent_header"] = self.get_part(
637+
"parent_header", msg["parent_header"], msg_list
638+
)
639+
self.write_stderr(
577640
dedent(
578641
"""\
579642
IOPub data rate exceeded.
@@ -588,7 +651,8 @@ def write_stderr(error_message):
588651
""".format(
589652
self.iopub_data_rate_limit, self.rate_limit_window
590653
)
591-
)
654+
),
655+
msg["parent_header"],
592656
)
593657
else:
594658
# resume once we've got some headroom below the limit
@@ -603,8 +667,9 @@ def write_stderr(error_message):
603667
self._iopub_window_msg_count -= 1
604668
self._iopub_window_byte_count -= byte_count
605669
self._iopub_window_byte_queue.pop(-1)
606-
return
607-
super(ZMQChannelsHandler, self)._on_zmq_reply(stream, msg)
670+
return True
671+
672+
return False
608673

609674
def close(self):
610675
super(ZMQChannelsHandler, self).close()
@@ -654,8 +719,12 @@ def _send_status_message(self, status):
654719
# that all messages from the stopped kernel have been delivered
655720
iopub.flush()
656721
msg = self.session.msg("status", {"execution_state": status})
657-
msg["channel"] = "iopub"
658-
self.write_message(json.dumps(msg, default=json_default))
722+
if self.selected_subprotocol == "v1.kernel.websocket.jupyter.org":
723+
bin_msg = serialize_msg_to_ws_v1(msg, "iopub", self.session.pack)
724+
self.write_message(bin_msg, binary=True)
725+
else:
726+
msg["channel"] = "iopub"
727+
self.write_message(json.dumps(msg, default=json_default))
659728

660729
def on_kernel_restarted(self):
661730
self.log.warning("kernel %s restarted", self.kernel_id)
@@ -665,12 +734,19 @@ def on_restart_failed(self):
665734
self.log.error("kernel %s restarted failed!", self.kernel_id)
666735
self._send_status_message("dead")
667736

668-
def _on_error(self, msg):
737+
def _on_error(self, channel, msg, msg_list):
669738
if self.kernel_manager.allow_tracebacks:
670739
return
671-
msg["content"]["ename"] = "ExecutionError"
672-
msg["content"]["evalue"] = "Execution error"
673-
msg["content"]["traceback"] = [self.kernel_manager.traceback_replacement_message]
740+
741+
if channel == "iopub":
742+
msg["header"] = self.get_part("header", msg["header"], msg_list)
743+
if msg["header"]["msg_type"] == "error":
744+
msg["content"] = self.get_part("content", msg["content"], msg_list)
745+
msg["content"]["ename"] = "ExecutionError"
746+
msg["content"]["evalue"] = "Execution error"
747+
msg["content"]["traceback"] = [self.kernel_manager.traceback_replacement_message]
748+
if self.selected_subprotocol == "v1.kernel.websocket.jupyter.org":
749+
msg_list[3] = self.session.pack(msg["content"])
674750

675751

676752
# -----------------------------------------------------------------------------

0 commit comments

Comments
 (0)