Spaces:
Running
Running
rpc : prevent crashes on invalid input (llama/9040)
Browse filesAdd more checks which prevent RPC server from crashing if invalid input
is received from client
- ggml/src/ggml-rpc.cpp +47 -34
ggml/src/ggml-rpc.cpp
CHANGED
|
@@ -82,17 +82,18 @@ static_assert(sizeof(rpc_tensor) % 8 == 0, "rpc_tensor size must be multiple of
|
|
| 82 |
|
| 83 |
// RPC commands
|
| 84 |
enum rpc_cmd {
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
|
|
|
| 96 |
};
|
| 97 |
|
| 98 |
// RPC data structures
|
|
@@ -330,7 +331,7 @@ GGML_CALL static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t
|
|
| 330 |
uint64_t remote_ptr = ctx->remote_ptr;
|
| 331 |
memcpy(input.data(), &remote_ptr, sizeof(remote_ptr));
|
| 332 |
std::vector<uint8_t> output;
|
| 333 |
-
bool status = send_rpc_cmd(ctx->sock,
|
| 334 |
GGML_ASSERT(status);
|
| 335 |
GGML_ASSERT(output.empty());
|
| 336 |
delete ctx;
|
|
@@ -346,7 +347,7 @@ GGML_CALL static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t b
|
|
| 346 |
uint64_t remote_ptr = ctx->remote_ptr;
|
| 347 |
memcpy(input.data(), &remote_ptr, sizeof(remote_ptr));
|
| 348 |
std::vector<uint8_t> output;
|
| 349 |
-
bool status = send_rpc_cmd(ctx->sock,
|
| 350 |
GGML_ASSERT(status);
|
| 351 |
GGML_ASSERT(output.size() == sizeof(uint64_t));
|
| 352 |
// output serialization format: | base_ptr (8 bytes) |
|
|
@@ -405,7 +406,7 @@ GGML_CALL static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t b
|
|
| 405 |
memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
|
| 406 |
memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), data, size);
|
| 407 |
std::vector<uint8_t> output;
|
| 408 |
-
bool status = send_rpc_cmd(ctx->sock,
|
| 409 |
GGML_ASSERT(status);
|
| 410 |
}
|
| 411 |
|
|
@@ -419,7 +420,7 @@ GGML_CALL static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t b
|
|
| 419 |
memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
|
| 420 |
memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), &size, sizeof(size));
|
| 421 |
std::vector<uint8_t> output;
|
| 422 |
-
bool status = send_rpc_cmd(ctx->sock,
|
| 423 |
GGML_ASSERT(status);
|
| 424 |
GGML_ASSERT(output.size() == size);
|
| 425 |
// output serialization format: | data (size bytes) |
|
|
@@ -444,7 +445,7 @@ GGML_CALL static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t b
|
|
| 444 |
memcpy(input.data(), &rpc_src, sizeof(rpc_src));
|
| 445 |
memcpy(input.data() + sizeof(rpc_src), &rpc_dst, sizeof(rpc_dst));
|
| 446 |
std::vector<uint8_t> output;
|
| 447 |
-
bool status = send_rpc_cmd(ctx->sock,
|
| 448 |
GGML_ASSERT(status);
|
| 449 |
// output serialization format: | result (1 byte) |
|
| 450 |
GGML_ASSERT(output.size() == 1);
|
|
@@ -459,7 +460,7 @@ GGML_CALL static void ggml_backend_rpc_buffer_clear(ggml_backend_buffer_t buffer
|
|
| 459 |
memcpy(input.data(), &ctx->remote_ptr, sizeof(ctx->remote_ptr));
|
| 460 |
memcpy(input.data() + sizeof(ctx->remote_ptr), &value, sizeof(value));
|
| 461 |
std::vector<uint8_t> output;
|
| 462 |
-
bool status = send_rpc_cmd(ctx->sock,
|
| 463 |
GGML_ASSERT(status);
|
| 464 |
}
|
| 465 |
|
|
@@ -488,7 +489,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer
|
|
| 488 |
memcpy(input.data(), &size, sizeof(size));
|
| 489 |
std::vector<uint8_t> output;
|
| 490 |
auto sock = get_socket(buft_ctx->endpoint);
|
| 491 |
-
bool status = send_rpc_cmd(sock,
|
| 492 |
GGML_ASSERT(status);
|
| 493 |
GGML_ASSERT(output.size() == 2*sizeof(uint64_t));
|
| 494 |
// output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) |
|
|
@@ -511,7 +512,7 @@ static size_t get_alignment(const std::shared_ptr<socket_t> & sock) {
|
|
| 511 |
// input serialization format: | 0 bytes |
|
| 512 |
std::vector<uint8_t> input;
|
| 513 |
std::vector<uint8_t> output;
|
| 514 |
-
bool status = send_rpc_cmd(sock,
|
| 515 |
GGML_ASSERT(status);
|
| 516 |
GGML_ASSERT(output.size() == sizeof(uint64_t));
|
| 517 |
// output serialization format: | alignment (8 bytes) |
|
|
@@ -529,7 +530,7 @@ static size_t get_max_size(const std::shared_ptr<socket_t> & sock) {
|
|
| 529 |
// input serialization format: | 0 bytes |
|
| 530 |
std::vector<uint8_t> input;
|
| 531 |
std::vector<uint8_t> output;
|
| 532 |
-
bool status = send_rpc_cmd(sock,
|
| 533 |
GGML_ASSERT(status);
|
| 534 |
GGML_ASSERT(output.size() == sizeof(uint64_t));
|
| 535 |
// output serialization format: | max_size (8 bytes) |
|
|
@@ -622,7 +623,7 @@ GGML_CALL static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t
|
|
| 622 |
serialize_graph(cgraph, input);
|
| 623 |
std::vector<uint8_t> output;
|
| 624 |
auto sock = get_socket(rpc_ctx->endpoint);
|
| 625 |
-
bool status = send_rpc_cmd(sock,
|
| 626 |
GGML_ASSERT(status);
|
| 627 |
GGML_ASSERT(output.size() == 1);
|
| 628 |
return (enum ggml_status)output[0];
|
|
@@ -719,7 +720,7 @@ static void get_device_memory(const std::shared_ptr<socket_t> & sock, size_t * f
|
|
| 719 |
// input serialization format: | 0 bytes |
|
| 720 |
std::vector<uint8_t> input;
|
| 721 |
std::vector<uint8_t> output;
|
| 722 |
-
bool status = send_rpc_cmd(sock,
|
| 723 |
GGML_ASSERT(status);
|
| 724 |
GGML_ASSERT(output.size() == 2*sizeof(uint64_t));
|
| 725 |
// output serialization format: | free (8 bytes) | total (8 bytes) |
|
|
@@ -1098,59 +1099,69 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
|
|
| 1098 |
if (!recv_data(sockfd, &cmd, 1)) {
|
| 1099 |
break;
|
| 1100 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1101 |
std::vector<uint8_t> input;
|
| 1102 |
std::vector<uint8_t> output;
|
| 1103 |
uint64_t input_size;
|
| 1104 |
if (!recv_data(sockfd, &input_size, sizeof(input_size))) {
|
| 1105 |
break;
|
| 1106 |
}
|
| 1107 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1108 |
if (!recv_data(sockfd, input.data(), input_size)) {
|
| 1109 |
break;
|
| 1110 |
}
|
| 1111 |
bool ok = true;
|
| 1112 |
switch (cmd) {
|
| 1113 |
-
case
|
| 1114 |
ok = server.alloc_buffer(input, output);
|
| 1115 |
break;
|
| 1116 |
}
|
| 1117 |
-
case
|
| 1118 |
server.get_alignment(output);
|
| 1119 |
break;
|
| 1120 |
}
|
| 1121 |
-
case
|
| 1122 |
server.get_max_size(output);
|
| 1123 |
break;
|
| 1124 |
}
|
| 1125 |
-
case
|
| 1126 |
ok = server.buffer_get_base(input, output);
|
| 1127 |
break;
|
| 1128 |
}
|
| 1129 |
-
case
|
| 1130 |
ok = server.free_buffer(input);
|
| 1131 |
break;
|
| 1132 |
}
|
| 1133 |
-
case
|
| 1134 |
ok = server.buffer_clear(input);
|
| 1135 |
break;
|
| 1136 |
}
|
| 1137 |
-
case
|
| 1138 |
ok = server.set_tensor(input);
|
| 1139 |
break;
|
| 1140 |
}
|
| 1141 |
-
case
|
| 1142 |
ok = server.get_tensor(input, output);
|
| 1143 |
break;
|
| 1144 |
}
|
| 1145 |
-
case
|
| 1146 |
ok = server.copy_tensor(input, output);
|
| 1147 |
break;
|
| 1148 |
}
|
| 1149 |
-
case
|
| 1150 |
ok = server.graph_compute(input, output);
|
| 1151 |
break;
|
| 1152 |
}
|
| 1153 |
-
case
|
| 1154 |
// output serialization format: | free (8 bytes) | total (8 bytes) |
|
| 1155 |
output.resize(2*sizeof(uint64_t), 0);
|
| 1156 |
memcpy(output.data(), &free_mem, sizeof(free_mem));
|
|
@@ -1203,8 +1214,10 @@ void start_rpc_server(ggml_backend_t backend, const char * endpoint, size_t free
|
|
| 1203 |
return;
|
| 1204 |
}
|
| 1205 |
printf("Accepted client connection, free_mem=%zu, total_mem=%zu\n", free_mem, total_mem);
|
|
|
|
| 1206 |
rpc_serve_client(backend, client_socket->fd, free_mem, total_mem);
|
| 1207 |
printf("Client connection closed\n");
|
|
|
|
| 1208 |
}
|
| 1209 |
#ifdef _WIN32
|
| 1210 |
WSACleanup();
|
|
|
|
| 82 |
|
| 83 |
// RPC commands
|
| 84 |
enum rpc_cmd {
|
| 85 |
+
RPC_CMD_ALLOC_BUFFER = 0,
|
| 86 |
+
RPC_CMD_GET_ALIGNMENT,
|
| 87 |
+
RPC_CMD_GET_MAX_SIZE,
|
| 88 |
+
RPC_CMD_BUFFER_GET_BASE,
|
| 89 |
+
RPC_CMD_FREE_BUFFER,
|
| 90 |
+
RPC_CMD_BUFFER_CLEAR,
|
| 91 |
+
RPC_CMD_SET_TENSOR,
|
| 92 |
+
RPC_CMD_GET_TENSOR,
|
| 93 |
+
RPC_CMD_COPY_TENSOR,
|
| 94 |
+
RPC_CMD_GRAPH_COMPUTE,
|
| 95 |
+
RPC_CMD_GET_DEVICE_MEMORY,
|
| 96 |
+
RPC_CMD_COUNT,
|
| 97 |
};
|
| 98 |
|
| 99 |
// RPC data structures
|
|
|
|
| 331 |
uint64_t remote_ptr = ctx->remote_ptr;
|
| 332 |
memcpy(input.data(), &remote_ptr, sizeof(remote_ptr));
|
| 333 |
std::vector<uint8_t> output;
|
| 334 |
+
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_FREE_BUFFER, input, output);
|
| 335 |
GGML_ASSERT(status);
|
| 336 |
GGML_ASSERT(output.empty());
|
| 337 |
delete ctx;
|
|
|
|
| 347 |
uint64_t remote_ptr = ctx->remote_ptr;
|
| 348 |
memcpy(input.data(), &remote_ptr, sizeof(remote_ptr));
|
| 349 |
std::vector<uint8_t> output;
|
| 350 |
+
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_GET_BASE, input, output);
|
| 351 |
GGML_ASSERT(status);
|
| 352 |
GGML_ASSERT(output.size() == sizeof(uint64_t));
|
| 353 |
// output serialization format: | base_ptr (8 bytes) |
|
|
|
|
| 406 |
memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
|
| 407 |
memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), data, size);
|
| 408 |
std::vector<uint8_t> output;
|
| 409 |
+
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR, input, output);
|
| 410 |
GGML_ASSERT(status);
|
| 411 |
}
|
| 412 |
|
|
|
|
| 420 |
memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
|
| 421 |
memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), &size, sizeof(size));
|
| 422 |
std::vector<uint8_t> output;
|
| 423 |
+
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_GET_TENSOR, input, output);
|
| 424 |
GGML_ASSERT(status);
|
| 425 |
GGML_ASSERT(output.size() == size);
|
| 426 |
// output serialization format: | data (size bytes) |
|
|
|
|
| 445 |
memcpy(input.data(), &rpc_src, sizeof(rpc_src));
|
| 446 |
memcpy(input.data() + sizeof(rpc_src), &rpc_dst, sizeof(rpc_dst));
|
| 447 |
std::vector<uint8_t> output;
|
| 448 |
+
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_COPY_TENSOR, input, output);
|
| 449 |
GGML_ASSERT(status);
|
| 450 |
// output serialization format: | result (1 byte) |
|
| 451 |
GGML_ASSERT(output.size() == 1);
|
|
|
|
| 460 |
memcpy(input.data(), &ctx->remote_ptr, sizeof(ctx->remote_ptr));
|
| 461 |
memcpy(input.data() + sizeof(ctx->remote_ptr), &value, sizeof(value));
|
| 462 |
std::vector<uint8_t> output;
|
| 463 |
+
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_CLEAR, input, output);
|
| 464 |
GGML_ASSERT(status);
|
| 465 |
}
|
| 466 |
|
|
|
|
| 489 |
memcpy(input.data(), &size, sizeof(size));
|
| 490 |
std::vector<uint8_t> output;
|
| 491 |
auto sock = get_socket(buft_ctx->endpoint);
|
| 492 |
+
bool status = send_rpc_cmd(sock, RPC_CMD_ALLOC_BUFFER, input, output);
|
| 493 |
GGML_ASSERT(status);
|
| 494 |
GGML_ASSERT(output.size() == 2*sizeof(uint64_t));
|
| 495 |
// output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) |
|
|
|
|
| 512 |
// input serialization format: | 0 bytes |
|
| 513 |
std::vector<uint8_t> input;
|
| 514 |
std::vector<uint8_t> output;
|
| 515 |
+
bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALIGNMENT, input, output);
|
| 516 |
GGML_ASSERT(status);
|
| 517 |
GGML_ASSERT(output.size() == sizeof(uint64_t));
|
| 518 |
// output serialization format: | alignment (8 bytes) |
|
|
|
|
| 530 |
// input serialization format: | 0 bytes |
|
| 531 |
std::vector<uint8_t> input;
|
| 532 |
std::vector<uint8_t> output;
|
| 533 |
+
bool status = send_rpc_cmd(sock, RPC_CMD_GET_MAX_SIZE, input, output);
|
| 534 |
GGML_ASSERT(status);
|
| 535 |
GGML_ASSERT(output.size() == sizeof(uint64_t));
|
| 536 |
// output serialization format: | max_size (8 bytes) |
|
|
|
|
| 623 |
serialize_graph(cgraph, input);
|
| 624 |
std::vector<uint8_t> output;
|
| 625 |
auto sock = get_socket(rpc_ctx->endpoint);
|
| 626 |
+
bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input, output);
|
| 627 |
GGML_ASSERT(status);
|
| 628 |
GGML_ASSERT(output.size() == 1);
|
| 629 |
return (enum ggml_status)output[0];
|
|
|
|
| 720 |
// input serialization format: | 0 bytes |
|
| 721 |
std::vector<uint8_t> input;
|
| 722 |
std::vector<uint8_t> output;
|
| 723 |
+
bool status = send_rpc_cmd(sock, RPC_CMD_GET_DEVICE_MEMORY, input, output);
|
| 724 |
GGML_ASSERT(status);
|
| 725 |
GGML_ASSERT(output.size() == 2*sizeof(uint64_t));
|
| 726 |
// output serialization format: | free (8 bytes) | total (8 bytes) |
|
|
|
|
| 1099 |
if (!recv_data(sockfd, &cmd, 1)) {
|
| 1100 |
break;
|
| 1101 |
}
|
| 1102 |
+
if (cmd >= RPC_CMD_COUNT) {
|
| 1103 |
+
// fail fast if the command is invalid
|
| 1104 |
+
fprintf(stderr, "Unknown command: %d\n", cmd);
|
| 1105 |
+
break;
|
| 1106 |
+
}
|
| 1107 |
std::vector<uint8_t> input;
|
| 1108 |
std::vector<uint8_t> output;
|
| 1109 |
uint64_t input_size;
|
| 1110 |
if (!recv_data(sockfd, &input_size, sizeof(input_size))) {
|
| 1111 |
break;
|
| 1112 |
}
|
| 1113 |
+
try {
|
| 1114 |
+
input.resize(input_size);
|
| 1115 |
+
} catch (const std::bad_alloc & e) {
|
| 1116 |
+
fprintf(stderr, "Failed to allocate input buffer of size %" PRIu64 "\n", input_size);
|
| 1117 |
+
break;
|
| 1118 |
+
}
|
| 1119 |
if (!recv_data(sockfd, input.data(), input_size)) {
|
| 1120 |
break;
|
| 1121 |
}
|
| 1122 |
bool ok = true;
|
| 1123 |
switch (cmd) {
|
| 1124 |
+
case RPC_CMD_ALLOC_BUFFER: {
|
| 1125 |
ok = server.alloc_buffer(input, output);
|
| 1126 |
break;
|
| 1127 |
}
|
| 1128 |
+
case RPC_CMD_GET_ALIGNMENT: {
|
| 1129 |
server.get_alignment(output);
|
| 1130 |
break;
|
| 1131 |
}
|
| 1132 |
+
case RPC_CMD_GET_MAX_SIZE: {
|
| 1133 |
server.get_max_size(output);
|
| 1134 |
break;
|
| 1135 |
}
|
| 1136 |
+
case RPC_CMD_BUFFER_GET_BASE: {
|
| 1137 |
ok = server.buffer_get_base(input, output);
|
| 1138 |
break;
|
| 1139 |
}
|
| 1140 |
+
case RPC_CMD_FREE_BUFFER: {
|
| 1141 |
ok = server.free_buffer(input);
|
| 1142 |
break;
|
| 1143 |
}
|
| 1144 |
+
case RPC_CMD_BUFFER_CLEAR: {
|
| 1145 |
ok = server.buffer_clear(input);
|
| 1146 |
break;
|
| 1147 |
}
|
| 1148 |
+
case RPC_CMD_SET_TENSOR: {
|
| 1149 |
ok = server.set_tensor(input);
|
| 1150 |
break;
|
| 1151 |
}
|
| 1152 |
+
case RPC_CMD_GET_TENSOR: {
|
| 1153 |
ok = server.get_tensor(input, output);
|
| 1154 |
break;
|
| 1155 |
}
|
| 1156 |
+
case RPC_CMD_COPY_TENSOR: {
|
| 1157 |
ok = server.copy_tensor(input, output);
|
| 1158 |
break;
|
| 1159 |
}
|
| 1160 |
+
case RPC_CMD_GRAPH_COMPUTE: {
|
| 1161 |
ok = server.graph_compute(input, output);
|
| 1162 |
break;
|
| 1163 |
}
|
| 1164 |
+
case RPC_CMD_GET_DEVICE_MEMORY: {
|
| 1165 |
// output serialization format: | free (8 bytes) | total (8 bytes) |
|
| 1166 |
output.resize(2*sizeof(uint64_t), 0);
|
| 1167 |
memcpy(output.data(), &free_mem, sizeof(free_mem));
|
|
|
|
| 1214 |
return;
|
| 1215 |
}
|
| 1216 |
printf("Accepted client connection, free_mem=%zu, total_mem=%zu\n", free_mem, total_mem);
|
| 1217 |
+
fflush(stdout);
|
| 1218 |
rpc_serve_client(backend, client_socket->fd, free_mem, total_mem);
|
| 1219 |
printf("Client connection closed\n");
|
| 1220 |
+
fflush(stdout);
|
| 1221 |
}
|
| 1222 |
#ifdef _WIN32
|
| 1223 |
WSACleanup();
|