mkiol commited on
Commit
08ba486
·
unverified ·
1 Parent(s): 6c20dfb

whisper : add abort callback (#1335)

Browse files
Files changed (2) hide show
  1. whisper.cpp +31 -19
  2. whisper.h +9 -0
whisper.cpp CHANGED
@@ -125,9 +125,17 @@ static void byteswap_tensor(ggml_tensor * tensor) {
125
  // ggml helpers
126
  //
127
 
128
- static void ggml_graph_compute_helper(std::vector<uint8_t> & buf, ggml_cgraph * graph, int n_threads) {
 
 
 
 
 
129
  struct ggml_cplan plan = ggml_graph_plan(graph, n_threads);
130
 
 
 
 
131
  if (plan.work_size > 0) {
132
  buf.resize(plan.work_size);
133
  plan.work_data = buf.data();
@@ -1922,7 +1930,9 @@ static bool whisper_encode_internal(
1922
  whisper_context & wctx,
1923
  whisper_state & wstate,
1924
  const int mel_offset,
1925
- const int n_threads) {
 
 
1926
  const int64_t t_start_us = ggml_time_us();
1927
 
1928
  // conv
@@ -1936,7 +1946,7 @@ static bool whisper_encode_internal(
1936
  ggml_allocr_alloc_graph(alloc, gf);
1937
 
1938
  if (!whisper_encode_external(wstate)) {
1939
- ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
1940
  }
1941
  }
1942
 
@@ -1955,10 +1965,10 @@ static bool whisper_encode_internal(
1955
  ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
1956
  ggml_metal_graph_compute(wstate.ctx_metal, gf);
1957
  } else {
1958
- ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
1959
  }
1960
  #else
1961
- ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
1962
  #endif
1963
  }
1964
 
@@ -1977,10 +1987,10 @@ static bool whisper_encode_internal(
1977
  ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
1978
  ggml_metal_graph_compute(wstate.ctx_metal, gf);
1979
  } else {
1980
- ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
1981
  }
1982
  #else
1983
- ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
1984
  #endif
1985
  }
1986
 
@@ -2346,7 +2356,9 @@ static bool whisper_decode_internal(
2346
  const whisper_token * tokens,
2347
  const int n_tokens,
2348
  const int n_past,
2349
- const int n_threads) {
 
 
2350
  const int64_t t_start_us = ggml_time_us();
2351
 
2352
  const auto & model = wctx.model;
@@ -2375,10 +2387,10 @@ static bool whisper_decode_internal(
2375
  ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
2376
  ggml_metal_graph_compute(wstate.ctx_metal, gf);
2377
  } else {
2378
- ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
2379
  }
2380
  #else
2381
- ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
2382
  #endif
2383
  }
2384
 
@@ -3290,7 +3302,7 @@ int whisper_set_mel(
3290
  }
3291
 
3292
  int whisper_encode_with_state(struct whisper_context * ctx, struct whisper_state * state, int offset, int n_threads) {
3293
- if (!whisper_encode_internal(*ctx, *state, offset, n_threads)) {
3294
  log("%s: failed to eval\n", __func__);
3295
  return -1;
3296
  }
@@ -3299,7 +3311,7 @@ int whisper_encode_with_state(struct whisper_context * ctx, struct whisper_state
3299
  }
3300
 
3301
  int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
3302
- if (!whisper_encode_internal(*ctx, *ctx->state, offset, n_threads)) {
3303
  log("%s: failed to eval\n", __func__);
3304
  return -1;
3305
  }
@@ -3310,7 +3322,7 @@ int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
3310
  int whisper_decode_with_state(struct whisper_context * ctx, struct whisper_state * state, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) {
3311
  const int selected_decoder_id = 0;
3312
 
3313
- if (!whisper_decode_internal(*ctx, *state, state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads)) {
3314
  log("%s: failed to eval\n", __func__);
3315
  return 1;
3316
  }
@@ -3327,7 +3339,7 @@ int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, i
3327
  return false;
3328
  }
3329
 
3330
- if (!whisper_decode_internal(*ctx, *ctx->state, ctx->state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads)) {
3331
  log("%s: failed to eval\n", __func__);
3332
  return 1;
3333
  }
@@ -4594,7 +4606,7 @@ int whisper_full_with_state(
4594
  }
4595
 
4596
  // encode audio features starting at offset seek
4597
- if (!whisper_encode_internal(*ctx, *state, seek, params.n_threads)) {
4598
  log("%s: failed to encode\n", __func__);
4599
  return -6;
4600
  }
@@ -4677,7 +4689,7 @@ int whisper_full_with_state(
4677
  }
4678
  WHISPER_PRINT_DEBUG("\n\n");
4679
 
4680
- if (!whisper_decode_internal(*ctx, *state, state->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads)) {
4681
  log("%s: failed to decode\n", __func__);
4682
  return -7;
4683
  }
@@ -4901,7 +4913,7 @@ int whisper_full_with_state(
4901
 
4902
  //WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, kv_self.n %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.kv_self.n, decoder.seek_delta);
4903
 
4904
- if (!whisper_decode_internal(*ctx, *state, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n, params.n_threads)) {
4905
  log("%s: failed to decode\n", __func__);
4906
  return -8;
4907
  }
@@ -5473,12 +5485,12 @@ WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads) {
5473
  double tsum = 0.0;
5474
 
5475
  // heat-up
5476
- ggml_graph_compute_helper(work, &gf, n_threads);
5477
 
5478
  for (int i = 0; i < n_max; ++i) {
5479
  const int64_t t0 = ggml_time_us();
5480
 
5481
- ggml_graph_compute_helper(work, &gf, n_threads);
5482
 
5483
  const int64_t t1 = ggml_time_us();
5484
 
 
125
  // ggml helpers
126
  //
127
 
128
+ static void ggml_graph_compute_helper(
129
+ std::vector<uint8_t> & buf,
130
+ ggml_cgraph * graph,
131
+ int n_threads,
132
+ whisper_abort_callback abort_callback,
133
+ void * abort_callback_data) {
134
  struct ggml_cplan plan = ggml_graph_plan(graph, n_threads);
135
 
136
+ plan.abort_callback = abort_callback;
137
+ plan.abort_callback_data = abort_callback_data;
138
+
139
  if (plan.work_size > 0) {
140
  buf.resize(plan.work_size);
141
  plan.work_data = buf.data();
 
1930
  whisper_context & wctx,
1931
  whisper_state & wstate,
1932
  const int mel_offset,
1933
+ const int n_threads,
1934
+ whisper_abort_callback abort_callback,
1935
+ void * abort_callback_data) {
1936
  const int64_t t_start_us = ggml_time_us();
1937
 
1938
  // conv
 
1946
  ggml_allocr_alloc_graph(alloc, gf);
1947
 
1948
  if (!whisper_encode_external(wstate)) {
1949
+ ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
1950
  }
1951
  }
1952
 
 
1965
  ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
1966
  ggml_metal_graph_compute(wstate.ctx_metal, gf);
1967
  } else {
1968
+ ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
1969
  }
1970
  #else
1971
+ ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
1972
  #endif
1973
  }
1974
 
 
1987
  ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
1988
  ggml_metal_graph_compute(wstate.ctx_metal, gf);
1989
  } else {
1990
+ ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
1991
  }
1992
  #else
1993
+ ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
1994
  #endif
1995
  }
1996
 
 
2356
  const whisper_token * tokens,
2357
  const int n_tokens,
2358
  const int n_past,
2359
+ const int n_threads,
2360
+ whisper_abort_callback abort_callback,
2361
+ void * abort_callback_data) {
2362
  const int64_t t_start_us = ggml_time_us();
2363
 
2364
  const auto & model = wctx.model;
 
2387
  ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
2388
  ggml_metal_graph_compute(wstate.ctx_metal, gf);
2389
  } else {
2390
+ ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
2391
  }
2392
  #else
2393
+ ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
2394
  #endif
2395
  }
2396
 
 
3302
  }
3303
 
3304
  int whisper_encode_with_state(struct whisper_context * ctx, struct whisper_state * state, int offset, int n_threads) {
3305
+ if (!whisper_encode_internal(*ctx, *state, offset, n_threads, nullptr, nullptr)) {
3306
  log("%s: failed to eval\n", __func__);
3307
  return -1;
3308
  }
 
3311
  }
3312
 
3313
  int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
3314
+ if (!whisper_encode_internal(*ctx, *ctx->state, offset, n_threads, nullptr, nullptr)) {
3315
  log("%s: failed to eval\n", __func__);
3316
  return -1;
3317
  }
 
3322
  int whisper_decode_with_state(struct whisper_context * ctx, struct whisper_state * state, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) {
3323
  const int selected_decoder_id = 0;
3324
 
3325
+ if (!whisper_decode_internal(*ctx, *state, state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads, nullptr, nullptr)) {
3326
  log("%s: failed to eval\n", __func__);
3327
  return 1;
3328
  }
 
3339
  return false;
3340
  }
3341
 
3342
+ if (!whisper_decode_internal(*ctx, *ctx->state, ctx->state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads, nullptr, nullptr)) {
3343
  log("%s: failed to eval\n", __func__);
3344
  return 1;
3345
  }
 
4606
  }
4607
 
4608
  // encode audio features starting at offset seek
4609
+ if (!whisper_encode_internal(*ctx, *state, seek, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
4610
  log("%s: failed to encode\n", __func__);
4611
  return -6;
4612
  }
 
4689
  }
4690
  WHISPER_PRINT_DEBUG("\n\n");
4691
 
4692
+ if (!whisper_decode_internal(*ctx, *state, state->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
4693
  log("%s: failed to decode\n", __func__);
4694
  return -7;
4695
  }
 
4913
 
4914
  //WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, kv_self.n %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.kv_self.n, decoder.seek_delta);
4915
 
4916
+ if (!whisper_decode_internal(*ctx, *state, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
4917
  log("%s: failed to decode\n", __func__);
4918
  return -8;
4919
  }
 
5485
  double tsum = 0.0;
5486
 
5487
  // heat-up
5488
+ ggml_graph_compute_helper(work, &gf, n_threads, nullptr , nullptr);
5489
 
5490
  for (int i = 0; i < n_max; ++i) {
5491
  const int64_t t0 = ggml_time_us();
5492
 
5493
+ ggml_graph_compute_helper(work, &gf, n_threads, nullptr, nullptr);
5494
 
5495
  const int64_t t1 = ggml_time_us();
5496
 
whisper.h CHANGED
@@ -334,6 +334,11 @@ extern "C" {
334
  // If it returns false, the computation is aborted
335
  typedef bool (*whisper_encoder_begin_callback)(struct whisper_context * ctx, struct whisper_state * state, void * user_data);
336
 
 
 
 
 
 
337
  // Logits filter callback
338
  // Can be used to modify the logits before sampling
339
  // If not NULL, called after applying temperature to logits
@@ -428,6 +433,10 @@ extern "C" {
428
  whisper_encoder_begin_callback encoder_begin_callback;
429
  void * encoder_begin_callback_user_data;
430
 
 
 
 
 
431
  // called by each decoder to filter obtained logits
432
  whisper_logits_filter_callback logits_filter_callback;
433
  void * logits_filter_callback_user_data;
 
334
  // If it returns false, the computation is aborted
335
  typedef bool (*whisper_encoder_begin_callback)(struct whisper_context * ctx, struct whisper_state * state, void * user_data);
336
 
337
+ // Abort callback
338
+ // If not NULL, called before ggml computation
339
+ // If it returns true, the computation is aborted
340
+ typedef bool (*whisper_abort_callback)(void * user_data);
341
+
342
  // Logits filter callback
343
  // Can be used to modify the logits before sampling
344
  // If not NULL, called after applying temperature to logits
 
433
  whisper_encoder_begin_callback encoder_begin_callback;
434
  void * encoder_begin_callback_user_data;
435
 
436
+ // called each time before ggml computation starts
437
+ whisper_abort_callback abort_callback;
438
+ void * abort_callback_user_data;
439
+
440
  // called by each decoder to filter obtained logits
441
  whisper_logits_filter_callback logits_filter_callback;
442
  void * logits_filter_callback_user_data;