ggerganov slaren commited on
Commit
bfa5a95
·
1 Parent(s): 1b0dec0

whisper : use ggml_backend_sched (#2239)

Browse files

* whisper : use ggml_backend_sched (wip)

* use sched in whisper_allocr

* whisper : single backend in whisper_context

* whisper : remove whisper_state->backends_used

* whisper : remove whisper_context->backend

* whisper : reset scheduler after init

* whisper : fix external encoder (e.g. CoreML)

* whisper : cleanup

* whisper : handle null GPU buffer types + fix sycl

---------

Co-authored-by: slaren <[email protected]>

Files changed (3) hide show
  1. ggml-backend.c +13 -2
  2. ggml-backend.h +3 -0
  3. whisper.cpp +173 -108
ggml-backend.c CHANGED
@@ -1706,14 +1706,16 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
1706
  static bool ggml_backend_sched_alloc_splits(ggml_backend_sched_t sched) {
1707
  bool backend_ids_changed = false;
1708
  for (int i = 0; i < sched->graph->n_nodes; i++) {
1709
- if (sched->node_backend_ids[i] != sched->prev_node_backend_ids[i]) {
 
1710
  backend_ids_changed = true;
1711
  break;
1712
  }
1713
  }
1714
  if (!backend_ids_changed) {
1715
  for (int i = 0; i < sched->graph->n_leafs; i++) {
1716
- if (sched->leaf_backend_ids[i] != sched->prev_leaf_backend_ids[i]) {
 
1717
  backend_ids_changed = true;
1718
  break;
1719
  }
@@ -1977,6 +1979,15 @@ int ggml_backend_sched_get_n_copies(ggml_backend_sched_t sched) {
1977
  return sched->n_copies;
1978
  }
1979
 
 
 
 
 
 
 
 
 
 
1980
  size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend) {
1981
  int backend_index = ggml_backend_sched_backend_id(sched, backend);
1982
  GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends);
 
1706
  static bool ggml_backend_sched_alloc_splits(ggml_backend_sched_t sched) {
1707
  bool backend_ids_changed = false;
1708
  for (int i = 0; i < sched->graph->n_nodes; i++) {
1709
+ if (sched->node_backend_ids[i] != sched->prev_node_backend_ids[i] &&
1710
+ sched->bufts[sched->node_backend_ids[i]] != sched->bufts[sched->prev_node_backend_ids[i]]) {
1711
  backend_ids_changed = true;
1712
  break;
1713
  }
1714
  }
1715
  if (!backend_ids_changed) {
1716
  for (int i = 0; i < sched->graph->n_leafs; i++) {
1717
+ if (sched->leaf_backend_ids[i] != sched->prev_leaf_backend_ids[i] &&
1718
+ sched->bufts[sched->leaf_backend_ids[i]] != sched->bufts[sched->prev_leaf_backend_ids[i]]) {
1719
  backend_ids_changed = true;
1720
  break;
1721
  }
 
1979
  return sched->n_copies;
1980
  }
1981
 
1982
+ int ggml_backend_sched_get_n_backends(ggml_backend_sched_t sched) {
1983
+ return sched->n_backends;
1984
+ }
1985
+
1986
+ ggml_backend_t ggml_backend_sched_get_backend(ggml_backend_sched_t sched, int i) {
1987
+ GGML_ASSERT(i >= 0 && i < sched->n_backends);
1988
+ return sched->backends[i];
1989
+ }
1990
+
1991
  size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend) {
1992
  int backend_index = ggml_backend_sched_backend_id(sched, backend);
1993
  GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends);
ggml-backend.h CHANGED
@@ -182,6 +182,9 @@ extern "C" {
182
  // Initialize backend buffers from a measure graph
183
  GGML_API bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph);
184
 
 
 
 
185
  // Get the number of splits of the last graph
186
  GGML_API int ggml_backend_sched_get_n_splits(ggml_backend_sched_t sched);
187
  GGML_API int ggml_backend_sched_get_n_copies(ggml_backend_sched_t sched);
 
182
  // Initialize backend buffers from a measure graph
183
  GGML_API bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph);
184
 
185
+ GGML_API int ggml_backend_sched_get_n_backends(ggml_backend_sched_t sched);
186
+ GGML_API ggml_backend_t ggml_backend_sched_get_backend(ggml_backend_sched_t sched, int i);
187
+
188
  // Get the number of splits of the last graph
189
  GGML_API int ggml_backend_sched_get_n_splits(ggml_backend_sched_t sched);
190
  GGML_API int ggml_backend_sched_get_n_copies(ggml_backend_sched_t sched);
whisper.cpp CHANGED
@@ -17,6 +17,10 @@
17
  #include "ggml-sycl.h"
18
  #endif
19
 
 
 
 
 
20
  #ifdef WHISPER_USE_OPENVINO
21
  #include "openvino/whisper-openvino-encoder.h"
22
  #endif
@@ -179,18 +183,30 @@ static bool ggml_graph_compute_helper(
179
  }
180
 
181
  static bool ggml_graph_compute_helper(
182
- struct ggml_backend * backend,
183
  struct ggml_cgraph * graph,
184
  int n_threads) {
185
- if (ggml_backend_is_cpu(backend)) {
186
- ggml_backend_cpu_set_n_threads(backend, n_threads);
187
- }
 
 
 
 
 
 
 
 
188
  #ifdef GGML_USE_METAL
189
- if (ggml_backend_is_metal(backend)) {
190
- ggml_backend_metal_set_n_cb(backend, n_threads);
191
- }
192
  #endif
193
- return ggml_backend_graph_compute(backend, graph) == GGML_STATUS_SUCCESS;
 
 
 
 
194
  }
195
 
196
  // faster matrix multiplications for tensors that do not have dimension 0 divisible by "pad"
@@ -490,33 +506,41 @@ struct whisper_pair {
490
  whisper_pair() : first(A()), second(B()) {}
491
  };
492
 
493
- // ggml_allocr wrapper for whisper usage
494
- struct whisper_allocr {
495
- ggml_gallocr_t alloc = nullptr;
496
 
497
  std::vector<uint8_t> meta;
498
  };
499
 
500
- static size_t whisper_allocr_size(struct whisper_allocr & allocr) {
501
- return allocr.meta.size() + ggml_gallocr_get_buffer_size(allocr.alloc, 0);
 
 
 
 
 
502
  }
503
 
504
  // measure the memory usage of a graph and prepare the allocr's internal data buffer
505
- static bool whisper_allocr_graph_init(struct whisper_allocr & allocr, ggml_backend_t backend, std::function<struct ggml_cgraph *()> && get_graph) {
506
- auto & alloc = allocr.alloc;
507
  auto & meta = allocr.meta;
508
 
509
- alloc = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend));
510
 
511
  meta.resize(ggml_tensor_overhead()*WHISPER_MAX_NODES + ggml_graph_overhead());
512
 
513
  // since there are dependencies between the different graphs,
514
  // we need to allocate them instead of only reserving to get the correct compute buffer size
515
- if (!ggml_gallocr_alloc_graph(alloc, get_graph())) {
516
  // failed to allocate the compute buffer
517
  WHISPER_LOG_ERROR("%s: failed to allocate the compute buffer\n", __func__);
518
  return false;
519
  }
 
 
 
520
  return true;
521
  }
522
 
@@ -808,15 +832,13 @@ struct whisper_state {
808
 
809
  whisper_decoder decoders[WHISPER_MAX_DECODERS];
810
 
811
- ggml_backend_t backend = nullptr;
812
 
813
- // ggml-alloc:
814
  // - stores meta info about the intermediate tensors into the `meta` buffers
815
- // - stores the actual tensor data into the `data` buffers
816
- whisper_allocr alloc_conv;
817
- whisper_allocr alloc_encode;
818
- whisper_allocr alloc_cross;
819
- whisper_allocr alloc_decode;
820
 
821
  // result of the encoder
822
  struct ggml_tensor * embd_conv = nullptr;
@@ -874,8 +896,6 @@ struct whisper_context {
874
 
875
  whisper_state * state = nullptr;
876
 
877
- ggml_backend_t backend = nullptr;
878
-
879
  std::string path_model; // populated by whisper_init_from_file_with_params()
880
  };
881
 
@@ -1061,20 +1081,16 @@ static void whisper_kv_cache_seq_cp(
1061
  }
1062
 
1063
  static uint32_t whisper_kv_cache_get_padding(const struct whisper_context & wctx) {
1064
- if (!wctx.params.flash_attn) {
1065
  return 1u;
1066
  }
1067
 
1068
  #ifdef GGML_USE_METAL
1069
- if (ggml_backend_is_metal(wctx.backend)) {
1070
- return 32u;
1071
- }
1072
  #endif
1073
 
1074
  #ifdef GGML_USE_CUDA
1075
- if (ggml_backend_is_cuda(wctx.backend)) {
1076
- return 256u;
1077
- }
1078
  #endif
1079
 
1080
  return 1u;
@@ -1211,15 +1227,14 @@ static size_t aheads_masks_nbytes(struct whisper_aheads_masks & aheads_masks) {
1211
  return size;
1212
  }
1213
 
1214
- static ggml_backend_t whisper_backend_init(const whisper_context_params & params) {
1215
- ggml_backend_t backend_gpu = NULL;
1216
 
1217
- // initialize the backends
1218
  #ifdef GGML_USE_CUDA
1219
  if (params.use_gpu) {
1220
  WHISPER_LOG_INFO("%s: using CUDA backend\n", __func__);
1221
- backend_gpu = ggml_backend_cuda_init(params.gpu_device);
1222
- if (!backend_gpu) {
1223
  WHISPER_LOG_ERROR("%s: ggml_backend_cuda_init() failed\n", __func__);
1224
  }
1225
  }
@@ -1229,13 +1244,13 @@ static ggml_backend_t whisper_backend_init(const whisper_context_params & params
1229
  if (params.use_gpu) {
1230
  WHISPER_LOG_INFO("%s: using Metal backend\n", __func__);
1231
  ggml_backend_metal_log_set_callback(g_state.log_callback, g_state.log_callback_user_data);
1232
- backend_gpu = ggml_backend_metal_init();
1233
- if (!backend_gpu) {
1234
  WHISPER_LOG_ERROR("%s: ggml_backend_metal_init() failed\n", __func__);
1235
- } else if (!ggml_backend_metal_supports_family(backend_gpu, 7)) {
1236
  WHISPER_LOG_ERROR("%s: Metal GPU does not support family 7 - falling back to CPU\n", __func__);
1237
- ggml_backend_free(backend_gpu);
1238
- backend_gpu = NULL;
1239
  }
1240
  }
1241
  #endif
@@ -1243,20 +1258,64 @@ static ggml_backend_t whisper_backend_init(const whisper_context_params & params
1243
  #ifdef GGML_USE_SYCL
1244
  if (params.use_gpu) {
1245
  WHISPER_LOG_INFO("%s: using SYCL backend\n", __func__);
1246
- backend_gpu = ggml_backend_sycl_init(params.gpu_device);
1247
- if (!backend_gpu) {
1248
  WHISPER_LOG_ERROR("%s: ggml_backend_sycl_init() failed\n", __func__);
1249
  }
1250
  }
1251
  #endif
1252
 
1253
- GGML_UNUSED(params);
 
 
 
 
 
 
1254
 
1255
  if (backend_gpu) {
1256
- return backend_gpu;
 
 
 
 
 
 
 
 
 
 
 
1257
  }
 
 
 
 
 
 
 
 
 
 
 
1258
 
1259
- return ggml_backend_cpu_init();
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1260
  }
1261
 
1262
  // load the model from a ggml file
@@ -1683,21 +1742,15 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
1683
  }
1684
  }
1685
 
1686
- wctx.backend = whisper_backend_init(wctx.params);
1687
- if (!wctx.backend) {
1688
- WHISPER_LOG_ERROR("%s: failed to initialize the backend\n", __func__);
1689
- return false;
1690
- }
1691
-
1692
  // allocate tensors in the backend buffers
1693
- model.buffer = ggml_backend_alloc_ctx_tensors(model.ctx, wctx.backend);
1694
  if (!model.buffer) {
1695
  WHISPER_LOG_ERROR("%s: failed to allocate memory for the model\n", __func__);
1696
  return false;
1697
  }
1698
 
1699
  size_t size_main = ggml_backend_buffer_get_size(model.buffer);
1700
- WHISPER_LOG_INFO("%s: %8s total size = %8.2f MB\n", __func__, ggml_backend_name(wctx.backend), size_main / 1e6);
1701
 
1702
  // load weights
1703
  {
@@ -1792,6 +1845,8 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
1792
  }
1793
  }
1794
 
 
 
1795
  wctx.t_load_us = ggml_time_us() - t_start_us;
1796
 
1797
  return true;
@@ -1828,8 +1883,8 @@ static struct ggml_cgraph * whisper_build_graph_conv(
1828
  const int n_mels = hparams.n_mels;
1829
 
1830
  struct ggml_init_params params = {
1831
- /*.mem_size =*/ wstate.alloc_conv.meta.size(),
1832
- /*.mem_buffer =*/ wstate.alloc_conv.meta.data(),
1833
  /*.no_alloc =*/ true,
1834
  };
1835
 
@@ -1837,9 +1892,13 @@ static struct ggml_cgraph * whisper_build_graph_conv(
1837
 
1838
  ggml_cgraph * gf = ggml_new_graph(ctx0);
1839
 
 
 
1840
  ggml_tensor * mel_inp = wstate.mel.tensor;
 
 
1841
  ggml_tensor * mel;
1842
- if (mel_inp) {
1843
  const int n_len = int(mel_inp->ne[0]);
1844
  const int out_s = 2 * n_ctx;
1845
  const int i0 = std::min(mel_offset, n_len);
@@ -1853,16 +1912,12 @@ static struct ggml_cgraph * whisper_build_graph_conv(
1853
 
1854
  if (mel_s < out_s) {
1855
  mel = ggml_pad(ctx0, cur, out_s - mel_s, 0, 0, 0);
1856
- }
1857
- else {
1858
  mel = ggml_cont(ctx0, cur);
1859
  }
1860
  }
1861
- else {
1862
- // just create some tensor so that the graph/buffer size estimation is correct
1863
- mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2 * n_ctx, n_mels);
1864
- }
1865
- ggml_set_name(mel, "mel"); // used with external encoding
1866
 
1867
  struct ggml_tensor * cur = nullptr;
1868
 
@@ -1886,6 +1941,7 @@ static struct ggml_cgraph * whisper_build_graph_conv(
1886
  ggml_build_forward_expand(gf, mel);
1887
 
1888
  cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx);
 
1889
 
1890
  ggml_set_name(cur, "embd_enc");
1891
  wstate.embd_enc = cur;
@@ -1920,8 +1976,8 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
1920
  const int n_ctx_pad = GGML_PAD(n_ctx, 256);
1921
 
1922
  struct ggml_init_params params = {
1923
- /*.mem_size =*/ wstate.alloc_encode.meta.size(),
1924
- /*.mem_buffer =*/ wstate.alloc_encode.meta.data(),
1925
  /*.no_alloc =*/ true,
1926
  };
1927
 
@@ -2160,8 +2216,8 @@ static struct ggml_cgraph * whisper_build_graph_cross(
2160
  const int n_ctx_pad = GGML_PAD(n_ctx, 256);
2161
 
2162
  struct ggml_init_params params = {
2163
- /*.mem_size =*/ wstate.alloc_cross.meta.size(),
2164
- /*.mem_buffer =*/ wstate.alloc_cross.meta.data(),
2165
  /*.no_alloc =*/ true,
2166
  };
2167
 
@@ -2242,16 +2298,16 @@ static bool whisper_encode_internal(
2242
 
2243
  // conv
2244
  {
2245
- auto & alloc = wstate.alloc_conv.alloc;
2246
 
2247
  ggml_cgraph * gf = whisper_build_graph_conv(wctx, wstate, mel_offset);
2248
 
2249
- if (!ggml_gallocr_alloc_graph(alloc, gf)) {
2250
  // should never happen as we pre-allocate the memory
2251
  return false;
2252
  }
2253
 
2254
- if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) {
2255
  return false;
2256
  }
2257
 
@@ -2269,32 +2325,32 @@ static bool whisper_encode_internal(
2269
 
2270
  // encoder
2271
  if (!whisper_encode_external(wstate)) {
2272
- auto & alloc = wstate.alloc_encode.alloc;
2273
 
2274
  ggml_cgraph * gf = whisper_build_graph_encoder(wctx, wstate);
2275
 
2276
- if (!ggml_gallocr_alloc_graph(alloc, gf)) {
2277
  // should never happen as we pre-allocate the memory
2278
  return false;
2279
  }
2280
 
2281
- if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) {
2282
  return false;
2283
  }
2284
  }
2285
 
2286
  // cross
2287
  {
2288
- auto & alloc = wstate.alloc_cross.alloc;
2289
 
2290
  ggml_cgraph * gf = whisper_build_graph_cross(wctx, wstate);
2291
 
2292
- if (!ggml_gallocr_alloc_graph(alloc, gf)) {
2293
  // should never happen as we pre-allocate the memory
2294
  return false;
2295
  }
2296
 
2297
- if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) {
2298
  return false;
2299
  }
2300
  }
@@ -2336,8 +2392,8 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
2336
  //WHISPER_LOG_DEBUG("%s: n_past = %d, n_tokens = %d, n_audio_ctx = %d, n_ctx = %d\n", __func__, n_past, n_tokens, n_audio_ctx, n_ctx);
2337
 
2338
  struct ggml_init_params params = {
2339
- /*.mem_size =*/ wstate.alloc_decode.meta.size(),
2340
- /*.mem_buffer =*/ wstate.alloc_decode.meta.data(),
2341
  /*.no_alloc =*/ true,
2342
  };
2343
 
@@ -2736,11 +2792,11 @@ static bool whisper_decode_internal(
2736
 
2737
  // decoder
2738
  {
2739
- auto & alloc = wstate.alloc_decode.alloc;
2740
 
2741
  ggml_cgraph * gf = whisper_build_graph_decoder(wctx, wstate, batch, save_alignment_heads_QKs, false);
2742
 
2743
- if (!ggml_gallocr_alloc_graph(alloc, gf)) {
2744
  // should never happen as we pre-allocate the memory
2745
  return false;
2746
  }
@@ -2795,7 +2851,7 @@ static bool whisper_decode_internal(
2795
 
2796
  logits = gf->nodes[gf->n_nodes - 1];
2797
 
2798
- if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) {
2799
  return false;
2800
  }
2801
  }
@@ -3299,20 +3355,29 @@ static std::string whisper_openvino_get_path_cache(std::string path_bin) {
3299
  struct whisper_state * whisper_init_state(whisper_context * ctx) {
3300
  whisper_state * state = new whisper_state;
3301
 
3302
- state->backend = whisper_backend_init(ctx->params);
3303
- if (!state->backend) {
3304
  WHISPER_LOG_ERROR("%s: whisper_backend_init() failed\n", __func__);
3305
  whisper_free_state(state);
3306
  return nullptr;
3307
  }
3308
 
3309
- state->mel_calc = whisper_mel_calc_create(state->backend, ctx->model.filters);
 
 
 
 
 
 
 
 
 
3310
 
3311
  // at this point, we don't know yet how many decoders will be used, so we overallocate 3x ctx
3312
  // in theory, there can be a case where this is not enough, but in practice it should always be enough
3313
  const int factor = 3;
3314
 
3315
- if (!whisper_kv_cache_init(state->kv_self, state->backend, ctx->itype,
3316
  ctx->model.hparams.n_text_state,
3317
  ctx->model.hparams.n_text_layer,
3318
  GGML_PAD(ctx->model.hparams.n_text_ctx, 256)*factor)) {
@@ -3326,7 +3391,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
3326
  WHISPER_LOG_INFO("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1e6);
3327
  }
3328
 
3329
- if (!whisper_kv_cache_init(state->kv_cross, state->backend, ctx->itype,
3330
  ctx->model.hparams.n_text_state,
3331
  ctx->model.hparams.n_text_layer,
3332
  GGML_PAD(ctx->model.hparams.n_audio_ctx, 256))) {
@@ -3340,7 +3405,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
3340
  WHISPER_LOG_INFO("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1e6);
3341
  }
3342
 
3343
- if (!whisper_kv_cache_init(state->kv_pad, state->backend, ctx->itype,
3344
  ctx->model.hparams.n_audio_state,
3345
  1,
3346
  GGML_PAD(ctx->model.hparams.n_audio_ctx, 256))) {
@@ -3356,7 +3421,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
3356
 
3357
  // [EXPERIMENTAL] Token-level timestamps with DTW
3358
  if (ctx->params.dtw_token_timestamps) {
3359
- if (!aheads_masks_init(ctx->params, ctx->model.hparams, state->aheads_masks, state->backend)) {
3360
  WHISPER_LOG_ERROR("%s: aheads_masks_init() failed for alignment heads masks\n", __func__);
3361
  whisper_free_state(state);
3362
  return nullptr;
@@ -3399,7 +3464,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
3399
 
3400
  // conv allocator
3401
  {
3402
- bool ok = whisper_allocr_graph_init(state->alloc_conv, state->backend,
3403
  [&]() {
3404
  return whisper_build_graph_conv(*ctx, *state, 0);
3405
  });
@@ -3410,12 +3475,12 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
3410
  return nullptr;
3411
  }
3412
 
3413
- WHISPER_LOG_INFO("%s: compute buffer (conv) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_conv) / 1e6);
3414
  }
3415
 
3416
  // encoder allocator
3417
  if (!whisper_encode_external(*state)) {
3418
- bool ok = whisper_allocr_graph_init(state->alloc_encode, state->backend,
3419
  [&]() {
3420
  return whisper_build_graph_encoder(*ctx, *state);
3421
  });
@@ -3426,12 +3491,12 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
3426
  return nullptr;
3427
  }
3428
 
3429
- WHISPER_LOG_INFO("%s: compute buffer (encode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_encode) / 1e6);
3430
  }
3431
 
3432
  // cross allocator
3433
  {
3434
- bool ok = whisper_allocr_graph_init(state->alloc_cross, state->backend,
3435
  [&]() {
3436
  return whisper_build_graph_cross(*ctx, *state);
3437
  });
@@ -3442,12 +3507,12 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
3442
  return nullptr;
3443
  }
3444
 
3445
- WHISPER_LOG_INFO("%s: compute buffer (cross) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_cross) / 1e6);
3446
  }
3447
 
3448
  // decoder allocator
3449
  {
3450
- bool ok = whisper_allocr_graph_init(state->alloc_decode, state->backend,
3451
  [&]() {
3452
  const auto & hparams = ctx->model.hparams;
3453
 
@@ -3466,7 +3531,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
3466
  return nullptr;
3467
  }
3468
 
3469
- WHISPER_LOG_INFO("%s: compute buffer (decode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_decode) / 1e6);
3470
  }
3471
 
3472
  return state;
@@ -3746,12 +3811,14 @@ void whisper_free_state(struct whisper_state * state) {
3746
 
3747
  whisper_batch_free(state->batch);
3748
 
3749
- ggml_gallocr_free(state->alloc_conv.alloc);
3750
- ggml_gallocr_free(state->alloc_encode.alloc);
3751
- ggml_gallocr_free(state->alloc_cross.alloc);
3752
- ggml_gallocr_free(state->alloc_decode.alloc);
3753
 
3754
- ggml_backend_free(state->backend);
 
 
3755
 
3756
  // [EXPERIMENTAL] Token-level timestamps with DTW
3757
  aheads_masks_free(state->aheads_masks);
@@ -3768,8 +3835,6 @@ void whisper_free(struct whisper_context * ctx) {
3768
 
3769
  whisper_free_state(ctx->state);
3770
 
3771
- ggml_backend_free(ctx->backend);
3772
-
3773
  delete ctx;
3774
  }
3775
  }
@@ -3800,7 +3865,7 @@ int whisper_pcm_to_mel_with_state(struct whisper_context * ctx, struct whisper_s
3800
  // 2. the time to transcribe audios this long will be dominated by the decoding time, so the mel calculation
3801
  // taking longer is not a major concern
3802
  if (!state->mel_calc_fallback) {
3803
- state->mel_calc_fallback = new mel_calc_cpu(state->backend, ctx->model.filters);
3804
  }
3805
  state->mel = state->mel_calc_fallback->calculate({samples, n_samples}, n_threads);
3806
  }
@@ -3837,7 +3902,7 @@ int whisper_set_mel_with_state(
3837
  }
3838
 
3839
  whisper_mel_free(state->mel);
3840
- whisper_mel_init(state->mel, ctx->backend, n_len, n_len, n_mel);
3841
 
3842
  ggml_backend_tensor_set(state->mel.tensor, data, 0, ggml_nbytes(state->mel.tensor));
3843
 
 
17
  #include "ggml-sycl.h"
18
  #endif
19
 
20
+ #ifdef GGML_USE_BLAS
21
+ #include "ggml-blas.h"
22
+ #endif
23
+
24
  #ifdef WHISPER_USE_OPENVINO
25
  #include "openvino/whisper-openvino-encoder.h"
26
  #endif
 
183
  }
184
 
185
  static bool ggml_graph_compute_helper(
186
+ ggml_backend_sched_t sched,
187
  struct ggml_cgraph * graph,
188
  int n_threads) {
189
+
190
+ for (int i = 0; i < ggml_backend_sched_get_n_backends(sched); ++i) {
191
+ ggml_backend_t backend = ggml_backend_sched_get_backend(sched, i);
192
+ if (ggml_backend_is_cpu(backend)) {
193
+ ggml_backend_cpu_set_n_threads(backend, n_threads);
194
+ }
195
+ #ifdef GGML_USE_BLAS
196
+ if (ggml_backend_is_blas(backend)) {
197
+ ggml_backend_blas_set_n_threads(backend, n_threads);
198
+ }
199
+ #endif
200
  #ifdef GGML_USE_METAL
201
+ if (ggml_backend_is_metal(backend)) {
202
+ ggml_backend_metal_set_n_cb(backend, n_threads);
203
+ }
204
  #endif
205
+ }
206
+
207
+ bool t = ggml_backend_sched_graph_compute(sched, graph) == GGML_STATUS_SUCCESS;
208
+ ggml_backend_sched_reset(sched);
209
+ return t;
210
  }
211
 
212
  // faster matrix multiplications for tensors that do not have dimension 0 divisible by "pad"
 
506
  whisper_pair() : first(A()), second(B()) {}
507
  };
508
 
509
+ // ggml_backend_sched wrapper for whisper usage
510
+ struct whisper_sched {
511
+ ggml_backend_sched_t sched = nullptr;
512
 
513
  std::vector<uint8_t> meta;
514
  };
515
 
516
+ static size_t whisper_sched_size(struct whisper_sched & allocr) {
517
+ size_t size = allocr.meta.size();
518
+ for (int i = 0; i < ggml_backend_sched_get_n_backends(allocr.sched); ++i) {
519
+ ggml_backend_t backend = ggml_backend_sched_get_backend(allocr.sched, i);
520
+ size += ggml_backend_sched_get_buffer_size(allocr.sched, backend);
521
+ }
522
+ return size;
523
  }
524
 
525
  // measure the memory usage of a graph and prepare the allocr's internal data buffer
526
+ static bool whisper_sched_graph_init(struct whisper_sched & allocr, std::vector<ggml_backend_t> backends, std::function<struct ggml_cgraph *()> && get_graph) {
527
+ auto & sched = allocr.sched;
528
  auto & meta = allocr.meta;
529
 
530
+ sched = ggml_backend_sched_new(backends.data(), nullptr, backends.size(), WHISPER_MAX_NODES, false);
531
 
532
  meta.resize(ggml_tensor_overhead()*WHISPER_MAX_NODES + ggml_graph_overhead());
533
 
534
  // since there are dependencies between the different graphs,
535
  // we need to allocate them instead of only reserving to get the correct compute buffer size
536
+ if (!ggml_backend_sched_alloc_graph(sched, get_graph())) {
537
  // failed to allocate the compute buffer
538
  WHISPER_LOG_ERROR("%s: failed to allocate the compute buffer\n", __func__);
539
  return false;
540
  }
541
+
542
+ ggml_backend_sched_reset(sched);
543
+
544
  return true;
545
  }
546
 
 
832
 
833
  whisper_decoder decoders[WHISPER_MAX_DECODERS];
834
 
835
+ std::vector<ggml_backend_t> backends;
836
 
 
837
  // - stores meta info about the intermediate tensors into the `meta` buffers
838
+ whisper_sched sched_conv;
839
+ whisper_sched sched_encode;
840
+ whisper_sched sched_cross;
841
+ whisper_sched sched_decode;
 
842
 
843
  // result of the encoder
844
  struct ggml_tensor * embd_conv = nullptr;
 
896
 
897
  whisper_state * state = nullptr;
898
 
 
 
899
  std::string path_model; // populated by whisper_init_from_file_with_params()
900
  };
901
 
 
1081
  }
1082
 
1083
  static uint32_t whisper_kv_cache_get_padding(const struct whisper_context & wctx) {
1084
+ if (!wctx.params.flash_attn || !wctx.params.use_gpu) {
1085
  return 1u;
1086
  }
1087
 
1088
  #ifdef GGML_USE_METAL
1089
+ return 32u;
 
 
1090
  #endif
1091
 
1092
  #ifdef GGML_USE_CUDA
1093
+ return 256u;
 
 
1094
  #endif
1095
 
1096
  return 1u;
 
1227
  return size;
1228
  }
1229
 
1230
+ static ggml_backend_t whisper_backend_init_gpu(const whisper_context_params & params) {
1231
+ ggml_backend_t result = NULL;
1232
 
 
1233
  #ifdef GGML_USE_CUDA
1234
  if (params.use_gpu) {
1235
  WHISPER_LOG_INFO("%s: using CUDA backend\n", __func__);
1236
+ result = ggml_backend_cuda_init(params.gpu_device);
1237
+ if (!result) {
1238
  WHISPER_LOG_ERROR("%s: ggml_backend_cuda_init() failed\n", __func__);
1239
  }
1240
  }
 
1244
  if (params.use_gpu) {
1245
  WHISPER_LOG_INFO("%s: using Metal backend\n", __func__);
1246
  ggml_backend_metal_log_set_callback(g_state.log_callback, g_state.log_callback_user_data);
1247
+ result = ggml_backend_metal_init();
1248
+ if (!result) {
1249
  WHISPER_LOG_ERROR("%s: ggml_backend_metal_init() failed\n", __func__);
1250
+ } else if (!ggml_backend_metal_supports_family(result, 7)) {
1251
  WHISPER_LOG_ERROR("%s: Metal GPU does not support family 7 - falling back to CPU\n", __func__);
1252
+ ggml_backend_free(result);
1253
+ result = NULL;
1254
  }
1255
  }
1256
  #endif
 
1258
  #ifdef GGML_USE_SYCL
1259
  if (params.use_gpu) {
1260
  WHISPER_LOG_INFO("%s: using SYCL backend\n", __func__);
1261
+ result = ggml_backend_sycl_init(params.gpu_device);
1262
+ if (!result) {
1263
  WHISPER_LOG_ERROR("%s: ggml_backend_sycl_init() failed\n", __func__);
1264
  }
1265
  }
1266
  #endif
1267
 
1268
+ return result;
1269
+ }
1270
+
1271
+ static std::vector<ggml_backend_t> whisper_backend_init(const whisper_context_params & params) {
1272
+ std::vector<ggml_backend_t> result;
1273
+
1274
+ ggml_backend_t backend_gpu = whisper_backend_init_gpu(params);
1275
 
1276
  if (backend_gpu) {
1277
+ result.push_back(backend_gpu);
1278
+ }
1279
+
1280
+ #ifdef GGML_USE_BLAS
1281
+ {
1282
+ WHISPER_LOG_INFO("%s: using BLAS backend\n", __func__);
1283
+ ggml_backend_t backend_blas = ggml_backend_blas_init();
1284
+ if (!backend_blas) {
1285
+ WHISPER_LOG_ERROR("%s: ggml_backend_blas_init() failed\n", __func__);
1286
+ } else {
1287
+ result.push_back(backend_blas);
1288
+ }
1289
  }
1290
+ #endif
1291
+
1292
+ GGML_UNUSED(params);
1293
+
1294
+ result.push_back(ggml_backend_cpu_init());
1295
+
1296
+ return result;
1297
+ }
1298
+
1299
+ static ggml_backend_buffer_type_t whisper_default_buffer_type(const whisper_context_params & params) {
1300
+ ggml_backend_buffer_type_t result = nullptr;
1301
 
1302
+ params.use_gpu || (result = ggml_backend_cpu_buffer_type());
1303
+
1304
+ #ifdef GGML_USE_CUDA
1305
+ result || (result = ggml_backend_cuda_buffer_type(params.gpu_device));
1306
+ #endif
1307
+
1308
+ #ifdef GGML_USE_METAL
1309
+ result || (result = ggml_backend_metal_buffer_type());
1310
+ #endif
1311
+
1312
+ #ifdef GGML_USE_SYCL
1313
+ result || (result = ggml_backend_sycl_buffer_type(params.gpu_device));
1314
+ #endif
1315
+
1316
+ result || (result = ggml_backend_cpu_buffer_type());
1317
+
1318
+ return result;
1319
  }
1320
 
1321
  // load the model from a ggml file
 
1742
  }
1743
  }
1744
 
 
 
 
 
 
 
1745
  // allocate tensors in the backend buffers
1746
+ model.buffer = ggml_backend_alloc_ctx_tensors_from_buft(model.ctx, whisper_default_buffer_type(wctx.params));
1747
  if (!model.buffer) {
1748
  WHISPER_LOG_ERROR("%s: failed to allocate memory for the model\n", __func__);
1749
  return false;
1750
  }
1751
 
1752
  size_t size_main = ggml_backend_buffer_get_size(model.buffer);
1753
+ WHISPER_LOG_INFO("%s: %8s total size = %8.2f MB\n", __func__, ggml_backend_buffer_name(model.buffer), size_main / 1e6);
1754
 
1755
  // load weights
1756
  {
 
1845
  }
1846
  }
1847
 
1848
+ ggml_backend_buffer_set_usage(model.buffer, GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
1849
+
1850
  wctx.t_load_us = ggml_time_us() - t_start_us;
1851
 
1852
  return true;
 
1883
  const int n_mels = hparams.n_mels;
1884
 
1885
  struct ggml_init_params params = {
1886
+ /*.mem_size =*/ wstate.sched_conv.meta.size(),
1887
+ /*.mem_buffer =*/ wstate.sched_conv.meta.data(),
1888
  /*.no_alloc =*/ true,
1889
  };
1890
 
 
1892
 
1893
  ggml_cgraph * gf = ggml_new_graph(ctx0);
1894
 
1895
+ GGML_ASSERT(wstate.mel.tensor);
1896
+
1897
  ggml_tensor * mel_inp = wstate.mel.tensor;
1898
+ ggml_set_input(mel_inp);
1899
+
1900
  ggml_tensor * mel;
1901
+ {
1902
  const int n_len = int(mel_inp->ne[0]);
1903
  const int out_s = 2 * n_ctx;
1904
  const int i0 = std::min(mel_offset, n_len);
 
1912
 
1913
  if (mel_s < out_s) {
1914
  mel = ggml_pad(ctx0, cur, out_s - mel_s, 0, 0, 0);
1915
+ } else {
 
1916
  mel = ggml_cont(ctx0, cur);
1917
  }
1918
  }
1919
+
1920
+ ggml_set_name(mel, "mel");
 
 
 
1921
 
1922
  struct ggml_tensor * cur = nullptr;
1923
 
 
1941
  ggml_build_forward_expand(gf, mel);
1942
 
1943
  cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx);
1944
+ ggml_set_input(cur); // the external encoder will write into this tensor
1945
 
1946
  ggml_set_name(cur, "embd_enc");
1947
  wstate.embd_enc = cur;
 
1976
  const int n_ctx_pad = GGML_PAD(n_ctx, 256);
1977
 
1978
  struct ggml_init_params params = {
1979
+ /*.mem_size =*/ wstate.sched_encode.meta.size(),
1980
+ /*.mem_buffer =*/ wstate.sched_encode.meta.data(),
1981
  /*.no_alloc =*/ true,
1982
  };
1983
 
 
2216
  const int n_ctx_pad = GGML_PAD(n_ctx, 256);
2217
 
2218
  struct ggml_init_params params = {
2219
+ /*.mem_size =*/ wstate.sched_cross.meta.size(),
2220
+ /*.mem_buffer =*/ wstate.sched_cross.meta.data(),
2221
  /*.no_alloc =*/ true,
2222
  };
2223
 
 
2298
 
2299
  // conv
2300
  {
2301
+ auto & sched = wstate.sched_conv.sched;
2302
 
2303
  ggml_cgraph * gf = whisper_build_graph_conv(wctx, wstate, mel_offset);
2304
 
2305
+ if (!ggml_backend_sched_alloc_graph(sched, gf)) {
2306
  // should never happen as we pre-allocate the memory
2307
  return false;
2308
  }
2309
 
2310
+ if (!ggml_graph_compute_helper(sched, gf, n_threads)) {
2311
  return false;
2312
  }
2313
 
 
2325
 
2326
  // encoder
2327
  if (!whisper_encode_external(wstate)) {
2328
+ auto & sched = wstate.sched_encode.sched;
2329
 
2330
  ggml_cgraph * gf = whisper_build_graph_encoder(wctx, wstate);
2331
 
2332
+ if (!ggml_backend_sched_alloc_graph(sched, gf)) {
2333
  // should never happen as we pre-allocate the memory
2334
  return false;
2335
  }
2336
 
2337
+ if (!ggml_graph_compute_helper(sched, gf, n_threads)) {
2338
  return false;
2339
  }
2340
  }
2341
 
2342
  // cross
2343
  {
2344
+ auto & sched = wstate.sched_cross.sched;
2345
 
2346
  ggml_cgraph * gf = whisper_build_graph_cross(wctx, wstate);
2347
 
2348
+ if (!ggml_backend_sched_alloc_graph(sched, gf)) {
2349
  // should never happen as we pre-allocate the memory
2350
  return false;
2351
  }
2352
 
2353
+ if (!ggml_graph_compute_helper(sched, gf, n_threads)) {
2354
  return false;
2355
  }
2356
  }
 
2392
  //WHISPER_LOG_DEBUG("%s: n_past = %d, n_tokens = %d, n_audio_ctx = %d, n_ctx = %d\n", __func__, n_past, n_tokens, n_audio_ctx, n_ctx);
2393
 
2394
  struct ggml_init_params params = {
2395
+ /*.mem_size =*/ wstate.sched_decode.meta.size(),
2396
+ /*.mem_buffer =*/ wstate.sched_decode.meta.data(),
2397
  /*.no_alloc =*/ true,
2398
  };
2399
 
 
2792
 
2793
  // decoder
2794
  {
2795
+ auto & sched = wstate.sched_decode.sched;
2796
 
2797
  ggml_cgraph * gf = whisper_build_graph_decoder(wctx, wstate, batch, save_alignment_heads_QKs, false);
2798
 
2799
+ if (!ggml_backend_sched_alloc_graph(sched, gf)) {
2800
  // should never happen as we pre-allocate the memory
2801
  return false;
2802
  }
 
2851
 
2852
  logits = gf->nodes[gf->n_nodes - 1];
2853
 
2854
+ if (!ggml_graph_compute_helper(sched, gf, n_threads)) {
2855
  return false;
2856
  }
2857
  }
 
3355
  struct whisper_state * whisper_init_state(whisper_context * ctx) {
3356
  whisper_state * state = new whisper_state;
3357
 
3358
+ state->backends = whisper_backend_init(ctx->params);
3359
+ if (state->backends.empty()) {
3360
  WHISPER_LOG_ERROR("%s: whisper_backend_init() failed\n", __func__);
3361
  whisper_free_state(state);
3362
  return nullptr;
3363
  }
3364
 
3365
+ state->mel_calc = whisper_mel_calc_create(state->backends[0], ctx->model.filters);
3366
+
3367
+ // init 60s of random mel data
3368
+ {
3369
+ const int n_len = 2*100*WHISPER_CHUNK_SIZE;
3370
+ const int n_mel = ctx->model.filters.n_mel;
3371
+
3372
+ whisper_mel_free(state->mel);
3373
+ whisper_mel_init(state->mel, state->backends[0], n_len, n_len, n_mel);
3374
+ }
3375
 
3376
  // at this point, we don't know yet how many decoders will be used, so we overallocate 3x ctx
3377
  // in theory, there can be a case where this is not enough, but in practice it should always be enough
3378
  const int factor = 3;
3379
 
3380
+ if (!whisper_kv_cache_init(state->kv_self, state->backends[0], ctx->itype,
3381
  ctx->model.hparams.n_text_state,
3382
  ctx->model.hparams.n_text_layer,
3383
  GGML_PAD(ctx->model.hparams.n_text_ctx, 256)*factor)) {
 
3391
  WHISPER_LOG_INFO("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1e6);
3392
  }
3393
 
3394
+ if (!whisper_kv_cache_init(state->kv_cross, state->backends[0], ctx->itype,
3395
  ctx->model.hparams.n_text_state,
3396
  ctx->model.hparams.n_text_layer,
3397
  GGML_PAD(ctx->model.hparams.n_audio_ctx, 256))) {
 
3405
  WHISPER_LOG_INFO("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1e6);
3406
  }
3407
 
3408
+ if (!whisper_kv_cache_init(state->kv_pad, state->backends[0], ctx->itype,
3409
  ctx->model.hparams.n_audio_state,
3410
  1,
3411
  GGML_PAD(ctx->model.hparams.n_audio_ctx, 256))) {
 
3421
 
3422
  // [EXPERIMENTAL] Token-level timestamps with DTW
3423
  if (ctx->params.dtw_token_timestamps) {
3424
+ if (!aheads_masks_init(ctx->params, ctx->model.hparams, state->aheads_masks, state->backends[0])) {
3425
  WHISPER_LOG_ERROR("%s: aheads_masks_init() failed for alignment heads masks\n", __func__);
3426
  whisper_free_state(state);
3427
  return nullptr;
 
3464
 
3465
  // conv allocator
3466
  {
3467
+ bool ok = whisper_sched_graph_init(state->sched_conv, state->backends,
3468
  [&]() {
3469
  return whisper_build_graph_conv(*ctx, *state, 0);
3470
  });
 
3475
  return nullptr;
3476
  }
3477
 
3478
+ WHISPER_LOG_INFO("%s: compute buffer (conv) = %7.2f MB\n", __func__, whisper_sched_size(state->sched_conv) / 1e6);
3479
  }
3480
 
3481
  // encoder allocator
3482
  if (!whisper_encode_external(*state)) {
3483
+ bool ok = whisper_sched_graph_init(state->sched_encode, state->backends,
3484
  [&]() {
3485
  return whisper_build_graph_encoder(*ctx, *state);
3486
  });
 
3491
  return nullptr;
3492
  }
3493
 
3494
+ WHISPER_LOG_INFO("%s: compute buffer (encode) = %7.2f MB\n", __func__, whisper_sched_size(state->sched_encode) / 1e6);
3495
  }
3496
 
3497
  // cross allocator
3498
  {
3499
+ bool ok = whisper_sched_graph_init(state->sched_cross, state->backends,
3500
  [&]() {
3501
  return whisper_build_graph_cross(*ctx, *state);
3502
  });
 
3507
  return nullptr;
3508
  }
3509
 
3510
+ WHISPER_LOG_INFO("%s: compute buffer (cross) = %7.2f MB\n", __func__, whisper_sched_size(state->sched_cross) / 1e6);
3511
  }
3512
 
3513
  // decoder allocator
3514
  {
3515
+ bool ok = whisper_sched_graph_init(state->sched_decode, state->backends,
3516
  [&]() {
3517
  const auto & hparams = ctx->model.hparams;
3518
 
 
3531
  return nullptr;
3532
  }
3533
 
3534
+ WHISPER_LOG_INFO("%s: compute buffer (decode) = %7.2f MB\n", __func__, whisper_sched_size(state->sched_decode) / 1e6);
3535
  }
3536
 
3537
  return state;
 
3811
 
3812
  whisper_batch_free(state->batch);
3813
 
3814
+ ggml_backend_sched_free(state->sched_conv.sched);
3815
+ ggml_backend_sched_free(state->sched_encode.sched);
3816
+ ggml_backend_sched_free(state->sched_cross.sched);
3817
+ ggml_backend_sched_free(state->sched_decode.sched);
3818
 
3819
+ for (auto & backend : state->backends) {
3820
+ ggml_backend_free(backend);
3821
+ }
3822
 
3823
  // [EXPERIMENTAL] Token-level timestamps with DTW
3824
  aheads_masks_free(state->aheads_masks);
 
3835
 
3836
  whisper_free_state(ctx->state);
3837
 
 
 
3838
  delete ctx;
3839
  }
3840
  }
 
3865
  // 2. the time to transcribe audios this long will be dominated by the decoding time, so the mel calculation
3866
  // taking longer is not a major concern
3867
  if (!state->mel_calc_fallback) {
3868
+ state->mel_calc_fallback = new mel_calc_cpu(state->backends[0], ctx->model.filters);
3869
  }
3870
  state->mel = state->mel_calc_fallback->calculate({samples, n_samples}, n_threads);
3871
  }
 
3902
  }
3903
 
3904
  whisper_mel_free(state->mel);
3905
+ whisper_mel_init(state->mel, state->backends[0], n_len, n_len, n_mel);
3906
 
3907
  ggml_backend_tensor_set(state->mel.tensor, data, 0, ggml_nbytes(state->mel.tensor));
3908