22
22
23
23
from ...base .handlers import APIHandler
24
24
from ...base .zmqhandlers import AuthenticatedZMQStreamHandler
25
- from ...base .zmqhandlers import deserialize_binary_message
25
+ from ...base .zmqhandlers import (
26
+ deserialize_binary_message ,
27
+ serialize_msg_to_ws_v1 ,
28
+ deserialize_msg_from_ws_v1 ,
29
+ )
26
30
from jupyter_server .utils import ensure_async
27
31
from jupyter_server .utils import url_escape
28
32
from jupyter_server .utils import url_path_join
@@ -105,6 +109,10 @@ def kernel_info_timeout(self):
105
109
km_default = self .kernel_manager .kernel_info_timeout
106
110
return self .settings .get ("kernel_info_timeout" , km_default )
107
111
112
+ @property
113
+ def limit_rate (self ):
114
+ return self .settings .get ("limit_rate" , True )
115
+
108
116
@property
109
117
def iopub_msg_rate_limit (self ):
110
118
return self .settings .get ("iopub_msg_rate_limit" , 0 )
@@ -452,64 +460,112 @@ def subscribe(value):
452
460
453
461
return connected
454
462
455
- def on_message (self , msg ):
463
+ def on_message (self , ws_msg ):
456
464
if not self .channels :
457
465
# already closed, ignore the message
458
- self .log .debug ("Received message on closed websocket %r" , msg )
466
+ self .log .debug ("Received message on closed websocket %r" , ws_msg )
459
467
return
460
- if isinstance (msg , bytes ):
461
- msg = deserialize_binary_message (msg )
468
+
469
+ if self .selected_subprotocol == "v1.kernel.websocket.jupyter.org" :
470
+ channel , msg_list = deserialize_msg_from_ws_v1 (ws_msg )
471
+ msg = {
472
+ "header" : None ,
473
+ }
462
474
else :
463
- msg = json .loads (msg )
464
- channel = msg .pop ("channel" , None )
475
+ if isinstance (ws_msg , bytes ):
476
+ msg = deserialize_binary_message (ws_msg )
477
+ else :
478
+ msg = json .loads (ws_msg )
479
+ msg_list = []
480
+ channel = msg .pop ("channel" , None )
481
+
465
482
if channel is None :
466
483
self .log .warning ("No channel specified, assuming shell: %s" , msg )
467
484
channel = "shell"
468
485
if channel not in self .channels :
469
486
self .log .warning ("No such channel: %r" , channel )
470
487
return
471
488
am = self .kernel_manager .allowed_message_types
472
- mt = msg ["header" ]["msg_type" ]
473
- if am and mt not in am :
474
- self .log .warning ('Received message of type "%s", which is not allowed. Ignoring.' % mt )
475
- else :
489
+ ignore_msg = False
490
+ if am :
491
+ msg ["header" ] = self .get_part ("header" , msg ["header" ], msg_list )
492
+ if msg ["header" ]["msg_type" ] not in am :
493
+ self .log .warning (
494
+ 'Received message of type "%s", which is not allowed. Ignoring.'
495
+ % msg ["header" ]["msg_type" ]
496
+ )
497
+ ignore_msg = True
498
+ if not ignore_msg :
476
499
stream = self .channels [channel ]
477
- self .session .send (stream , msg )
500
+ if self .selected_subprotocol == "v1.kernel.websocket.jupyter.org" :
501
+ self .session .send_raw (stream , msg_list )
502
+ else :
503
+ self .session .send (stream , msg )
504
+
505
+ def get_part (self , field , value , msg_list ):
506
+ if value is None :
507
+ field2idx = {
508
+ "header" : 0 ,
509
+ "parent_header" : 1 ,
510
+ "content" : 3 ,
511
+ }
512
+ value = self .session .unpack (msg_list [field2idx [field ]])
513
+ return value
478
514
479
515
def _on_zmq_reply (self , stream , msg_list ):
480
516
idents , fed_msg_list = self .session .feed_identities (msg_list )
481
- msg = self .session .deserialize (fed_msg_list )
482
517
483
- parent = msg ["parent_header" ]
484
-
485
- def write_stderr (error_message ):
486
- self .log .warning (error_message )
487
- msg = self .session .msg (
488
- "stream" , content = {"text" : error_message + "\n " , "name" : "stderr" }, parent = parent
489
- )
490
- msg ["channel" ] = "iopub"
491
- self .write_message (json .dumps (msg , default = json_default ))
518
+ if self .selected_subprotocol == "v1.kernel.websocket.jupyter.org" :
519
+ msg = {"header" : None , "parent_header" : None , "content" : None }
520
+ else :
521
+ msg = self .session .deserialize (fed_msg_list )
492
522
493
523
channel = getattr (stream , "channel" , None )
494
- msg_type = msg [ "header" ][ "msg_type" ]
524
+ parts = fed_msg_list [ 1 : ]
495
525
496
- if channel == "iopub" and msg_type == "error" :
497
- self ._on_error (msg )
526
+ self ._on_error (channel , msg , parts )
498
527
499
- if (
500
- channel == "iopub"
501
- and msg_type == "status"
502
- and msg ["content" ].get ("execution_state" ) == "idle"
503
- ):
504
- # reset rate limit counter on status=idle,
505
- # to avoid 'Run All' hitting limits prematurely.
506
- self ._iopub_window_byte_queue = []
507
- self ._iopub_window_msg_count = 0
508
- self ._iopub_window_byte_count = 0
509
- self ._iopub_msgs_exceeded = False
510
- self ._iopub_data_exceeded = False
511
-
512
- if channel == "iopub" and msg_type not in {"status" , "comm_open" , "execute_input" }:
528
+ if self ._limit_rate (channel , msg , parts ):
529
+ return
530
+
531
+ if self .selected_subprotocol == "v1.kernel.websocket.jupyter.org" :
532
+ super (ZMQChannelsHandler , self )._on_zmq_reply (stream , parts )
533
+ else :
534
+ super (ZMQChannelsHandler , self )._on_zmq_reply (stream , msg )
535
+
536
+ def write_stderr (self , error_message , parent_header ):
537
+ self .log .warning (error_message )
538
+ err_msg = self .session .msg (
539
+ "stream" ,
540
+ content = {"text" : error_message + "\n " , "name" : "stderr" },
541
+ parent = parent_header ,
542
+ )
543
+ if self .selected_subprotocol == "v1.kernel.websocket.jupyter.org" :
544
+ bin_msg = serialize_msg_to_ws_v1 (err_msg , "iopub" , self .session .pack )
545
+ self .write_message (bin_msg , binary = True )
546
+ else :
547
+ err_msg ["channel" ] = "iopub"
548
+ self .write_message (json .dumps (err_msg , default = json_default ))
549
+
550
+ def _limit_rate (self , channel , msg , msg_list ):
551
+ if not (self .limit_rate and channel == "iopub" ):
552
+ return False
553
+
554
+ msg ["header" ] = self .get_part ("header" , msg ["header" ], msg_list )
555
+
556
+ msg_type = msg ["header" ]["msg_type" ]
557
+ if msg_type == "status" :
558
+ msg ["content" ] = self .get_part ("content" , msg ["content" ], msg_list )
559
+ if msg ["content" ].get ("execution_state" ) == "idle" :
560
+ # reset rate limit counter on status=idle,
561
+ # to avoid 'Run All' hitting limits prematurely.
562
+ self ._iopub_window_byte_queue = []
563
+ self ._iopub_window_msg_count = 0
564
+ self ._iopub_window_byte_count = 0
565
+ self ._iopub_msgs_exceeded = False
566
+ self ._iopub_data_exceeded = False
567
+
568
+ if msg_type not in {"status" , "comm_open" , "execute_input" }:
513
569
514
570
# Remove the counts queued for removal.
515
571
now = IOLoop .current ().time ()
@@ -545,7 +601,10 @@ def write_stderr(error_message):
545
601
if self .iopub_msg_rate_limit > 0 and msg_rate > self .iopub_msg_rate_limit :
546
602
if not self ._iopub_msgs_exceeded :
547
603
self ._iopub_msgs_exceeded = True
548
- write_stderr (
604
+ msg ["parent_header" ] = self .get_part (
605
+ "parent_header" , msg ["parent_header" ], msg_list
606
+ )
607
+ self .write_stderr (
549
608
dedent (
550
609
"""\
551
610
IOPub message rate exceeded.
@@ -560,7 +619,8 @@ def write_stderr(error_message):
560
619
""" .format (
561
620
self .iopub_msg_rate_limit , self .rate_limit_window
562
621
)
563
- )
622
+ ),
623
+ msg ["parent_header" ],
564
624
)
565
625
else :
566
626
# resume once we've got some headroom below the limit
@@ -573,7 +633,10 @@ def write_stderr(error_message):
573
633
if self .iopub_data_rate_limit > 0 and data_rate > self .iopub_data_rate_limit :
574
634
if not self ._iopub_data_exceeded :
575
635
self ._iopub_data_exceeded = True
576
- write_stderr (
636
+ msg ["parent_header" ] = self .get_part (
637
+ "parent_header" , msg ["parent_header" ], msg_list
638
+ )
639
+ self .write_stderr (
577
640
dedent (
578
641
"""\
579
642
IOPub data rate exceeded.
@@ -588,7 +651,8 @@ def write_stderr(error_message):
588
651
""" .format (
589
652
self .iopub_data_rate_limit , self .rate_limit_window
590
653
)
591
- )
654
+ ),
655
+ msg ["parent_header" ],
592
656
)
593
657
else :
594
658
# resume once we've got some headroom below the limit
@@ -603,8 +667,9 @@ def write_stderr(error_message):
603
667
self ._iopub_window_msg_count -= 1
604
668
self ._iopub_window_byte_count -= byte_count
605
669
self ._iopub_window_byte_queue .pop (- 1 )
606
- return
607
- super (ZMQChannelsHandler , self )._on_zmq_reply (stream , msg )
670
+ return True
671
+
672
+ return False
608
673
609
674
def close (self ):
610
675
super (ZMQChannelsHandler , self ).close ()
@@ -654,8 +719,12 @@ def _send_status_message(self, status):
654
719
# that all messages from the stopped kernel have been delivered
655
720
iopub .flush ()
656
721
msg = self .session .msg ("status" , {"execution_state" : status })
657
- msg ["channel" ] = "iopub"
658
- self .write_message (json .dumps (msg , default = json_default ))
722
+ if self .selected_subprotocol == "v1.kernel.websocket.jupyter.org" :
723
+ bin_msg = serialize_msg_to_ws_v1 (msg , "iopub" , self .session .pack )
724
+ self .write_message (bin_msg , binary = True )
725
+ else :
726
+ msg ["channel" ] = "iopub"
727
+ self .write_message (json .dumps (msg , default = json_default ))
659
728
660
729
def on_kernel_restarted (self ):
661
730
self .log .warning ("kernel %s restarted" , self .kernel_id )
@@ -665,12 +734,19 @@ def on_restart_failed(self):
665
734
self .log .error ("kernel %s restarted failed!" , self .kernel_id )
666
735
self ._send_status_message ("dead" )
667
736
668
- def _on_error (self , msg ):
737
+ def _on_error (self , channel , msg , msg_list ):
669
738
if self .kernel_manager .allow_tracebacks :
670
739
return
671
- msg ["content" ]["ename" ] = "ExecutionError"
672
- msg ["content" ]["evalue" ] = "Execution error"
673
- msg ["content" ]["traceback" ] = [self .kernel_manager .traceback_replacement_message ]
740
+
741
+ if channel == "iopub" :
742
+ msg ["header" ] = self .get_part ("header" , msg ["header" ], msg_list )
743
+ if msg ["header" ]["msg_type" ] == "error" :
744
+ msg ["content" ] = self .get_part ("content" , msg ["content" ], msg_list )
745
+ msg ["content" ]["ename" ] = "ExecutionError"
746
+ msg ["content" ]["evalue" ] = "Execution error"
747
+ msg ["content" ]["traceback" ] = [self .kernel_manager .traceback_replacement_message ]
748
+ if self .selected_subprotocol == "v1.kernel.websocket.jupyter.org" :
749
+ msg_list [3 ] = self .session .pack (msg ["content" ])
674
750
675
751
676
752
# -----------------------------------------------------------------------------
0 commit comments