Diego Devesa commited on
Commit
4ac768e
·
1 Parent(s): 553b278

rpc : add backend registry / device interfaces (llama/9812)

Browse files

* rpc : add backend registry / device interfaces

* llama : add llama_supports_rpc API

* ggml_backend_rpc_start_rpc_server -> ggml_backend_rpc_start_server

ggml/include/ggml-rpc.h CHANGED
@@ -17,7 +17,11 @@ GGML_API ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * en
17
 
18
  GGML_API void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total);
19
 
20
- GGML_API void start_rpc_server(ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem);
 
 
 
 
21
 
22
  #ifdef __cplusplus
23
  }
 
17
 
18
  GGML_API void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total);
19
 
20
+ GGML_API void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem);
21
+
22
+ GGML_API ggml_backend_reg_t ggml_backend_rpc_reg(void);
23
+
24
+ GGML_API ggml_backend_dev_t ggml_backend_rpc_add_device(const char * endpoint);
25
 
26
  #ifdef __cplusplus
27
  }
ggml/src/ggml-backend.cpp CHANGED
@@ -542,6 +542,10 @@ void * ggml_backend_reg_get_proc_address(ggml_backend_reg_t reg, const char * na
542
  #include "ggml-blas.h"
543
  #endif
544
 
 
 
 
 
545
  struct ggml_backend_registry {
546
  std::vector<ggml_backend_reg_t> backends;
547
  std::vector<ggml_backend_dev_t> devices;
@@ -556,6 +560,9 @@ struct ggml_backend_registry {
556
  #ifdef GGML_USE_BLAS
557
  register_backend(ggml_backend_blas_reg());
558
  #endif
 
 
 
559
 
560
  // TODO: sycl, vulkan, kompute, cann
561
 
 
542
  #include "ggml-blas.h"
543
  #endif
544
 
545
+ #ifdef GGML_USE_RPC
546
+ #include "ggml-rpc.h"
547
+ #endif
548
+
549
  struct ggml_backend_registry {
550
  std::vector<ggml_backend_reg_t> backends;
551
  std::vector<ggml_backend_dev_t> devices;
 
560
  #ifdef GGML_USE_BLAS
561
  register_backend(ggml_backend_blas_reg());
562
  #endif
563
+ #ifdef GGML_USE_RPC
564
+ register_backend(ggml_backend_rpc_reg());
565
+ #endif
566
 
567
  // TODO: sycl, vulkan, kompute, cann
568
 
ggml/src/ggml-rpc.cpp CHANGED
@@ -25,7 +25,7 @@
25
  # include <netdb.h>
26
  # include <unistd.h>
27
  #endif
28
- #include <string.h>
29
 
30
  #define UNUSED GGML_UNUSED
31
 
@@ -630,22 +630,6 @@ static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, g
630
  return (enum ggml_status)output[0];
631
  }
632
 
633
- static bool ggml_backend_rpc_supports_op(ggml_backend_t backend, const ggml_tensor * op) {
634
- UNUSED(backend);
635
- UNUSED(op);
636
- //TODO: call the remote backend and cache the results
637
- return true;
638
- }
639
-
640
- static bool ggml_backend_rpc_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
641
- if (!buft || buft->iface.get_name != ggml_backend_rpc_buffer_type_name) {
642
- return false;
643
- }
644
- ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
645
- ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
646
- return buft_ctx->endpoint == rpc_ctx->endpoint;
647
- }
648
-
649
  static ggml_backend_i ggml_backend_rpc_interface = {
650
  /* .get_name = */ ggml_backend_rpc_name,
651
  /* .free = */ ggml_backend_rpc_free,
@@ -659,8 +643,8 @@ static ggml_backend_i ggml_backend_rpc_interface = {
659
  /* .graph_plan_update = */ NULL,
660
  /* .graph_plan_compute = */ NULL,
661
  /* .graph_compute = */ ggml_backend_rpc_graph_compute,
662
- /* .supports_op = */ ggml_backend_rpc_supports_op,
663
- /* .supports_buft = */ ggml_backend_rpc_supports_buft,
664
  /* .offload_op = */ NULL,
665
  /* .event_record = */ NULL,
666
  /* .event_wait = */ NULL,
@@ -691,7 +675,7 @@ GGML_API ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * en
691
 
692
  ggml_backend_buffer_type_t buft = new ggml_backend_buffer_type {
693
  /* .iface = */ ggml_backend_rpc_buffer_type_interface,
694
- /* .device = */ nullptr,
695
  /* .context = */ buft_ctx
696
  };
697
  buft_map[endpoint] = buft;
@@ -707,7 +691,7 @@ ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
707
  ggml_backend_t backend = new ggml_backend {
708
  /* .guid = */ ggml_backend_rpc_guid(),
709
  /* .interface = */ ggml_backend_rpc_interface,
710
- /* .device = */ nullptr,
711
  /* .context = */ ctx
712
  };
713
  return backend;
@@ -1189,7 +1173,7 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
1189
  }
1190
  }
1191
 
1192
- void start_rpc_server(ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem) {
1193
  std::string host;
1194
  int port;
1195
  if (!parse_endpoint(endpoint, host, port)) {
@@ -1226,3 +1210,179 @@ void start_rpc_server(ggml_backend_t backend, const char * endpoint, size_t free
1226
  WSACleanup();
1227
  #endif
1228
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  # include <netdb.h>
26
  # include <unistd.h>
27
  #endif
28
+ #include <cstring>
29
 
30
  #define UNUSED GGML_UNUSED
31
 
 
630
  return (enum ggml_status)output[0];
631
  }
632
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
633
  static ggml_backend_i ggml_backend_rpc_interface = {
634
  /* .get_name = */ ggml_backend_rpc_name,
635
  /* .free = */ ggml_backend_rpc_free,
 
643
  /* .graph_plan_update = */ NULL,
644
  /* .graph_plan_compute = */ NULL,
645
  /* .graph_compute = */ ggml_backend_rpc_graph_compute,
646
+ /* .supports_op = */ NULL,
647
+ /* .supports_buft = */ NULL,
648
  /* .offload_op = */ NULL,
649
  /* .event_record = */ NULL,
650
  /* .event_wait = */ NULL,
 
675
 
676
  ggml_backend_buffer_type_t buft = new ggml_backend_buffer_type {
677
  /* .iface = */ ggml_backend_rpc_buffer_type_interface,
678
+ /* .device = */ ggml_backend_rpc_add_device(endpoint),
679
  /* .context = */ buft_ctx
680
  };
681
  buft_map[endpoint] = buft;
 
691
  ggml_backend_t backend = new ggml_backend {
692
  /* .guid = */ ggml_backend_rpc_guid(),
693
  /* .interface = */ ggml_backend_rpc_interface,
694
+ /* .device = */ ggml_backend_rpc_add_device(endpoint),
695
  /* .context = */ ctx
696
  };
697
  return backend;
 
1173
  }
1174
  }
1175
 
1176
+ void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem) {
1177
  std::string host;
1178
  int port;
1179
  if (!parse_endpoint(endpoint, host, port)) {
 
1210
  WSACleanup();
1211
  #endif
1212
  }
1213
+
1214
+ // device interface
1215
+
1216
+ struct ggml_backend_rpc_device_context {
1217
+ std::string endpoint;
1218
+ std::string name;
1219
+ };
1220
+
1221
+ static const char * ggml_backend_rpc_device_get_name(ggml_backend_dev_t dev) {
1222
+ ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
1223
+
1224
+ return ctx->name.c_str();
1225
+ }
1226
+
1227
+ static const char * ggml_backend_rpc_device_get_description(ggml_backend_dev_t dev) {
1228
+ ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
1229
+
1230
+ return ctx->name.c_str();
1231
+ }
1232
+
1233
+ static void ggml_backend_rpc_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
1234
+ ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
1235
+
1236
+ ggml_backend_rpc_get_device_memory(ctx->endpoint.c_str(), free, total);
1237
+
1238
+ UNUSED(dev);
1239
+ }
1240
+
1241
+ static enum ggml_backend_dev_type ggml_backend_rpc_device_get_type(ggml_backend_dev_t dev) {
1242
+ // TODO: obtain value from the server
1243
+ return GGML_BACKEND_DEVICE_TYPE_GPU_FULL;
1244
+
1245
+ UNUSED(dev);
1246
+ }
1247
+
1248
+ static void ggml_backend_rpc_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
1249
+ props->name = ggml_backend_rpc_device_get_name(dev);
1250
+ props->description = ggml_backend_rpc_device_get_description(dev);
1251
+ props->type = ggml_backend_rpc_device_get_type(dev);
1252
+ ggml_backend_rpc_device_get_memory(dev, &props->memory_free, &props->memory_total);
1253
+ props->caps = {
1254
+ /* .async = */ false,
1255
+ /* .host_buffer = */ false,
1256
+ /* .buffer_from_host_ptr = */ false,
1257
+ /* .events = */ false,
1258
+ };
1259
+ }
1260
+
1261
+ static ggml_backend_t ggml_backend_rpc_device_init(ggml_backend_dev_t dev, const char * params) {
1262
+ ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
1263
+
1264
+ return ggml_backend_rpc_init(ctx->endpoint.c_str());
1265
+
1266
+ UNUSED(params);
1267
+ }
1268
+
1269
+ static ggml_backend_buffer_type_t ggml_backend_rpc_device_get_buffer_type(ggml_backend_dev_t dev) {
1270
+ ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
1271
+
1272
+ return ggml_backend_rpc_buffer_type(ctx->endpoint.c_str());
1273
+
1274
+ UNUSED(dev);
1275
+ }
1276
+
1277
+ static ggml_backend_buffer_t ggml_backend_rpc_device_buffer_from_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) {
1278
+ return ggml_backend_cpu_buffer_from_ptr(ptr, size);
1279
+
1280
+ UNUSED(dev);
1281
+ UNUSED(max_tensor_size);
1282
+ }
1283
+
1284
+ static bool ggml_backend_rpc_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
1285
+ UNUSED(dev);
1286
+ UNUSED(op);
1287
+ //TODO: call the remote backend and cache the results
1288
+ return true;
1289
+ }
1290
+
1291
+ static bool ggml_backend_rpc_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
1292
+ if (!buft || buft->iface.get_name != ggml_backend_rpc_buffer_type_name) {
1293
+ return false;
1294
+ }
1295
+ ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
1296
+ ggml_backend_rpc_device_context * dev_ctx = (ggml_backend_rpc_device_context *)dev->context;
1297
+ return buft_ctx->endpoint == dev_ctx->endpoint;
1298
+ }
1299
+
1300
+ static const struct ggml_backend_device_i ggml_backend_rpc_device_i = {
1301
+ /* .get_name = */ ggml_backend_rpc_device_get_name,
1302
+ /* .get_description = */ ggml_backend_rpc_device_get_description,
1303
+ /* .get_memory = */ ggml_backend_rpc_device_get_memory,
1304
+ /* .get_type = */ ggml_backend_rpc_device_get_type,
1305
+ /* .get_props = */ ggml_backend_rpc_device_get_props,
1306
+ /* .init_backend = */ ggml_backend_rpc_device_init,
1307
+ /* .get_buffer_type = */ ggml_backend_rpc_device_get_buffer_type,
1308
+ /* .get_host_buffer_type = */ NULL,
1309
+ /* .buffer_from_host_ptr = */ ggml_backend_rpc_device_buffer_from_ptr,
1310
+ /* .supports_op = */ ggml_backend_rpc_device_supports_op,
1311
+ /* .supports_buft = */ ggml_backend_rpc_device_supports_buft,
1312
+ /* .offload_op = */ NULL,
1313
+ /* .event_new = */ NULL,
1314
+ /* .event_free = */ NULL,
1315
+ /* .event_synchronize = */ NULL,
1316
+ };
1317
+
1318
+ // backend reg interface
1319
+
1320
+ static const char * ggml_backend_rpc_reg_get_name(ggml_backend_reg_t reg) {
1321
+ return "RPC";
1322
+
1323
+ UNUSED(reg);
1324
+ }
1325
+
1326
+ static size_t ggml_backend_rpc_reg_get_device_count(ggml_backend_reg_t reg) {
1327
+ return 0;
1328
+
1329
+ UNUSED(reg);
1330
+ }
1331
+
1332
+ static ggml_backend_dev_t ggml_backend_rpc_reg_get_device(ggml_backend_reg_t reg, size_t index) {
1333
+ GGML_ABORT("The RPC backend does not have enumerated devices - use ggml_backend_add_device instead");
1334
+
1335
+ UNUSED(reg);
1336
+ UNUSED(index);
1337
+ }
1338
+
1339
+ static void * ggml_backend_rpc_get_proc_address(ggml_backend_reg_t reg, const char * name) {
1340
+ if (std::strcmp(name, "ggml_backend_rpc_add_device") == 0) {
1341
+ return (void *)ggml_backend_rpc_add_device;
1342
+ }
1343
+ return NULL;
1344
+
1345
+ UNUSED(reg);
1346
+ }
1347
+
1348
+ static const struct ggml_backend_reg_i ggml_backend_rpc_reg_i = {
1349
+ /* .get_name = */ ggml_backend_rpc_reg_get_name,
1350
+ /* .get_device_count = */ ggml_backend_rpc_reg_get_device_count,
1351
+ /* .get_device = */ ggml_backend_rpc_reg_get_device,
1352
+ /* .get_proc_address = */ ggml_backend_rpc_get_proc_address,
1353
+ };
1354
+
1355
+ ggml_backend_reg_t ggml_backend_rpc_reg(void) {
1356
+ static struct ggml_backend_reg ggml_backend_rpc_reg = {
1357
+ /* .iface = */ ggml_backend_rpc_reg_i,
1358
+ /* .context = */ NULL,
1359
+ };
1360
+
1361
+ return &ggml_backend_rpc_reg;
1362
+ }
1363
+
1364
+ ggml_backend_dev_t ggml_backend_rpc_add_device(const char * endpoint) {
1365
+ static std::unordered_map<std::string, ggml_backend_dev_t> dev_map;
1366
+
1367
+ static std::mutex mutex;
1368
+ std::lock_guard<std::mutex> lock(mutex);
1369
+
1370
+ if (dev_map.find(endpoint) != dev_map.end()) {
1371
+ return dev_map[endpoint];
1372
+ }
1373
+
1374
+ ggml_backend_rpc_device_context * ctx = new ggml_backend_rpc_device_context {
1375
+ /* .endpoint = */ endpoint,
1376
+ /* .name = */ "RPC[" + std::string(endpoint) + "]",
1377
+ };
1378
+
1379
+ ggml_backend_dev_t dev = new ggml_backend_device {
1380
+ /* .iface = */ ggml_backend_rpc_device_i,
1381
+ /* .reg = */ ggml_backend_rpc_reg(),
1382
+ /* .context = */ ctx,
1383
+ };
1384
+
1385
+ dev_map[endpoint] = dev;
1386
+
1387
+ return dev;
1388
+ }