Spaces:
Running
Running
whisper : fix excessive memory usage (#2443)
Browse files* whisper : fix KV cache allocation
* whisper : reduce memory overhead from unused input tensors
- src/whisper.cpp +39 -32
src/whisper.cpp
CHANGED
|
@@ -163,7 +163,6 @@ static void whisper_log_callback_default(ggml_log_level level, const char * text
|
|
| 163 |
} \
|
| 164 |
} while (0)
|
| 165 |
|
| 166 |
-
//#define WHISPER_USE_FLASH_FF
|
| 167 |
#define WHISPER_MAX_DECODERS 8
|
| 168 |
#define WHISPER_MAX_NODES 4096
|
| 169 |
|
|
@@ -817,6 +816,9 @@ struct whisper_state {
|
|
| 817 |
int32_t n_fail_p = 0; // number of logprob threshold failures
|
| 818 |
int32_t n_fail_h = 0; // number of entropy threshold failures
|
| 819 |
|
|
|
|
|
|
|
|
|
|
| 820 |
// unified self-attention KV cache for all decoders
|
| 821 |
whisper_kv_cache kv_self;
|
| 822 |
|
|
@@ -2096,9 +2098,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
|
|
| 2096 |
|
| 2097 |
struct ggml_tensor * Q =
|
| 2098 |
ggml_permute(ctx0,
|
| 2099 |
-
|
| 2100 |
-
Qcur,
|
| 2101 |
-
ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state_head, n_head, n_ctx)),
|
| 2102 |
0, 2, 1, 3);
|
| 2103 |
|
| 2104 |
if (wctx.params.flash_attn) {
|
|
@@ -2125,9 +2125,9 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
|
|
| 2125 |
} else {
|
| 2126 |
struct ggml_tensor * K =
|
| 2127 |
ggml_permute(ctx0,
|
| 2128 |
-
|
| 2129 |
-
Kcur,
|
| 2130 |
-
|
| 2131 |
0, 2, 1, 3);
|
| 2132 |
|
| 2133 |
// K * Q
|
|
@@ -2136,22 +2136,19 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
|
|
| 2136 |
struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, nullptr, KQscale, 0.0f);
|
| 2137 |
|
| 2138 |
struct ggml_tensor * V =
|
| 2139 |
-
|
| 2140 |
ggml_permute(ctx0,
|
| 2141 |
ggml_reshape_3d(ctx0,
|
| 2142 |
Vcur,
|
| 2143 |
n_state_head, n_head, n_ctx),
|
| 2144 |
1, 2, 0, 3),
|
| 2145 |
-
|
| 2146 |
-
);
|
| 2147 |
|
| 2148 |
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
|
| 2149 |
|
| 2150 |
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
| 2151 |
|
| 2152 |
-
cur =
|
| 2153 |
-
KQV_merged,
|
| 2154 |
-
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx));
|
| 2155 |
}
|
| 2156 |
}
|
| 2157 |
|
|
@@ -2181,11 +2178,6 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
|
|
| 2181 |
layer.mlp_ln_b);
|
| 2182 |
}
|
| 2183 |
|
| 2184 |
-
#ifdef WHISPER_USE_FLASH_FF
|
| 2185 |
-
cur = ggml_flash_ff(ctx0,
|
| 2186 |
-
ggml_cpy(ctx0, cur, ggml_new_tensor_2d(ctx0, wstate.itype, n_state, n_ctx)),
|
| 2187 |
-
layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
|
| 2188 |
-
#else
|
| 2189 |
// fully connected
|
| 2190 |
cur = ggml_mul_mat(ctx0,
|
| 2191 |
layer.mlp_0_w,
|
|
@@ -2202,7 +2194,6 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
|
|
| 2202 |
cur);
|
| 2203 |
|
| 2204 |
cur = ggml_add(ctx0, cur, layer.mlp_1_b);
|
| 2205 |
-
#endif
|
| 2206 |
}
|
| 2207 |
|
| 2208 |
inpL = ggml_add(ctx0, cur, inpFF);
|
|
@@ -2578,9 +2569,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
|
| 2578 |
|
| 2579 |
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
| 2580 |
|
| 2581 |
-
cur =
|
| 2582 |
-
KQV_merged,
|
| 2583 |
-
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_tokens));
|
| 2584 |
}
|
| 2585 |
}
|
| 2586 |
|
|
@@ -2687,9 +2676,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
|
| 2687 |
|
| 2688 |
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
| 2689 |
|
| 2690 |
-
cur =
|
| 2691 |
-
KQV_merged,
|
| 2692 |
-
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_tokens));
|
| 2693 |
}
|
| 2694 |
}
|
| 2695 |
|
|
@@ -3403,14 +3390,13 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|
| 3403 |
whisper_mel_init(state->mel, state->backends[0], n_len, n_len, n_mel);
|
| 3404 |
}
|
| 3405 |
|
| 3406 |
-
// at this point, we don't know yet how many decoders will be used
|
| 3407 |
-
//
|
| 3408 |
-
|
| 3409 |
-
|
| 3410 |
if (!whisper_kv_cache_init(state->kv_self, state->backends[0], ctx->itype,
|
| 3411 |
ctx->model.hparams.n_text_state,
|
| 3412 |
ctx->model.hparams.n_text_layer,
|
| 3413 |
-
GGML_PAD(ctx->model.hparams.n_text_ctx, 256)
|
| 3414 |
WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for self-attention cache\n", __func__);
|
| 3415 |
whisper_free_state(state);
|
| 3416 |
return nullptr;
|
|
@@ -5775,13 +5761,34 @@ int whisper_full_with_state(
|
|
| 5775 |
}
|
| 5776 |
WHISPER_LOG_DEBUG("\n\n");
|
| 5777 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5778 |
whisper_kv_cache_clear(state->kv_self);
|
| 5779 |
|
| 5780 |
whisper_batch_prep_legacy(state->batch, prompt.data(), prompt.size(), 0, 0);
|
| 5781 |
|
| 5782 |
if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, false, params.abort_callback, params.abort_callback_user_data)) {
|
| 5783 |
WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
|
| 5784 |
-
return -
|
| 5785 |
}
|
| 5786 |
|
| 5787 |
{
|
|
@@ -6081,7 +6088,7 @@ int whisper_full_with_state(
|
|
| 6081 |
|
| 6082 |
if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, false, params.abort_callback, params.abort_callback_user_data)) {
|
| 6083 |
WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
|
| 6084 |
-
return -
|
| 6085 |
}
|
| 6086 |
|
| 6087 |
const int64_t t_start_sample_us = ggml_time_us();
|
|
|
|
| 163 |
} \
|
| 164 |
} while (0)
|
| 165 |
|
|
|
|
| 166 |
#define WHISPER_MAX_DECODERS 8
|
| 167 |
#define WHISPER_MAX_NODES 4096
|
| 168 |
|
|
|
|
| 816 |
int32_t n_fail_p = 0; // number of logprob threshold failures
|
| 817 |
int32_t n_fail_h = 0; // number of entropy threshold failures
|
| 818 |
|
| 819 |
+
// number of decoders for which we have constructed the KV cache
|
| 820 |
+
int32_t kv_self_n_dec = 0;
|
| 821 |
+
|
| 822 |
// unified self-attention KV cache for all decoders
|
| 823 |
whisper_kv_cache kv_self;
|
| 824 |
|
|
|
|
| 2098 |
|
| 2099 |
struct ggml_tensor * Q =
|
| 2100 |
ggml_permute(ctx0,
|
| 2101 |
+
ggml_reshape_3d(ctx0, Qcur, n_state_head, n_head, n_ctx),
|
|
|
|
|
|
|
| 2102 |
0, 2, 1, 3);
|
| 2103 |
|
| 2104 |
if (wctx.params.flash_attn) {
|
|
|
|
| 2125 |
} else {
|
| 2126 |
struct ggml_tensor * K =
|
| 2127 |
ggml_permute(ctx0,
|
| 2128 |
+
ggml_cast(ctx0,
|
| 2129 |
+
ggml_reshape_3d(ctx0, Kcur, n_state_head, n_head, n_ctx),
|
| 2130 |
+
wctx.itype),
|
| 2131 |
0, 2, 1, 3);
|
| 2132 |
|
| 2133 |
// K * Q
|
|
|
|
| 2136 |
struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, nullptr, KQscale, 0.0f);
|
| 2137 |
|
| 2138 |
struct ggml_tensor * V =
|
| 2139 |
+
ggml_cast(ctx0,
|
| 2140 |
ggml_permute(ctx0,
|
| 2141 |
ggml_reshape_3d(ctx0,
|
| 2142 |
Vcur,
|
| 2143 |
n_state_head, n_head, n_ctx),
|
| 2144 |
1, 2, 0, 3),
|
| 2145 |
+
wctx.itype);
|
|
|
|
| 2146 |
|
| 2147 |
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
|
| 2148 |
|
| 2149 |
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
| 2150 |
|
| 2151 |
+
cur = ggml_cont_2d(ctx0, KQV_merged, n_state, n_ctx);
|
|
|
|
|
|
|
| 2152 |
}
|
| 2153 |
}
|
| 2154 |
|
|
|
|
| 2178 |
layer.mlp_ln_b);
|
| 2179 |
}
|
| 2180 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2181 |
// fully connected
|
| 2182 |
cur = ggml_mul_mat(ctx0,
|
| 2183 |
layer.mlp_0_w,
|
|
|
|
| 2194 |
cur);
|
| 2195 |
|
| 2196 |
cur = ggml_add(ctx0, cur, layer.mlp_1_b);
|
|
|
|
| 2197 |
}
|
| 2198 |
|
| 2199 |
inpL = ggml_add(ctx0, cur, inpFF);
|
|
|
|
| 2569 |
|
| 2570 |
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
| 2571 |
|
| 2572 |
+
cur = ggml_cont_2d(ctx0, KQV_merged, n_state, n_tokens);
|
|
|
|
|
|
|
| 2573 |
}
|
| 2574 |
}
|
| 2575 |
|
|
|
|
| 2676 |
|
| 2677 |
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
| 2678 |
|
| 2679 |
+
cur = ggml_cont_2d(ctx0, KQV_merged, n_state, n_tokens);
|
|
|
|
|
|
|
| 2680 |
}
|
| 2681 |
}
|
| 2682 |
|
|
|
|
| 3390 |
whisper_mel_init(state->mel, state->backends[0], n_len, n_len, n_mel);
|
| 3391 |
}
|
| 3392 |
|
| 3393 |
+
// at this point, we don't know yet how many decoders will be used
|
| 3394 |
+
// later during decoding, if more decoders are used, we will recreate the KV cache respectively
|
| 3395 |
+
state->kv_self_n_dec = 1;
|
|
|
|
| 3396 |
if (!whisper_kv_cache_init(state->kv_self, state->backends[0], ctx->itype,
|
| 3397 |
ctx->model.hparams.n_text_state,
|
| 3398 |
ctx->model.hparams.n_text_layer,
|
| 3399 |
+
GGML_PAD(ctx->model.hparams.n_text_ctx, 256))) {
|
| 3400 |
WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for self-attention cache\n", __func__);
|
| 3401 |
whisper_free_state(state);
|
| 3402 |
return nullptr;
|
|
|
|
| 5761 |
}
|
| 5762 |
WHISPER_LOG_DEBUG("\n\n");
|
| 5763 |
|
| 5764 |
+
// recreate the KV cache if the number of decoders has changed
|
| 5765 |
+
if (state->kv_self_n_dec < n_decoders_cur) {
|
| 5766 |
+
WHISPER_LOG_DEBUG("%s: recreating KV cache: n_decoders_cur = %d\n", __func__, n_decoders_cur);
|
| 5767 |
+
|
| 5768 |
+
whisper_kv_cache_free(state->kv_self);
|
| 5769 |
+
|
| 5770 |
+
// overallocate to workaround KV cache fragmentation issues
|
| 5771 |
+
const int factor = n_decoders_cur > 1 ? n_decoders_cur + 2 : 1;
|
| 5772 |
+
|
| 5773 |
+
if (!whisper_kv_cache_init(state->kv_self, state->backends[0], ctx->itype,
|
| 5774 |
+
ctx->model.hparams.n_text_state,
|
| 5775 |
+
ctx->model.hparams.n_text_layer,
|
| 5776 |
+
GGML_PAD(ctx->model.hparams.n_text_ctx, 256)*factor)) {
|
| 5777 |
+
WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for self-attention cache\n", __func__);
|
| 5778 |
+
whisper_free_state(state);
|
| 5779 |
+
return -7;
|
| 5780 |
+
}
|
| 5781 |
+
|
| 5782 |
+
state->kv_self_n_dec = n_decoders_cur;
|
| 5783 |
+
}
|
| 5784 |
+
|
| 5785 |
whisper_kv_cache_clear(state->kv_self);
|
| 5786 |
|
| 5787 |
whisper_batch_prep_legacy(state->batch, prompt.data(), prompt.size(), 0, 0);
|
| 5788 |
|
| 5789 |
if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, false, params.abort_callback, params.abort_callback_user_data)) {
|
| 5790 |
WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
|
| 5791 |
+
return -8;
|
| 5792 |
}
|
| 5793 |
|
| 5794 |
{
|
|
|
|
| 6088 |
|
| 6089 |
if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, false, params.abort_callback, params.abort_callback_user_data)) {
|
| 6090 |
WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
|
| 6091 |
+
return -9;
|
| 6092 |
}
|
| 6093 |
|
| 6094 |
const int64_t t_start_sample_us = ggml_time_us();
|