ggerganov commited on
Commit
82e39d3
·
1 Parent(s): 974e0d1

whisper : add whisper_tokenize()

Browse files

Tokenizes a string into a list of vocabulary tokens

Files changed (2) hide show
  1. whisper.cpp +81 -0
  2. whisper.h +11 -0
whisper.cpp CHANGED
@@ -14,6 +14,7 @@
14
  #include <string>
15
  #include <thread>
16
  #include <vector>
 
17
 
18
  #define USE_FLASH_ATTN
19
  //#define USE_FLASH_FF
@@ -2161,6 +2162,71 @@ static bool log_mel_spectrogram(
2161
  return true;
2162
  }
2163
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2164
  //
2165
  // interface implementation
2166
  //
@@ -2291,6 +2357,21 @@ struct whisper_token_data whisper_sample_timestamp(struct whisper_context * ctx,
2291
  return res;
2292
  }
2293
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2294
  int whisper_lang_id(const char * lang) {
2295
  if (!g_lang.count(lang)) {
2296
  fprintf(stderr, "%s: unknown language '%s'\n", __func__, lang);
 
14
  #include <string>
15
  #include <thread>
16
  #include <vector>
17
+ #include <regex>
18
 
19
  #define USE_FLASH_ATTN
20
  //#define USE_FLASH_FF
 
2162
  return true;
2163
  }
2164
 
2165
+ // split text into tokens
2166
+ //
2167
+ // ref: https://github.com/openai/gpt-2/blob/a74da5d99abaaba920de8131d64da2862a8f213b/src/encoder.py#L53
2168
+ //
2169
+ // Regex (Python):
2170
+ // r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
2171
+ //
2172
+ // Regex (C++):
2173
+ // R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)"
2174
+ //
2175
+ static std::vector<whisper_vocab::id> tokenize(const whisper_vocab & vocab, const std::string & text) {
2176
+ std::vector<std::string> words;
2177
+
2178
+ // first split the text into words
2179
+ {
2180
+ std::string str = text;
2181
+ std::string pat = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)";
2182
+
2183
+ std::regex re(pat);
2184
+ std::smatch m;
2185
+
2186
+ while (std::regex_search(str, m, re)) {
2187
+ for (auto x : m) {
2188
+ words.push_back(x);
2189
+ }
2190
+ str = m.suffix();
2191
+ }
2192
+ }
2193
+
2194
+ // find the longest tokens that form the words:
2195
+ std::vector<whisper_vocab::id> tokens;
2196
+ for (const auto & word : words) {
2197
+ if (word.size() == 0) continue;
2198
+
2199
+ int i = 0;
2200
+ int n = word.size();
2201
+ while (i < n) {
2202
+ int j = n;
2203
+ while (j > i) {
2204
+ auto it = vocab.token_to_id.find(word.substr(i, j-i));
2205
+ if (it != vocab.token_to_id.end()) {
2206
+ tokens.push_back(it->second);
2207
+ i = j;
2208
+ break;
2209
+ }
2210
+ --j;
2211
+ }
2212
+ if (i == n) {
2213
+ break;
2214
+ }
2215
+ if (j == i) {
2216
+ auto sub = word.substr(i, 1);
2217
+ if (vocab.token_to_id.find(sub) != vocab.token_to_id.end()) {
2218
+ tokens.push_back(vocab.token_to_id.at(sub));
2219
+ } else {
2220
+ fprintf(stderr, "%s: unknown token '%s'\n", __func__, sub.data());
2221
+ }
2222
+ ++i;
2223
+ }
2224
+ }
2225
+ }
2226
+
2227
+ return tokens;
2228
+ }
2229
+
2230
  //
2231
  // interface implementation
2232
  //
 
2357
  return res;
2358
  }
2359
 
2360
+ int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_token * tokens, int n_max_tokens) {
2361
+ const auto res = tokenize(ctx->vocab, text);
2362
+
2363
+ if (res.size() > n_max_tokens) {
2364
+ fprintf(stderr, "%s: too many resulting tokens: %d (max %d)\n", __func__, (int) res.size(), n_max_tokens);
2365
+ return -1;
2366
+ }
2367
+
2368
+ for (int i = 0; i < res.size(); i++) {
2369
+ tokens[i] = res[i];
2370
+ }
2371
+
2372
+ return res.size();
2373
+ }
2374
+
2375
  int whisper_lang_id(const char * lang) {
2376
  if (!g_lang.count(lang)) {
2377
  fprintf(stderr, "%s: unknown language '%s'\n", __func__, lang);
whisper.h CHANGED
@@ -139,6 +139,17 @@ extern "C" {
139
  WHISPER_API whisper_token_data whisper_sample_best(struct whisper_context * ctx);
140
  WHISPER_API whisper_token_data whisper_sample_timestamp(struct whisper_context * ctx, bool is_initial);
141
 
 
 
 
 
 
 
 
 
 
 
 
142
  // Return the id of the specified language, returns -1 if not found
143
  WHISPER_API int whisper_lang_id(const char * lang);
144
 
 
139
  WHISPER_API whisper_token_data whisper_sample_best(struct whisper_context * ctx);
140
  WHISPER_API whisper_token_data whisper_sample_timestamp(struct whisper_context * ctx, bool is_initial);
141
 
142
+ // Convert the provided text into tokens.
143
+ // The tokens pointer must be large enough to hold the resulting tokens.
144
+ // Returns the number of tokens on success, no more than n_max_tokens
145
+ // Returns -1 on failure
146
+ // TODO: not sure if correct
147
+ WHISPER_API int whisper_tokenize(
148
+ struct whisper_context * ctx,
149
+ const char * text,
150
+ whisper_token * tokens,
151
+ int n_max_tokens);
152
+
153
  // Return the id of the specified language, returns -1 if not found
154
  WHISPER_API int whisper_lang_id(const char * lang);
155