ggerganov commited on
Commit
afe3785
·
unverified ·
1 Parent(s): d678325

whisper : fix excessive memory usage (#2443)

Browse files

* whisper : fix KV cache allocation

* whisper : reduce memory overhead from unused input tensors

Files changed (1) hide show
  1. src/whisper.cpp +39 -32
src/whisper.cpp CHANGED
@@ -163,7 +163,6 @@ static void whisper_log_callback_default(ggml_log_level level, const char * text
163
  } \
164
  } while (0)
165
 
166
- //#define WHISPER_USE_FLASH_FF
167
  #define WHISPER_MAX_DECODERS 8
168
  #define WHISPER_MAX_NODES 4096
169
 
@@ -817,6 +816,9 @@ struct whisper_state {
817
  int32_t n_fail_p = 0; // number of logprob threshold failures
818
  int32_t n_fail_h = 0; // number of entropy threshold failures
819
 
 
 
 
820
  // unified self-attention KV cache for all decoders
821
  whisper_kv_cache kv_self;
822
 
@@ -2096,9 +2098,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
2096
 
2097
  struct ggml_tensor * Q =
2098
  ggml_permute(ctx0,
2099
- ggml_cpy(ctx0,
2100
- Qcur,
2101
- ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state_head, n_head, n_ctx)),
2102
  0, 2, 1, 3);
2103
 
2104
  if (wctx.params.flash_attn) {
@@ -2125,9 +2125,9 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
2125
  } else {
2126
  struct ggml_tensor * K =
2127
  ggml_permute(ctx0,
2128
- ggml_cpy(ctx0,
2129
- Kcur,
2130
- ggml_new_tensor_3d(ctx0, wctx.itype, n_state_head, n_head, n_ctx)),
2131
  0, 2, 1, 3);
2132
 
2133
  // K * Q
@@ -2136,22 +2136,19 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
2136
  struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, nullptr, KQscale, 0.0f);
2137
 
2138
  struct ggml_tensor * V =
2139
- ggml_cpy(ctx0,
2140
  ggml_permute(ctx0,
2141
  ggml_reshape_3d(ctx0,
2142
  Vcur,
2143
  n_state_head, n_head, n_ctx),
2144
  1, 2, 0, 3),
2145
- ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state_head, n_head)
2146
- );
2147
 
2148
  struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
2149
 
2150
  struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
2151
 
2152
- cur = ggml_cpy(ctx0,
2153
- KQV_merged,
2154
- ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx));
2155
  }
2156
  }
2157
 
@@ -2181,11 +2178,6 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
2181
  layer.mlp_ln_b);
2182
  }
2183
 
2184
- #ifdef WHISPER_USE_FLASH_FF
2185
- cur = ggml_flash_ff(ctx0,
2186
- ggml_cpy(ctx0, cur, ggml_new_tensor_2d(ctx0, wstate.itype, n_state, n_ctx)),
2187
- layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
2188
- #else
2189
  // fully connected
2190
  cur = ggml_mul_mat(ctx0,
2191
  layer.mlp_0_w,
@@ -2202,7 +2194,6 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
2202
  cur);
2203
 
2204
  cur = ggml_add(ctx0, cur, layer.mlp_1_b);
2205
- #endif
2206
  }
2207
 
2208
  inpL = ggml_add(ctx0, cur, inpFF);
@@ -2578,9 +2569,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
2578
 
2579
  struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
2580
 
2581
- cur = ggml_cpy(ctx0,
2582
- KQV_merged,
2583
- ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_tokens));
2584
  }
2585
  }
2586
 
@@ -2687,9 +2676,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
2687
 
2688
  struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
2689
 
2690
- cur = ggml_cpy(ctx0,
2691
- KQV_merged,
2692
- ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_tokens));
2693
  }
2694
  }
2695
 
@@ -3403,14 +3390,13 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
3403
  whisper_mel_init(state->mel, state->backends[0], n_len, n_len, n_mel);
3404
  }
3405
 
3406
- // at this point, we don't know yet how many decoders will be used, so we overallocate 3x ctx
3407
- // in theory, there can be a case where this is not enough, but in practice it should always be enough
3408
- const int factor = 3;
3409
-
3410
  if (!whisper_kv_cache_init(state->kv_self, state->backends[0], ctx->itype,
3411
  ctx->model.hparams.n_text_state,
3412
  ctx->model.hparams.n_text_layer,
3413
- GGML_PAD(ctx->model.hparams.n_text_ctx, 256)*factor)) {
3414
  WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for self-attention cache\n", __func__);
3415
  whisper_free_state(state);
3416
  return nullptr;
@@ -5775,13 +5761,34 @@ int whisper_full_with_state(
5775
  }
5776
  WHISPER_LOG_DEBUG("\n\n");
5777
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5778
  whisper_kv_cache_clear(state->kv_self);
5779
 
5780
  whisper_batch_prep_legacy(state->batch, prompt.data(), prompt.size(), 0, 0);
5781
 
5782
  if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, false, params.abort_callback, params.abort_callback_user_data)) {
5783
  WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
5784
- return -7;
5785
  }
5786
 
5787
  {
@@ -6081,7 +6088,7 @@ int whisper_full_with_state(
6081
 
6082
  if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, false, params.abort_callback, params.abort_callback_user_data)) {
6083
  WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
6084
- return -8;
6085
  }
6086
 
6087
  const int64_t t_start_sample_us = ggml_time_us();
 
163
  } \
164
  } while (0)
165
 
 
166
  #define WHISPER_MAX_DECODERS 8
167
  #define WHISPER_MAX_NODES 4096
168
 
 
816
  int32_t n_fail_p = 0; // number of logprob threshold failures
817
  int32_t n_fail_h = 0; // number of entropy threshold failures
818
 
819
+ // number of decoders for which we have constructed the KV cache
820
+ int32_t kv_self_n_dec = 0;
821
+
822
  // unified self-attention KV cache for all decoders
823
  whisper_kv_cache kv_self;
824
 
 
2098
 
2099
  struct ggml_tensor * Q =
2100
  ggml_permute(ctx0,
2101
+ ggml_reshape_3d(ctx0, Qcur, n_state_head, n_head, n_ctx),
 
 
2102
  0, 2, 1, 3);
2103
 
2104
  if (wctx.params.flash_attn) {
 
2125
  } else {
2126
  struct ggml_tensor * K =
2127
  ggml_permute(ctx0,
2128
+ ggml_cast(ctx0,
2129
+ ggml_reshape_3d(ctx0, Kcur, n_state_head, n_head, n_ctx),
2130
+ wctx.itype),
2131
  0, 2, 1, 3);
2132
 
2133
  // K * Q
 
2136
  struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, nullptr, KQscale, 0.0f);
2137
 
2138
  struct ggml_tensor * V =
2139
+ ggml_cast(ctx0,
2140
  ggml_permute(ctx0,
2141
  ggml_reshape_3d(ctx0,
2142
  Vcur,
2143
  n_state_head, n_head, n_ctx),
2144
  1, 2, 0, 3),
2145
+ wctx.itype);
 
2146
 
2147
  struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
2148
 
2149
  struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
2150
 
2151
+ cur = ggml_cont_2d(ctx0, KQV_merged, n_state, n_ctx);
 
 
2152
  }
2153
  }
2154
 
 
2178
  layer.mlp_ln_b);
2179
  }
2180
 
 
 
 
 
 
2181
  // fully connected
2182
  cur = ggml_mul_mat(ctx0,
2183
  layer.mlp_0_w,
 
2194
  cur);
2195
 
2196
  cur = ggml_add(ctx0, cur, layer.mlp_1_b);
 
2197
  }
2198
 
2199
  inpL = ggml_add(ctx0, cur, inpFF);
 
2569
 
2570
  struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
2571
 
2572
+ cur = ggml_cont_2d(ctx0, KQV_merged, n_state, n_tokens);
 
 
2573
  }
2574
  }
2575
 
 
2676
 
2677
  struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
2678
 
2679
+ cur = ggml_cont_2d(ctx0, KQV_merged, n_state, n_tokens);
 
 
2680
  }
2681
  }
2682
 
 
3390
  whisper_mel_init(state->mel, state->backends[0], n_len, n_len, n_mel);
3391
  }
3392
 
3393
+ // at this point, we don't know yet how many decoders will be used
3394
+ // later during decoding, if more decoders are used, we will recreate the KV cache respectively
3395
+ state->kv_self_n_dec = 1;
 
3396
  if (!whisper_kv_cache_init(state->kv_self, state->backends[0], ctx->itype,
3397
  ctx->model.hparams.n_text_state,
3398
  ctx->model.hparams.n_text_layer,
3399
+ GGML_PAD(ctx->model.hparams.n_text_ctx, 256))) {
3400
  WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for self-attention cache\n", __func__);
3401
  whisper_free_state(state);
3402
  return nullptr;
 
5761
  }
5762
  WHISPER_LOG_DEBUG("\n\n");
5763
 
5764
+ // recreate the KV cache if the number of decoders has changed
5765
+ if (state->kv_self_n_dec < n_decoders_cur) {
5766
+ WHISPER_LOG_DEBUG("%s: recreating KV cache: n_decoders_cur = %d\n", __func__, n_decoders_cur);
5767
+
5768
+ whisper_kv_cache_free(state->kv_self);
5769
+
5770
+ // overallocate to workaround KV cache fragmentation issues
5771
+ const int factor = n_decoders_cur > 1 ? n_decoders_cur + 2 : 1;
5772
+
5773
+ if (!whisper_kv_cache_init(state->kv_self, state->backends[0], ctx->itype,
5774
+ ctx->model.hparams.n_text_state,
5775
+ ctx->model.hparams.n_text_layer,
5776
+ GGML_PAD(ctx->model.hparams.n_text_ctx, 256)*factor)) {
5777
+ WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for self-attention cache\n", __func__);
5778
+ whisper_free_state(state);
5779
+ return -7;
5780
+ }
5781
+
5782
+ state->kv_self_n_dec = n_decoders_cur;
5783
+ }
5784
+
5785
  whisper_kv_cache_clear(state->kv_self);
5786
 
5787
  whisper_batch_prep_legacy(state->batch, prompt.data(), prompt.size(), 0, 0);
5788
 
5789
  if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, false, params.abort_callback, params.abort_callback_user_data)) {
5790
  WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
5791
+ return -8;
5792
  }
5793
 
5794
  {
 
6088
 
6089
  if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, false, params.abort_callback, params.abort_callback_user_data)) {
6090
  WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
6091
+ return -9;
6092
  }
6093
 
6094
  const int64_t t_start_sample_us = ggml_time_us();