Spaces:
Running
Running
Try to improve the token sampling strategy (#193)
Browse files* whisper : try to improve the token sampling strategy
- Add the "max_initial_timestaamp" token logic from OpenAI
- Disallow sampling timestamps that are in the past
* whisper : fix the max initial timestamp logic + fallback decoding
- whisper.cpp +45 -52
- whisper.h +1 -1
whisper.cpp
CHANGED
|
@@ -1846,7 +1846,9 @@ static bool whisper_decode(
|
|
| 1846 |
// the most basic sampling scheme - select the top token
|
| 1847 |
static whisper_token_data whisper_sample_best(
|
| 1848 |
const whisper_vocab & vocab,
|
| 1849 |
-
const float * probs
|
|
|
|
|
|
|
| 1850 |
whisper_token_data result = {
|
| 1851 |
0, 0, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f,
|
| 1852 |
};
|
|
@@ -1869,7 +1871,18 @@ static whisper_token_data whisper_sample_best(
|
|
| 1869 |
max_tx = std::max(max_tx, probs_id[i].first);
|
| 1870 |
}
|
| 1871 |
|
| 1872 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1873 |
sum_ts += probs_id[i].first;
|
| 1874 |
if (probs_id[i].first > max_ts) {
|
| 1875 |
max_ts = probs_id[i].first;
|
|
@@ -1879,7 +1892,7 @@ static whisper_token_data whisper_sample_best(
|
|
| 1879 |
|
| 1880 |
// if the probability sum of all timestamp tokens is higher than the max probability of the text tokens - sample a
|
| 1881 |
// timestamp token
|
| 1882 |
-
if (sum_ts > max_tx) {
|
| 1883 |
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L430-L438
|
| 1884 |
for (int i = 0; i < vocab.token_beg; i++) {
|
| 1885 |
probs_id[i].first = -INFINITY;
|
|
@@ -1921,39 +1934,6 @@ static whisper_token_data whisper_sample_best(
|
|
| 1921 |
return result;
|
| 1922 |
}
|
| 1923 |
|
| 1924 |
-
// samples only from the timestamps tokens
|
| 1925 |
-
static whisper_vocab::id whisper_sample_timestamp(
|
| 1926 |
-
const whisper_vocab & vocab,
|
| 1927 |
-
const float * probs) {
|
| 1928 |
-
int n_logits = vocab.id_to_token.size();
|
| 1929 |
-
|
| 1930 |
-
std::vector<std::pair<double, whisper_vocab::id>> probs_id;
|
| 1931 |
-
probs_id.reserve(n_logits);
|
| 1932 |
-
|
| 1933 |
-
for (int i = vocab.token_beg + 1; i < n_logits; i++) {
|
| 1934 |
-
probs_id.push_back(std::make_pair(probs[i], i));
|
| 1935 |
-
}
|
| 1936 |
-
|
| 1937 |
-
const int top_k = 10;
|
| 1938 |
-
|
| 1939 |
-
// find the top K tokens
|
| 1940 |
-
std::partial_sort(
|
| 1941 |
-
probs_id.begin(),
|
| 1942 |
-
probs_id.begin() + top_k, probs_id.end(),
|
| 1943 |
-
[](const std::pair<double, whisper_vocab::id> & a, const std::pair<double, whisper_vocab::id> & b) {
|
| 1944 |
-
return a.first > b.first;
|
| 1945 |
-
});
|
| 1946 |
-
|
| 1947 |
-
probs_id.resize(top_k);
|
| 1948 |
-
|
| 1949 |
-
//printf("\n");
|
| 1950 |
-
//for (int i = 0; i < (int) probs_id.size(); i++) {
|
| 1951 |
-
// printf("%d: '%s' %f, %d\n", i, vocab.id_to_token.at(probs_id[i].second).c_str(), probs_id[i].first, probs_id[i].second);
|
| 1952 |
-
//}
|
| 1953 |
-
|
| 1954 |
-
return probs_id[0].second;
|
| 1955 |
-
}
|
| 1956 |
-
|
| 1957 |
// 500 -> 00:05.000
|
| 1958 |
// 6000 -> 01:00.000
|
| 1959 |
static std::string to_timestamp(int64_t t, bool comma = false) {
|
|
@@ -2284,19 +2264,17 @@ int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, i
|
|
| 2284 |
struct whisper_token_data whisper_sample_best(struct whisper_context * ctx) {
|
| 2285 |
const int64_t t_start_sample_us = ggml_time_us();
|
| 2286 |
|
| 2287 |
-
|
| 2288 |
-
auto res = whisper_sample_best(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab));
|
| 2289 |
|
| 2290 |
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
| 2291 |
|
| 2292 |
return res;
|
| 2293 |
}
|
| 2294 |
|
| 2295 |
-
|
| 2296 |
const int64_t t_start_sample_us = ggml_time_us();
|
| 2297 |
|
| 2298 |
-
|
| 2299 |
-
auto res = whisper_sample_timestamp(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab));
|
| 2300 |
|
| 2301 |
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
| 2302 |
|
|
@@ -2694,7 +2672,6 @@ int whisper_full(
|
|
| 2694 |
|
| 2695 |
prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end());
|
| 2696 |
|
| 2697 |
-
bool done = false;
|
| 2698 |
int seek_delta = 100*WHISPER_CHUNK_SIZE;
|
| 2699 |
|
| 2700 |
// print the prompt
|
|
@@ -2708,7 +2685,9 @@ int whisper_full(
|
|
| 2708 |
int result_len = 0;
|
| 2709 |
tokens_cur.clear();
|
| 2710 |
|
| 2711 |
-
|
|
|
|
|
|
|
| 2712 |
if (whisper_decode(ctx, prompt.data(), prompt.size(), n_past, params.n_threads) != 0) {
|
| 2713 |
fprintf(stderr, "%s: failed to decode\n", __func__);
|
| 2714 |
return 8;
|
|
@@ -2725,15 +2704,19 @@ int whisper_full(
|
|
| 2725 |
// feel free to experiment!
|
| 2726 |
//
|
| 2727 |
{
|
| 2728 |
-
auto token = whisper_sample_best(ctx);
|
| 2729 |
-
|
| 2730 |
-
if (i == 0) {
|
| 2731 |
-
token.tid = whisper_token_beg(ctx);
|
| 2732 |
-
}
|
| 2733 |
|
| 2734 |
// timestamp token - update sliding window
|
| 2735 |
if (token.id > whisper_token_beg(ctx)) {
|
| 2736 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2737 |
result_len = i + 1;
|
| 2738 |
}
|
| 2739 |
|
|
@@ -2752,8 +2735,8 @@ int whisper_full(
|
|
| 2752 |
if (seek + seek_delta + 100 >= seek_end) {
|
| 2753 |
result_len = i + 1;
|
| 2754 |
} else {
|
| 2755 |
-
|
| 2756 |
-
|
| 2757 |
}
|
| 2758 |
}
|
| 2759 |
|
|
@@ -2772,11 +2755,21 @@ int whisper_full(
|
|
| 2772 |
}
|
| 2773 |
}
|
| 2774 |
|
| 2775 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2776 |
break;
|
| 2777 |
}
|
| 2778 |
}
|
| 2779 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2780 |
// shrink down to result_len
|
| 2781 |
tokens_cur.resize(result_len);
|
| 2782 |
|
|
|
|
| 1846 |
// the most basic sampling scheme - select the top token
|
| 1847 |
static whisper_token_data whisper_sample_best(
|
| 1848 |
const whisper_vocab & vocab,
|
| 1849 |
+
const float * probs,
|
| 1850 |
+
bool force_timestamp,
|
| 1851 |
+
bool is_initial) {
|
| 1852 |
whisper_token_data result = {
|
| 1853 |
0, 0, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f,
|
| 1854 |
};
|
|
|
|
| 1871 |
max_tx = std::max(max_tx, probs_id[i].first);
|
| 1872 |
}
|
| 1873 |
|
| 1874 |
+
const auto i0 = is_initial ? vocab.token_beg + 101 : vocab.token_beg;
|
| 1875 |
+
const auto i1 = is_initial ? vocab.token_beg + 101 : n_logits;
|
| 1876 |
+
|
| 1877 |
+
// the initial timestamp cannot be larger than 100
|
| 1878 |
+
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L426-L429
|
| 1879 |
+
if (is_initial) {
|
| 1880 |
+
for (int i = i0; i < n_logits; ++ i) {
|
| 1881 |
+
probs_id[i].first = -INFINITY;
|
| 1882 |
+
}
|
| 1883 |
+
}
|
| 1884 |
+
|
| 1885 |
+
for (int i = vocab.token_beg; i < i1; i++) {
|
| 1886 |
sum_ts += probs_id[i].first;
|
| 1887 |
if (probs_id[i].first > max_ts) {
|
| 1888 |
max_ts = probs_id[i].first;
|
|
|
|
| 1892 |
|
| 1893 |
// if the probability sum of all timestamp tokens is higher than the max probability of the text tokens - sample a
|
| 1894 |
// timestamp token
|
| 1895 |
+
if (sum_ts > max_tx || force_timestamp) {
|
| 1896 |
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L430-L438
|
| 1897 |
for (int i = 0; i < vocab.token_beg; i++) {
|
| 1898 |
probs_id[i].first = -INFINITY;
|
|
|
|
| 1934 |
return result;
|
| 1935 |
}
|
| 1936 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1937 |
// 500 -> 00:05.000
|
| 1938 |
// 6000 -> 01:00.000
|
| 1939 |
static std::string to_timestamp(int64_t t, bool comma = false) {
|
|
|
|
| 2264 |
struct whisper_token_data whisper_sample_best(struct whisper_context * ctx) {
|
| 2265 |
const int64_t t_start_sample_us = ggml_time_us();
|
| 2266 |
|
| 2267 |
+
const auto res = whisper_sample_best(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab), false, false);
|
|
|
|
| 2268 |
|
| 2269 |
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
| 2270 |
|
| 2271 |
return res;
|
| 2272 |
}
|
| 2273 |
|
| 2274 |
+
struct whisper_token_data whisper_sample_timestamp(struct whisper_context * ctx, bool is_initial) {
|
| 2275 |
const int64_t t_start_sample_us = ggml_time_us();
|
| 2276 |
|
| 2277 |
+
const auto res = whisper_sample_best(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab), true, is_initial);
|
|
|
|
| 2278 |
|
| 2279 |
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
| 2280 |
|
|
|
|
| 2672 |
|
| 2673 |
prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end());
|
| 2674 |
|
|
|
|
| 2675 |
int seek_delta = 100*WHISPER_CHUNK_SIZE;
|
| 2676 |
|
| 2677 |
// print the prompt
|
|
|
|
| 2685 |
int result_len = 0;
|
| 2686 |
tokens_cur.clear();
|
| 2687 |
|
| 2688 |
+
bool failed = false;
|
| 2689 |
+
|
| 2690 |
+
for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) {
|
| 2691 |
if (whisper_decode(ctx, prompt.data(), prompt.size(), n_past, params.n_threads) != 0) {
|
| 2692 |
fprintf(stderr, "%s: failed to decode\n", __func__);
|
| 2693 |
return 8;
|
|
|
|
| 2704 |
// feel free to experiment!
|
| 2705 |
//
|
| 2706 |
{
|
| 2707 |
+
const auto token = (i == 0) ? whisper_sample_timestamp(ctx, true) : whisper_sample_best(ctx);
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2708 |
|
| 2709 |
// timestamp token - update sliding window
|
| 2710 |
if (token.id > whisper_token_beg(ctx)) {
|
| 2711 |
+
const int seek_delta_new = 2*(token.id - whisper_token_beg(ctx));
|
| 2712 |
+
|
| 2713 |
+
// do not allow to go back in time
|
| 2714 |
+
if (seek_delta != 100*WHISPER_CHUNK_SIZE &&
|
| 2715 |
+
seek_delta > seek_delta_new && result_len < i) {
|
| 2716 |
+
break;
|
| 2717 |
+
}
|
| 2718 |
+
|
| 2719 |
+
seek_delta = seek_delta_new;
|
| 2720 |
result_len = i + 1;
|
| 2721 |
}
|
| 2722 |
|
|
|
|
| 2735 |
if (seek + seek_delta + 100 >= seek_end) {
|
| 2736 |
result_len = i + 1;
|
| 2737 |
} else {
|
| 2738 |
+
failed = true;
|
| 2739 |
+
break;
|
| 2740 |
}
|
| 2741 |
}
|
| 2742 |
|
|
|
|
| 2755 |
}
|
| 2756 |
}
|
| 2757 |
|
| 2758 |
+
// sometimes, the decoding can get stuck in a repetition loop
|
| 2759 |
+
// this is a simple strategy to avoid such cases - we simply flag the decoding as failed and advance
|
| 2760 |
+
// the sliding window by 1 second
|
| 2761 |
+
if (i == n_max - 1 && (result_len == 0 || seek_delta < 100*WHISPER_CHUNK_SIZE/2)) {
|
| 2762 |
+
failed = true;
|
| 2763 |
break;
|
| 2764 |
}
|
| 2765 |
}
|
| 2766 |
|
| 2767 |
+
if (failed) {
|
| 2768 |
+
fprintf(stderr, "\n%s: failed to generate timestamp token - using fallback strategy\n\n", __func__);
|
| 2769 |
+
seek += 100;
|
| 2770 |
+
continue;
|
| 2771 |
+
}
|
| 2772 |
+
|
| 2773 |
// shrink down to result_len
|
| 2774 |
tokens_cur.resize(result_len);
|
| 2775 |
|
whisper.h
CHANGED
|
@@ -137,7 +137,7 @@ extern "C" {
|
|
| 137 |
// whisper_sample_best() returns the token with the highest probability
|
| 138 |
// whisper_sample_timestamp() returns the most probable timestamp token
|
| 139 |
WHISPER_API whisper_token_data whisper_sample_best(struct whisper_context * ctx);
|
| 140 |
-
WHISPER_API
|
| 141 |
|
| 142 |
// Return the id of the specified language, returns -1 if not found
|
| 143 |
WHISPER_API int whisper_lang_id(const char * lang);
|
|
|
|
| 137 |
// whisper_sample_best() returns the token with the highest probability
|
| 138 |
// whisper_sample_timestamp() returns the most probable timestamp token
|
| 139 |
WHISPER_API whisper_token_data whisper_sample_best(struct whisper_context * ctx);
|
| 140 |
+
WHISPER_API whisper_token_data whisper_sample_timestamp(struct whisper_context * ctx, bool is_initial);
|
| 141 |
|
| 142 |
// Return the id of the specified language, returns -1 if not found
|
| 143 |
WHISPER_API int whisper_lang_id(const char * lang);
|