Spaces:
Running
Running
ref #52 : improve greedy sampling strategy
Browse filesForce timestamp token to be sampled if the probability sum over all
timestamp tokens is above the probability of any other token
- whisper.cpp +25 -14
- whisper.h +1 -1
whisper.cpp
CHANGED
|
@@ -1784,7 +1784,7 @@ bool whisper_decode(
|
|
| 1784 |
// the most basic sampling scheme - select the top token
|
| 1785 |
whisper_vocab::id whisper_sample_best(
|
| 1786 |
const whisper_vocab & vocab,
|
| 1787 |
-
const float * probs
|
| 1788 |
int n_logits = vocab.id_to_token.size();
|
| 1789 |
|
| 1790 |
std::vector<std::pair<double, whisper_vocab::id>> probs_id;
|
|
@@ -1794,9 +1794,29 @@ whisper_vocab::id whisper_sample_best(
|
|
| 1794 |
probs_id.push_back(std::make_pair(probs[i], i));
|
| 1795 |
}
|
| 1796 |
|
| 1797 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1798 |
|
| 1799 |
// find the top K tokens
|
|
|
|
|
|
|
| 1800 |
std::partial_sort(
|
| 1801 |
probs_id.begin(),
|
| 1802 |
probs_id.begin() + top_k, probs_id.end(),
|
|
@@ -1811,15 +1831,6 @@ whisper_vocab::id whisper_sample_best(
|
|
| 1811 |
// printf("%d: '%s' %f, %d\n", i, vocab.id_to_token.at(probs_id[i].second).c_str(), probs_id[i].first, probs_id[i].second);
|
| 1812 |
//}
|
| 1813 |
|
| 1814 |
-
if (need_timestamp) {
|
| 1815 |
-
// at the end of the 30-second audio segment, we start giving preference to time tokens
|
| 1816 |
-
for (int i = 0; i < top_k; i++) {
|
| 1817 |
-
if (probs_id[i].second > vocab.token_beg + 1300 && probs_id[i].first > 0.01*probs_id[0].first) {
|
| 1818 |
-
return probs_id[i].second;
|
| 1819 |
-
}
|
| 1820 |
-
}
|
| 1821 |
-
}
|
| 1822 |
-
|
| 1823 |
int res = 0;
|
| 1824 |
while ((probs_id[res].second == vocab.token_sot ||
|
| 1825 |
probs_id[res].second == vocab.token_solm ||
|
|
@@ -2155,11 +2166,11 @@ int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, i
|
|
| 2155 |
return 0;
|
| 2156 |
}
|
| 2157 |
|
| 2158 |
-
whisper_token whisper_sample_best(struct whisper_context * ctx
|
| 2159 |
const int64_t t_start_sample_us = ggml_time_us();
|
| 2160 |
|
| 2161 |
// TODO: simplify
|
| 2162 |
-
auto res = whisper_sample_best(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab)
|
| 2163 |
|
| 2164 |
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
| 2165 |
|
|
@@ -2437,7 +2448,7 @@ int whisper_full(
|
|
| 2437 |
whisper_token id = 0;
|
| 2438 |
whisper_token tid = whisper_token_beg(ctx);
|
| 2439 |
|
| 2440 |
-
id = whisper_sample_best(ctx
|
| 2441 |
if (i > 0) {
|
| 2442 |
tid = whisper_sample_timestamp(ctx);
|
| 2443 |
}
|
|
|
|
| 1784 |
// the most basic sampling scheme - select the top token
|
| 1785 |
whisper_vocab::id whisper_sample_best(
|
| 1786 |
const whisper_vocab & vocab,
|
| 1787 |
+
const float * probs) {
|
| 1788 |
int n_logits = vocab.id_to_token.size();
|
| 1789 |
|
| 1790 |
std::vector<std::pair<double, whisper_vocab::id>> probs_id;
|
|
|
|
| 1794 |
probs_id.push_back(std::make_pair(probs[i], i));
|
| 1795 |
}
|
| 1796 |
|
| 1797 |
+
double sum_ts = 0.0;
|
| 1798 |
+
double max_tx = 0.0;
|
| 1799 |
+
|
| 1800 |
+
for (int i = 0; i < vocab.token_beg; i++) {
|
| 1801 |
+
max_tx = std::max(max_tx, probs_id[i].first);
|
| 1802 |
+
}
|
| 1803 |
+
|
| 1804 |
+
for (int i = vocab.token_beg; i < n_logits; i++) {
|
| 1805 |
+
sum_ts += probs_id[i].first;
|
| 1806 |
+
}
|
| 1807 |
+
|
| 1808 |
+
// if the probability sum of all timestamp tokesn is higher than the max probability of the text tokens - sample a
|
| 1809 |
+
// timestamp token
|
| 1810 |
+
if (sum_ts > max_tx) {
|
| 1811 |
+
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L430-L438
|
| 1812 |
+
for (int i = 0; i < vocab.token_beg; i++) {
|
| 1813 |
+
probs_id[i].first = -INFINITY;
|
| 1814 |
+
}
|
| 1815 |
+
}
|
| 1816 |
|
| 1817 |
// find the top K tokens
|
| 1818 |
+
const int top_k = 4;
|
| 1819 |
+
|
| 1820 |
std::partial_sort(
|
| 1821 |
probs_id.begin(),
|
| 1822 |
probs_id.begin() + top_k, probs_id.end(),
|
|
|
|
| 1831 |
// printf("%d: '%s' %f, %d\n", i, vocab.id_to_token.at(probs_id[i].second).c_str(), probs_id[i].first, probs_id[i].second);
|
| 1832 |
//}
|
| 1833 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1834 |
int res = 0;
|
| 1835 |
while ((probs_id[res].second == vocab.token_sot ||
|
| 1836 |
probs_id[res].second == vocab.token_solm ||
|
|
|
|
| 2166 |
return 0;
|
| 2167 |
}
|
| 2168 |
|
| 2169 |
+
whisper_token whisper_sample_best(struct whisper_context * ctx) {
|
| 2170 |
const int64_t t_start_sample_us = ggml_time_us();
|
| 2171 |
|
| 2172 |
// TODO: simplify
|
| 2173 |
+
auto res = whisper_sample_best(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab));
|
| 2174 |
|
| 2175 |
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
| 2176 |
|
|
|
|
| 2448 |
whisper_token id = 0;
|
| 2449 |
whisper_token tid = whisper_token_beg(ctx);
|
| 2450 |
|
| 2451 |
+
id = whisper_sample_best(ctx);
|
| 2452 |
if (i > 0) {
|
| 2453 |
tid = whisper_sample_timestamp(ctx);
|
| 2454 |
}
|
whisper.h
CHANGED
|
@@ -120,7 +120,7 @@ extern "C" {
|
|
| 120 |
// You can also implement your own sampling method using the whisper_get_probs() function.
|
| 121 |
// whisper_sample_best() returns the token with the highest probability
|
| 122 |
// whisper_sample_timestamp() returns the most probable timestamp token
|
| 123 |
-
WHISPER_API whisper_token whisper_sample_best(struct whisper_context * ctx
|
| 124 |
WHISPER_API whisper_token whisper_sample_timestamp(struct whisper_context * ctx);
|
| 125 |
|
| 126 |
// Return the id of the specified language, returns -1 if not found
|
|
|
|
| 120 |
// You can also implement your own sampling method using the whisper_get_probs() function.
|
| 121 |
// whisper_sample_best() returns the token with the highest probability
|
| 122 |
// whisper_sample_timestamp() returns the most probable timestamp token
|
| 123 |
+
WHISPER_API whisper_token whisper_sample_best(struct whisper_context * ctx);
|
| 124 |
WHISPER_API whisper_token whisper_sample_timestamp(struct whisper_context * ctx);
|
| 125 |
|
| 126 |
// Return the id of the specified language, returns -1 if not found
|