rgerganov commited on
Commit
656ae00
·
1 Parent(s): e0dc1ad

rpc : prevent crashes on invalid input (llama/9040)

Browse files

Add more checks which prevent RPC server from crashing if invalid input
is received from client

Files changed (1) hide show
  1. ggml/src/ggml-rpc.cpp +47 -34
ggml/src/ggml-rpc.cpp CHANGED
@@ -82,17 +82,18 @@ static_assert(sizeof(rpc_tensor) % 8 == 0, "rpc_tensor size must be multiple of
82
 
83
  // RPC commands
84
  enum rpc_cmd {
85
- ALLOC_BUFFER = 0,
86
- GET_ALIGNMENT,
87
- GET_MAX_SIZE,
88
- BUFFER_GET_BASE,
89
- FREE_BUFFER,
90
- BUFFER_CLEAR,
91
- SET_TENSOR,
92
- GET_TENSOR,
93
- COPY_TENSOR,
94
- GRAPH_COMPUTE,
95
- GET_DEVICE_MEMORY,
 
96
  };
97
 
98
  // RPC data structures
@@ -330,7 +331,7 @@ GGML_CALL static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t
330
  uint64_t remote_ptr = ctx->remote_ptr;
331
  memcpy(input.data(), &remote_ptr, sizeof(remote_ptr));
332
  std::vector<uint8_t> output;
333
- bool status = send_rpc_cmd(ctx->sock, FREE_BUFFER, input, output);
334
  GGML_ASSERT(status);
335
  GGML_ASSERT(output.empty());
336
  delete ctx;
@@ -346,7 +347,7 @@ GGML_CALL static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t b
346
  uint64_t remote_ptr = ctx->remote_ptr;
347
  memcpy(input.data(), &remote_ptr, sizeof(remote_ptr));
348
  std::vector<uint8_t> output;
349
- bool status = send_rpc_cmd(ctx->sock, BUFFER_GET_BASE, input, output);
350
  GGML_ASSERT(status);
351
  GGML_ASSERT(output.size() == sizeof(uint64_t));
352
  // output serialization format: | base_ptr (8 bytes) |
@@ -405,7 +406,7 @@ GGML_CALL static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t b
405
  memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
406
  memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), data, size);
407
  std::vector<uint8_t> output;
408
- bool status = send_rpc_cmd(ctx->sock, SET_TENSOR, input, output);
409
  GGML_ASSERT(status);
410
  }
411
 
@@ -419,7 +420,7 @@ GGML_CALL static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t b
419
  memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
420
  memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), &size, sizeof(size));
421
  std::vector<uint8_t> output;
422
- bool status = send_rpc_cmd(ctx->sock, GET_TENSOR, input, output);
423
  GGML_ASSERT(status);
424
  GGML_ASSERT(output.size() == size);
425
  // output serialization format: | data (size bytes) |
@@ -444,7 +445,7 @@ GGML_CALL static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t b
444
  memcpy(input.data(), &rpc_src, sizeof(rpc_src));
445
  memcpy(input.data() + sizeof(rpc_src), &rpc_dst, sizeof(rpc_dst));
446
  std::vector<uint8_t> output;
447
- bool status = send_rpc_cmd(ctx->sock, COPY_TENSOR, input, output);
448
  GGML_ASSERT(status);
449
  // output serialization format: | result (1 byte) |
450
  GGML_ASSERT(output.size() == 1);
@@ -459,7 +460,7 @@ GGML_CALL static void ggml_backend_rpc_buffer_clear(ggml_backend_buffer_t buffer
459
  memcpy(input.data(), &ctx->remote_ptr, sizeof(ctx->remote_ptr));
460
  memcpy(input.data() + sizeof(ctx->remote_ptr), &value, sizeof(value));
461
  std::vector<uint8_t> output;
462
- bool status = send_rpc_cmd(ctx->sock, BUFFER_CLEAR, input, output);
463
  GGML_ASSERT(status);
464
  }
465
 
@@ -488,7 +489,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer
488
  memcpy(input.data(), &size, sizeof(size));
489
  std::vector<uint8_t> output;
490
  auto sock = get_socket(buft_ctx->endpoint);
491
- bool status = send_rpc_cmd(sock, ALLOC_BUFFER, input, output);
492
  GGML_ASSERT(status);
493
  GGML_ASSERT(output.size() == 2*sizeof(uint64_t));
494
  // output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) |
@@ -511,7 +512,7 @@ static size_t get_alignment(const std::shared_ptr<socket_t> & sock) {
511
  // input serialization format: | 0 bytes |
512
  std::vector<uint8_t> input;
513
  std::vector<uint8_t> output;
514
- bool status = send_rpc_cmd(sock, GET_ALIGNMENT, input, output);
515
  GGML_ASSERT(status);
516
  GGML_ASSERT(output.size() == sizeof(uint64_t));
517
  // output serialization format: | alignment (8 bytes) |
@@ -529,7 +530,7 @@ static size_t get_max_size(const std::shared_ptr<socket_t> & sock) {
529
  // input serialization format: | 0 bytes |
530
  std::vector<uint8_t> input;
531
  std::vector<uint8_t> output;
532
- bool status = send_rpc_cmd(sock, GET_MAX_SIZE, input, output);
533
  GGML_ASSERT(status);
534
  GGML_ASSERT(output.size() == sizeof(uint64_t));
535
  // output serialization format: | max_size (8 bytes) |
@@ -622,7 +623,7 @@ GGML_CALL static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t
622
  serialize_graph(cgraph, input);
623
  std::vector<uint8_t> output;
624
  auto sock = get_socket(rpc_ctx->endpoint);
625
- bool status = send_rpc_cmd(sock, GRAPH_COMPUTE, input, output);
626
  GGML_ASSERT(status);
627
  GGML_ASSERT(output.size() == 1);
628
  return (enum ggml_status)output[0];
@@ -719,7 +720,7 @@ static void get_device_memory(const std::shared_ptr<socket_t> & sock, size_t * f
719
  // input serialization format: | 0 bytes |
720
  std::vector<uint8_t> input;
721
  std::vector<uint8_t> output;
722
- bool status = send_rpc_cmd(sock, GET_DEVICE_MEMORY, input, output);
723
  GGML_ASSERT(status);
724
  GGML_ASSERT(output.size() == 2*sizeof(uint64_t));
725
  // output serialization format: | free (8 bytes) | total (8 bytes) |
@@ -1098,59 +1099,69 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
1098
  if (!recv_data(sockfd, &cmd, 1)) {
1099
  break;
1100
  }
 
 
 
 
 
1101
  std::vector<uint8_t> input;
1102
  std::vector<uint8_t> output;
1103
  uint64_t input_size;
1104
  if (!recv_data(sockfd, &input_size, sizeof(input_size))) {
1105
  break;
1106
  }
1107
- input.resize(input_size);
 
 
 
 
 
1108
  if (!recv_data(sockfd, input.data(), input_size)) {
1109
  break;
1110
  }
1111
  bool ok = true;
1112
  switch (cmd) {
1113
- case ALLOC_BUFFER: {
1114
  ok = server.alloc_buffer(input, output);
1115
  break;
1116
  }
1117
- case GET_ALIGNMENT: {
1118
  server.get_alignment(output);
1119
  break;
1120
  }
1121
- case GET_MAX_SIZE: {
1122
  server.get_max_size(output);
1123
  break;
1124
  }
1125
- case BUFFER_GET_BASE: {
1126
  ok = server.buffer_get_base(input, output);
1127
  break;
1128
  }
1129
- case FREE_BUFFER: {
1130
  ok = server.free_buffer(input);
1131
  break;
1132
  }
1133
- case BUFFER_CLEAR: {
1134
  ok = server.buffer_clear(input);
1135
  break;
1136
  }
1137
- case SET_TENSOR: {
1138
  ok = server.set_tensor(input);
1139
  break;
1140
  }
1141
- case GET_TENSOR: {
1142
  ok = server.get_tensor(input, output);
1143
  break;
1144
  }
1145
- case COPY_TENSOR: {
1146
  ok = server.copy_tensor(input, output);
1147
  break;
1148
  }
1149
- case GRAPH_COMPUTE: {
1150
  ok = server.graph_compute(input, output);
1151
  break;
1152
  }
1153
- case GET_DEVICE_MEMORY: {
1154
  // output serialization format: | free (8 bytes) | total (8 bytes) |
1155
  output.resize(2*sizeof(uint64_t), 0);
1156
  memcpy(output.data(), &free_mem, sizeof(free_mem));
@@ -1203,8 +1214,10 @@ void start_rpc_server(ggml_backend_t backend, const char * endpoint, size_t free
1203
  return;
1204
  }
1205
  printf("Accepted client connection, free_mem=%zu, total_mem=%zu\n", free_mem, total_mem);
 
1206
  rpc_serve_client(backend, client_socket->fd, free_mem, total_mem);
1207
  printf("Client connection closed\n");
 
1208
  }
1209
  #ifdef _WIN32
1210
  WSACleanup();
 
82
 
83
  // RPC commands
84
  enum rpc_cmd {
85
+ RPC_CMD_ALLOC_BUFFER = 0,
86
+ RPC_CMD_GET_ALIGNMENT,
87
+ RPC_CMD_GET_MAX_SIZE,
88
+ RPC_CMD_BUFFER_GET_BASE,
89
+ RPC_CMD_FREE_BUFFER,
90
+ RPC_CMD_BUFFER_CLEAR,
91
+ RPC_CMD_SET_TENSOR,
92
+ RPC_CMD_GET_TENSOR,
93
+ RPC_CMD_COPY_TENSOR,
94
+ RPC_CMD_GRAPH_COMPUTE,
95
+ RPC_CMD_GET_DEVICE_MEMORY,
96
+ RPC_CMD_COUNT,
97
  };
98
 
99
  // RPC data structures
 
331
  uint64_t remote_ptr = ctx->remote_ptr;
332
  memcpy(input.data(), &remote_ptr, sizeof(remote_ptr));
333
  std::vector<uint8_t> output;
334
+ bool status = send_rpc_cmd(ctx->sock, RPC_CMD_FREE_BUFFER, input, output);
335
  GGML_ASSERT(status);
336
  GGML_ASSERT(output.empty());
337
  delete ctx;
 
347
  uint64_t remote_ptr = ctx->remote_ptr;
348
  memcpy(input.data(), &remote_ptr, sizeof(remote_ptr));
349
  std::vector<uint8_t> output;
350
+ bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_GET_BASE, input, output);
351
  GGML_ASSERT(status);
352
  GGML_ASSERT(output.size() == sizeof(uint64_t));
353
  // output serialization format: | base_ptr (8 bytes) |
 
406
  memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
407
  memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), data, size);
408
  std::vector<uint8_t> output;
409
+ bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR, input, output);
410
  GGML_ASSERT(status);
411
  }
412
 
 
420
  memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
421
  memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), &size, sizeof(size));
422
  std::vector<uint8_t> output;
423
+ bool status = send_rpc_cmd(ctx->sock, RPC_CMD_GET_TENSOR, input, output);
424
  GGML_ASSERT(status);
425
  GGML_ASSERT(output.size() == size);
426
  // output serialization format: | data (size bytes) |
 
445
  memcpy(input.data(), &rpc_src, sizeof(rpc_src));
446
  memcpy(input.data() + sizeof(rpc_src), &rpc_dst, sizeof(rpc_dst));
447
  std::vector<uint8_t> output;
448
+ bool status = send_rpc_cmd(ctx->sock, RPC_CMD_COPY_TENSOR, input, output);
449
  GGML_ASSERT(status);
450
  // output serialization format: | result (1 byte) |
451
  GGML_ASSERT(output.size() == 1);
 
460
  memcpy(input.data(), &ctx->remote_ptr, sizeof(ctx->remote_ptr));
461
  memcpy(input.data() + sizeof(ctx->remote_ptr), &value, sizeof(value));
462
  std::vector<uint8_t> output;
463
+ bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_CLEAR, input, output);
464
  GGML_ASSERT(status);
465
  }
466
 
 
489
  memcpy(input.data(), &size, sizeof(size));
490
  std::vector<uint8_t> output;
491
  auto sock = get_socket(buft_ctx->endpoint);
492
+ bool status = send_rpc_cmd(sock, RPC_CMD_ALLOC_BUFFER, input, output);
493
  GGML_ASSERT(status);
494
  GGML_ASSERT(output.size() == 2*sizeof(uint64_t));
495
  // output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) |
 
512
  // input serialization format: | 0 bytes |
513
  std::vector<uint8_t> input;
514
  std::vector<uint8_t> output;
515
+ bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALIGNMENT, input, output);
516
  GGML_ASSERT(status);
517
  GGML_ASSERT(output.size() == sizeof(uint64_t));
518
  // output serialization format: | alignment (8 bytes) |
 
530
  // input serialization format: | 0 bytes |
531
  std::vector<uint8_t> input;
532
  std::vector<uint8_t> output;
533
+ bool status = send_rpc_cmd(sock, RPC_CMD_GET_MAX_SIZE, input, output);
534
  GGML_ASSERT(status);
535
  GGML_ASSERT(output.size() == sizeof(uint64_t));
536
  // output serialization format: | max_size (8 bytes) |
 
623
  serialize_graph(cgraph, input);
624
  std::vector<uint8_t> output;
625
  auto sock = get_socket(rpc_ctx->endpoint);
626
+ bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input, output);
627
  GGML_ASSERT(status);
628
  GGML_ASSERT(output.size() == 1);
629
  return (enum ggml_status)output[0];
 
720
  // input serialization format: | 0 bytes |
721
  std::vector<uint8_t> input;
722
  std::vector<uint8_t> output;
723
+ bool status = send_rpc_cmd(sock, RPC_CMD_GET_DEVICE_MEMORY, input, output);
724
  GGML_ASSERT(status);
725
  GGML_ASSERT(output.size() == 2*sizeof(uint64_t));
726
  // output serialization format: | free (8 bytes) | total (8 bytes) |
 
1099
  if (!recv_data(sockfd, &cmd, 1)) {
1100
  break;
1101
  }
1102
+ if (cmd >= RPC_CMD_COUNT) {
1103
+ // fail fast if the command is invalid
1104
+ fprintf(stderr, "Unknown command: %d\n", cmd);
1105
+ break;
1106
+ }
1107
  std::vector<uint8_t> input;
1108
  std::vector<uint8_t> output;
1109
  uint64_t input_size;
1110
  if (!recv_data(sockfd, &input_size, sizeof(input_size))) {
1111
  break;
1112
  }
1113
+ try {
1114
+ input.resize(input_size);
1115
+ } catch (const std::bad_alloc & e) {
1116
+ fprintf(stderr, "Failed to allocate input buffer of size %" PRIu64 "\n", input_size);
1117
+ break;
1118
+ }
1119
  if (!recv_data(sockfd, input.data(), input_size)) {
1120
  break;
1121
  }
1122
  bool ok = true;
1123
  switch (cmd) {
1124
+ case RPC_CMD_ALLOC_BUFFER: {
1125
  ok = server.alloc_buffer(input, output);
1126
  break;
1127
  }
1128
+ case RPC_CMD_GET_ALIGNMENT: {
1129
  server.get_alignment(output);
1130
  break;
1131
  }
1132
+ case RPC_CMD_GET_MAX_SIZE: {
1133
  server.get_max_size(output);
1134
  break;
1135
  }
1136
+ case RPC_CMD_BUFFER_GET_BASE: {
1137
  ok = server.buffer_get_base(input, output);
1138
  break;
1139
  }
1140
+ case RPC_CMD_FREE_BUFFER: {
1141
  ok = server.free_buffer(input);
1142
  break;
1143
  }
1144
+ case RPC_CMD_BUFFER_CLEAR: {
1145
  ok = server.buffer_clear(input);
1146
  break;
1147
  }
1148
+ case RPC_CMD_SET_TENSOR: {
1149
  ok = server.set_tensor(input);
1150
  break;
1151
  }
1152
+ case RPC_CMD_GET_TENSOR: {
1153
  ok = server.get_tensor(input, output);
1154
  break;
1155
  }
1156
+ case RPC_CMD_COPY_TENSOR: {
1157
  ok = server.copy_tensor(input, output);
1158
  break;
1159
  }
1160
+ case RPC_CMD_GRAPH_COMPUTE: {
1161
  ok = server.graph_compute(input, output);
1162
  break;
1163
  }
1164
+ case RPC_CMD_GET_DEVICE_MEMORY: {
1165
  // output serialization format: | free (8 bytes) | total (8 bytes) |
1166
  output.resize(2*sizeof(uint64_t), 0);
1167
  memcpy(output.data(), &free_mem, sizeof(free_mem));
 
1214
  return;
1215
  }
1216
  printf("Accepted client connection, free_mem=%zu, total_mem=%zu\n", free_mem, total_mem);
1217
+ fflush(stdout);
1218
  rpc_serve_client(backend, client_socket->fd, free_mem, total_mem);
1219
  printf("Client connection closed\n");
1220
+ fflush(stdout);
1221
  }
1222
  #ifdef _WIN32
1223
  WSACleanup();