@@ -53,6 +53,9 @@ struct socket_t {
53
53
}
54
54
};
55
55
56
+ // macro for nicer error messages on server crash
57
+ #define RPC_STATUS_ASSERT (x ) if (!(x)) GGML_ABORT(" Remote RPC server crashed or returned malformed response" )
58
+
56
59
// all RPC structures must be packed
57
60
#pragma pack(push, 1)
58
61
// ggml_tensor is serialized into rpc_tensor
@@ -425,7 +428,7 @@ static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cm
425
428
static bool check_server_version (const std::shared_ptr<socket_t > & sock) {
426
429
rpc_msg_hello_rsp response;
427
430
bool status = send_rpc_cmd (sock, RPC_CMD_HELLO, nullptr , 0 , &response, sizeof (response));
428
- GGML_ASSERT (status);
431
+ RPC_STATUS_ASSERT (status);
429
432
if (response.major != RPC_PROTO_MAJOR_VERSION || response.minor > RPC_PROTO_MINOR_VERSION) {
430
433
fprintf (stderr, " RPC server version mismatch: %d.%d.%d\n " , response.major , response.minor , response.patch );
431
434
return false ;
@@ -481,7 +484,7 @@ static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t buffer) {
481
484
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context ;
482
485
rpc_msg_free_buffer_req request = {ctx->remote_ptr };
483
486
bool status = send_rpc_cmd (ctx->sock , RPC_CMD_FREE_BUFFER, &request, sizeof (request), nullptr , 0 );
484
- GGML_ASSERT (status);
487
+ RPC_STATUS_ASSERT (status);
485
488
delete ctx;
486
489
}
487
490
@@ -493,7 +496,7 @@ static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) {
493
496
rpc_msg_buffer_get_base_req request = {ctx->remote_ptr };
494
497
rpc_msg_buffer_get_base_rsp response;
495
498
bool status = send_rpc_cmd (ctx->sock , RPC_CMD_BUFFER_GET_BASE, &request, sizeof (request), &response, sizeof (response));
496
- GGML_ASSERT (status);
499
+ RPC_STATUS_ASSERT (status);
497
500
ctx->base_ptr = reinterpret_cast <void *>(response.base_ptr );
498
501
return ctx->base_ptr ;
499
502
}
@@ -545,7 +548,7 @@ static enum ggml_status ggml_backend_rpc_buffer_init_tensor(ggml_backend_buffer_
545
548
request.tensor = serialize_tensor (tensor);
546
549
547
550
bool status = send_rpc_cmd (ctx->sock , RPC_CMD_INIT_TENSOR, &request, sizeof (request), nullptr , 0 );
548
- GGML_ASSERT (status);
551
+ RPC_STATUS_ASSERT (status);
549
552
}
550
553
return GGML_STATUS_SUCCESS;
551
554
}
@@ -560,7 +563,7 @@ static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t buffer, ggm
560
563
request.hash = fnv_hash ((const uint8_t *)data, size);
561
564
rpc_msg_set_tensor_hash_rsp response;
562
565
bool status = send_rpc_cmd (ctx->sock , RPC_CMD_SET_TENSOR_HASH, &request, sizeof (request), &response, sizeof (response));
563
- GGML_ASSERT (status);
566
+ RPC_STATUS_ASSERT (status);
564
567
if (response.result ) {
565
568
// the server has the same data, no need to send it
566
569
return ;
@@ -573,7 +576,7 @@ static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t buffer, ggm
573
576
memcpy (input.data () + sizeof (rpc_tensor), &offset, sizeof (offset));
574
577
memcpy (input.data () + sizeof (rpc_tensor) + sizeof (offset), data, size);
575
578
bool status = send_rpc_cmd (ctx->sock , RPC_CMD_SET_TENSOR, input.data (), input.size ());
576
- GGML_ASSERT (status);
579
+ RPC_STATUS_ASSERT (status);
577
580
}
578
581
579
582
static void ggml_backend_rpc_buffer_get_tensor (ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
@@ -583,7 +586,7 @@ static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t buffer, con
583
586
request.offset = offset;
584
587
request.size = size;
585
588
bool status = send_rpc_cmd (ctx->sock , RPC_CMD_GET_TENSOR, &request, sizeof (request), data, size);
586
- GGML_ASSERT (status);
589
+ RPC_STATUS_ASSERT (status);
587
590
}
588
591
589
592
static bool ggml_backend_rpc_buffer_cpy_tensor (ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
@@ -601,15 +604,15 @@ static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, con
601
604
request.dst = serialize_tensor (dst);
602
605
rpc_msg_copy_tensor_rsp response;
603
606
bool status = send_rpc_cmd (ctx->sock , RPC_CMD_COPY_TENSOR, &request, sizeof (request), &response, sizeof (response));
604
- GGML_ASSERT (status);
607
+ RPC_STATUS_ASSERT (status);
605
608
return response.result ;
606
609
}
607
610
608
611
static void ggml_backend_rpc_buffer_clear (ggml_backend_buffer_t buffer, uint8_t value) {
609
612
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context ;
610
613
rpc_msg_buffer_clear_req request = {ctx->remote_ptr , value};
611
614
bool status = send_rpc_cmd (ctx->sock , RPC_CMD_BUFFER_CLEAR, &request, sizeof (request), nullptr , 0 );
612
- GGML_ASSERT (status);
615
+ RPC_STATUS_ASSERT (status);
613
616
}
614
617
615
618
static ggml_backend_buffer_i ggml_backend_rpc_buffer_interface = {
@@ -635,7 +638,7 @@ static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_back
635
638
rpc_msg_alloc_buffer_rsp response;
636
639
auto sock = get_socket (buft_ctx->endpoint );
637
640
bool status = send_rpc_cmd (sock, RPC_CMD_ALLOC_BUFFER, &request, sizeof (request), &response, sizeof (response));
638
- GGML_ASSERT (status);
641
+ RPC_STATUS_ASSERT (status);
639
642
if (response.remote_ptr != 0 ) {
640
643
ggml_backend_buffer_t buffer = ggml_backend_buffer_init (buft,
641
644
ggml_backend_rpc_buffer_interface,
@@ -650,7 +653,7 @@ static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_back
650
653
static size_t get_alignment (const std::shared_ptr<socket_t > & sock) {
651
654
rpc_msg_get_alignment_rsp response;
652
655
bool status = send_rpc_cmd (sock, RPC_CMD_GET_ALIGNMENT, nullptr , 0 , &response, sizeof (response));
653
- GGML_ASSERT (status);
656
+ RPC_STATUS_ASSERT (status);
654
657
return response.alignment ;
655
658
}
656
659
@@ -662,7 +665,7 @@ static size_t ggml_backend_rpc_buffer_type_get_alignment(ggml_backend_buffer_typ
662
665
static size_t get_max_size (const std::shared_ptr<socket_t > & sock) {
663
666
rpc_msg_get_max_size_rsp response;
664
667
bool status = send_rpc_cmd (sock, RPC_CMD_GET_MAX_SIZE, nullptr , 0 , &response, sizeof (response));
665
- GGML_ASSERT (status);
668
+ RPC_STATUS_ASSERT (status);
666
669
return response.max_size ;
667
670
}
668
671
@@ -683,7 +686,7 @@ static size_t ggml_backend_rpc_buffer_type_get_alloc_size(ggml_backend_buffer_ty
683
686
684
687
rpc_msg_get_alloc_size_rsp response;
685
688
bool status = send_rpc_cmd (sock, RPC_CMD_GET_ALLOC_SIZE, &request, sizeof (request), &response, sizeof (response));
686
- GGML_ASSERT (status);
689
+ RPC_STATUS_ASSERT (status);
687
690
688
691
return response.alloc_size ;
689
692
} else {
@@ -761,7 +764,7 @@ static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, g
761
764
rpc_msg_graph_compute_rsp response;
762
765
auto sock = get_socket (rpc_ctx->endpoint );
763
766
bool status = send_rpc_cmd (sock, RPC_CMD_GRAPH_COMPUTE, input.data (), input.size (), &response, sizeof (response));
764
- GGML_ASSERT (status);
767
+ RPC_STATUS_ASSERT (status);
765
768
return (enum ggml_status)response.result ;
766
769
}
767
770
@@ -835,7 +838,7 @@ bool ggml_backend_is_rpc(ggml_backend_t backend) {
835
838
static void get_device_memory (const std::shared_ptr<socket_t > & sock, size_t * free, size_t * total) {
836
839
rpc_msg_get_device_memory_rsp response;
837
840
bool status = send_rpc_cmd (sock, RPC_CMD_GET_DEVICE_MEMORY, nullptr , 0 , &response, sizeof (response));
838
- GGML_ASSERT (status);
841
+ RPC_STATUS_ASSERT (status);
839
842
*free = response.free_mem ;
840
843
*total = response.total_mem ;
841
844
}
0 commit comments