rgerganov commited on
Commit
7571b13
·
1 Parent(s): f0ee71c

rpc : resource management rework (llama/7562)

Browse files

* rpc : resource management rework

* address review comments

Files changed (1) hide show
  1. ggml-rpc.cpp +75 -58
ggml-rpc.cpp CHANGED
@@ -6,6 +6,7 @@
6
  #include <string>
7
  #include <vector>
8
  #include <memory>
 
9
  #include <unordered_map>
10
  #include <unordered_set>
11
  #ifdef _WIN32
@@ -47,6 +48,7 @@ struct socket_t {
47
  sockfd_t fd;
48
  socket_t(sockfd_t fd) : fd(fd) {}
49
  ~socket_t() {
 
50
  #ifdef _WIN32
51
  closesocket(this->fd);
52
  #else
@@ -97,7 +99,7 @@ static ggml_guid_t ggml_backend_rpc_guid() {
97
  }
98
 
99
  struct ggml_backend_rpc_buffer_type_context {
100
- std::shared_ptr<socket_t> sock;
101
  std::string name;
102
  size_t alignment;
103
  size_t max_size;
@@ -106,8 +108,6 @@ struct ggml_backend_rpc_buffer_type_context {
106
  struct ggml_backend_rpc_context {
107
  std::string endpoint;
108
  std::string name;
109
- std::shared_ptr<socket_t> sock;
110
- ggml_backend_buffer_type_t buft;
111
  };
112
 
113
  struct ggml_backend_rpc_buffer_context {
@@ -231,14 +231,13 @@ static bool recv_data(sockfd_t sockfd, void * data, size_t size) {
231
  return true;
232
  }
233
 
234
- static bool parse_endpoint(const char * endpoint, std::string & host, int & port) {
235
- std::string str(endpoint);
236
- size_t pos = str.find(':');
237
  if (pos == std::string::npos) {
238
  return false;
239
  }
240
- host = str.substr(0, pos);
241
- port = std::stoi(str.substr(pos + 1));
242
  return true;
243
  }
244
 
@@ -273,6 +272,44 @@ static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cm
273
 
274
  // RPC client-side implementation
275
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
276
  GGML_CALL static const char * ggml_backend_rpc_buffer_get_name(ggml_backend_buffer_t buffer) {
277
  ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
278
  return ctx->name.c_str();
@@ -442,7 +479,8 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer
442
  std::vector<uint8_t> input(input_size, 0);
443
  memcpy(input.data(), &size, sizeof(size));
444
  std::vector<uint8_t> output;
445
- bool status = send_rpc_cmd(buft_ctx->sock, ALLOC_BUFFER, input, output);
 
446
  GGML_ASSERT(status);
447
  GGML_ASSERT(output.size() == 2*sizeof(uint64_t));
448
  // output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) |
@@ -453,7 +491,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer
453
  if (remote_ptr != 0) {
454
  ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft,
455
  ggml_backend_rpc_buffer_interface,
456
- new ggml_backend_rpc_buffer_context{buft_ctx->sock, {}, remote_ptr, "RPC"},
457
  remote_size);
458
  return buffer;
459
  } else {
@@ -508,7 +546,7 @@ GGML_CALL static bool ggml_backend_rpc_buffer_type_supports_backend(ggml_backend
508
  }
509
  ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
510
  ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
511
- return buft_ctx->sock == rpc_ctx->sock;
512
  }
513
 
514
  static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = {
@@ -521,7 +559,6 @@ static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = {
521
  /* .is_host = */ NULL,
522
  };
523
 
524
-
525
  GGML_CALL static const char * ggml_backend_rpc_name(ggml_backend_t backend) {
526
  ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
527
 
@@ -530,16 +567,13 @@ GGML_CALL static const char * ggml_backend_rpc_name(ggml_backend_t backend) {
530
 
531
  GGML_CALL static void ggml_backend_rpc_free(ggml_backend_t backend) {
532
  ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
533
- ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)rpc_ctx->buft->context;
534
- delete buft_ctx;
535
- delete rpc_ctx->buft;
536
  delete rpc_ctx;
537
  delete backend;
538
  }
539
 
540
  GGML_CALL static ggml_backend_buffer_type_t ggml_backend_rpc_get_default_buffer_type(ggml_backend_t backend) {
541
  ggml_backend_rpc_context * ctx = (ggml_backend_rpc_context *)backend->context;
542
- return ctx->buft;
543
  }
544
 
545
  GGML_CALL static void ggml_backend_rpc_synchronize(ggml_backend_t backend) {
@@ -590,7 +624,8 @@ GGML_CALL static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t
590
  std::vector<uint8_t> input;
591
  serialize_graph(cgraph, input);
592
  std::vector<uint8_t> output;
593
- bool status = send_rpc_cmd(rpc_ctx->sock, GRAPH_COMPUTE, input, output);
 
594
  GGML_ASSERT(status);
595
  GGML_ASSERT(output.size() == 1);
596
  return (enum ggml_status)output[0];
@@ -624,65 +659,48 @@ static ggml_backend_i ggml_backend_rpc_interface = {
624
  /* .event_synchronize = */ NULL,
625
  };
626
 
627
- static std::unordered_map<std::string, ggml_backend_t> instances;
628
-
629
  GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) {
630
- ggml_backend_t backend = ggml_backend_rpc_init(endpoint);
631
- return backend != nullptr ? ggml_backend_rpc_get_default_buffer_type(backend) : nullptr;
632
- }
633
-
634
- GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
635
- std::string endpoint_str(endpoint);
636
- if (instances.find(endpoint_str) != instances.end()) {
637
- return instances[endpoint_str];
638
- }
639
- #ifdef _WIN32
640
- {
641
- WSADATA wsaData;
642
- int res = WSAStartup(MAKEWORD(2, 2), &wsaData);
643
- if (res != 0) {
644
- return nullptr;
645
- }
646
- }
647
- #endif
648
- fprintf(stderr, "Connecting to %s\n", endpoint);
649
- std::string host;
650
- int port;
651
- if (!parse_endpoint(endpoint, host, port)) {
652
- return nullptr;
653
- }
654
- auto sock = socket_connect(host.c_str(), port);
655
  if (sock == nullptr) {
656
  return nullptr;
657
  }
658
  size_t alignment = get_alignment(sock);
659
  size_t max_size = get_max_size(sock);
660
  ggml_backend_rpc_buffer_type_context * buft_ctx = new ggml_backend_rpc_buffer_type_context {
661
- /* .sock = */ sock,
662
- /* .name = */ "RPC" + std::to_string(sock->fd),
663
  /* .alignment = */ alignment,
664
- /* .max_size = */ max_size
665
  };
666
 
667
  ggml_backend_buffer_type_t buft = new ggml_backend_buffer_type {
668
  /* .iface = */ ggml_backend_rpc_buffer_type_interface,
669
  /* .context = */ buft_ctx
670
  };
 
 
 
671
 
 
672
  ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context {
673
- /* .endpoint = */ endpoint,
674
- /* .name = */ "RPC" + std::to_string(sock->fd),
675
- /* .sock = */ sock,
676
- /* .buft = */ buft
677
  };
678
 
679
- instances[endpoint] = new ggml_backend {
680
  /* .guid = */ ggml_backend_rpc_guid(),
681
  /* .interface = */ ggml_backend_rpc_interface,
682
  /* .context = */ ctx
683
  };
684
-
685
- return instances[endpoint];
686
  }
687
 
688
  GGML_API GGML_CALL bool ggml_backend_is_rpc(ggml_backend_t backend) {
@@ -706,14 +724,13 @@ static void get_device_memory(const std::shared_ptr<socket_t> & sock, size_t * f
706
  }
707
 
708
  GGML_API GGML_CALL void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total) {
709
- ggml_backend_t backend = ggml_backend_rpc_init(endpoint);
710
- if (backend == nullptr) {
711
  *free = 0;
712
  *total = 0;
713
  return;
714
  }
715
- ggml_backend_rpc_context * ctx = (ggml_backend_rpc_context *)backend->context;
716
- get_device_memory(ctx->sock, free, total);
717
  }
718
 
719
  // RPC server-side implementation
 
6
  #include <string>
7
  #include <vector>
8
  #include <memory>
9
+ #include <mutex>
10
  #include <unordered_map>
11
  #include <unordered_set>
12
  #ifdef _WIN32
 
48
  sockfd_t fd;
49
  socket_t(sockfd_t fd) : fd(fd) {}
50
  ~socket_t() {
51
+ GGML_PRINT_DEBUG("[%s] closing socket %d\n", __func__, this->fd);
52
  #ifdef _WIN32
53
  closesocket(this->fd);
54
  #else
 
99
  }
100
 
101
  struct ggml_backend_rpc_buffer_type_context {
102
+ std::string endpoint;
103
  std::string name;
104
  size_t alignment;
105
  size_t max_size;
 
108
  struct ggml_backend_rpc_context {
109
  std::string endpoint;
110
  std::string name;
 
 
111
  };
112
 
113
  struct ggml_backend_rpc_buffer_context {
 
231
  return true;
232
  }
233
 
234
+ static bool parse_endpoint(const std::string & endpoint, std::string & host, int & port) {
235
+ size_t pos = endpoint.find(':');
 
236
  if (pos == std::string::npos) {
237
  return false;
238
  }
239
+ host = endpoint.substr(0, pos);
240
+ port = std::stoi(endpoint.substr(pos + 1));
241
  return true;
242
  }
243
 
 
272
 
273
  // RPC client-side implementation
274
 
275
+ static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
276
+ static std::mutex mutex;
277
+ std::lock_guard<std::mutex> lock(mutex);
278
+ static std::unordered_map<std::string, std::weak_ptr<socket_t>> sockets;
279
+ static bool initialized = false;
280
+
281
+ auto it = sockets.find(endpoint);
282
+ if (it != sockets.end()) {
283
+ if (auto sock = it->second.lock()) {
284
+ return sock;
285
+ }
286
+ }
287
+ std::string host;
288
+ int port;
289
+ if (!parse_endpoint(endpoint, host, port)) {
290
+ return nullptr;
291
+ }
292
+ #ifdef _WIN32
293
+ if (!initialized) {
294
+ WSADATA wsaData;
295
+ int res = WSAStartup(MAKEWORD(2, 2), &wsaData);
296
+ if (res != 0) {
297
+ return nullptr;
298
+ }
299
+ initialized = true;
300
+ }
301
+ #else
302
+ UNUSED(initialized);
303
+ #endif
304
+ auto sock = socket_connect(host.c_str(), port);
305
+ if (sock == nullptr) {
306
+ return nullptr;
307
+ }
308
+ GGML_PRINT_DEBUG("[%s] connected to %s, sockfd=%d\n", __func__, endpoint.c_str(), sock->fd);
309
+ sockets[endpoint] = sock;
310
+ return sock;
311
+ }
312
+
313
  GGML_CALL static const char * ggml_backend_rpc_buffer_get_name(ggml_backend_buffer_t buffer) {
314
  ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
315
  return ctx->name.c_str();
 
479
  std::vector<uint8_t> input(input_size, 0);
480
  memcpy(input.data(), &size, sizeof(size));
481
  std::vector<uint8_t> output;
482
+ auto sock = get_socket(buft_ctx->endpoint);
483
+ bool status = send_rpc_cmd(sock, ALLOC_BUFFER, input, output);
484
  GGML_ASSERT(status);
485
  GGML_ASSERT(output.size() == 2*sizeof(uint64_t));
486
  // output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) |
 
491
  if (remote_ptr != 0) {
492
  ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft,
493
  ggml_backend_rpc_buffer_interface,
494
+ new ggml_backend_rpc_buffer_context{sock, {}, remote_ptr, "RPC"},
495
  remote_size);
496
  return buffer;
497
  } else {
 
546
  }
547
  ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
548
  ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
549
+ return buft_ctx->endpoint == rpc_ctx->endpoint;
550
  }
551
 
552
  static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = {
 
559
  /* .is_host = */ NULL,
560
  };
561
 
 
562
  GGML_CALL static const char * ggml_backend_rpc_name(ggml_backend_t backend) {
563
  ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
564
 
 
567
 
568
  GGML_CALL static void ggml_backend_rpc_free(ggml_backend_t backend) {
569
  ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
 
 
 
570
  delete rpc_ctx;
571
  delete backend;
572
  }
573
 
574
  GGML_CALL static ggml_backend_buffer_type_t ggml_backend_rpc_get_default_buffer_type(ggml_backend_t backend) {
575
  ggml_backend_rpc_context * ctx = (ggml_backend_rpc_context *)backend->context;
576
+ return ggml_backend_rpc_buffer_type(ctx->endpoint.c_str());
577
  }
578
 
579
  GGML_CALL static void ggml_backend_rpc_synchronize(ggml_backend_t backend) {
 
624
  std::vector<uint8_t> input;
625
  serialize_graph(cgraph, input);
626
  std::vector<uint8_t> output;
627
+ auto sock = get_socket(rpc_ctx->endpoint);
628
+ bool status = send_rpc_cmd(sock, GRAPH_COMPUTE, input, output);
629
  GGML_ASSERT(status);
630
  GGML_ASSERT(output.size() == 1);
631
  return (enum ggml_status)output[0];
 
659
  /* .event_synchronize = */ NULL,
660
  };
661
 
 
 
662
  GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) {
663
+ static std::mutex mutex;
664
+ std::lock_guard<std::mutex> lock(mutex);
665
+ // NOTE: buffer types are allocated and never freed; this is by design
666
+ static std::unordered_map<std::string, ggml_backend_buffer_type_t> buft_map;
667
+ auto it = buft_map.find(endpoint);
668
+ if (it != buft_map.end()) {
669
+ return it->second;
670
+ }
671
+ auto sock = get_socket(endpoint);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
672
  if (sock == nullptr) {
673
  return nullptr;
674
  }
675
  size_t alignment = get_alignment(sock);
676
  size_t max_size = get_max_size(sock);
677
  ggml_backend_rpc_buffer_type_context * buft_ctx = new ggml_backend_rpc_buffer_type_context {
678
+ /* .endpoint = */ endpoint,
679
+ /* .name = */ "RPC[" + std::string(endpoint) + "]",
680
  /* .alignment = */ alignment,
681
+ /* .max_size = */ max_size
682
  };
683
 
684
  ggml_backend_buffer_type_t buft = new ggml_backend_buffer_type {
685
  /* .iface = */ ggml_backend_rpc_buffer_type_interface,
686
  /* .context = */ buft_ctx
687
  };
688
+ buft_map[endpoint] = buft;
689
+ return buft;
690
+ }
691
 
692
+ GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
693
  ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context {
694
+ /* .endpoint = */ endpoint,
695
+ /* .name = */ "RPC",
 
 
696
  };
697
 
698
+ ggml_backend_t backend = new ggml_backend {
699
  /* .guid = */ ggml_backend_rpc_guid(),
700
  /* .interface = */ ggml_backend_rpc_interface,
701
  /* .context = */ ctx
702
  };
703
+ return backend;
 
704
  }
705
 
706
  GGML_API GGML_CALL bool ggml_backend_is_rpc(ggml_backend_t backend) {
 
724
  }
725
 
726
  GGML_API GGML_CALL void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total) {
727
+ auto sock = get_socket(endpoint);
728
+ if (sock == nullptr) {
729
  *free = 0;
730
  *total = 0;
731
  return;
732
  }
733
+ get_device_memory(sock, free, total);
 
734
  }
735
 
736
  // RPC server-side implementation