bert hubert commited on
Commit
4c1552f
·
1 Parent(s): 36c5e16

fix potential bug reading model data into a small size optimized string which could lead to memory corruption. In an SSO string, you can't write data to &str[0] and expect it to work well.

Browse files
Files changed (1) hide show
  1. whisper.cpp +33 -24
whisper.cpp CHANGED
@@ -429,6 +429,12 @@ struct whisper_context {
429
  int32_t exp_n_audio_ctx; // 0 - use default
430
  };
431
 
 
 
 
 
 
 
432
  // load the model from a ggml file
433
  //
434
  // file format:
@@ -455,7 +461,7 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
455
  // verify magic
456
  {
457
  uint32_t magic;
458
- fin.read((char *) &magic, sizeof(magic));
459
  if (magic != 0x67676d6c) {
460
  fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str());
461
  return false;
@@ -466,17 +472,17 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
466
  {
467
  auto & hparams = model.hparams;
468
 
469
- fin.read((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));
470
- fin.read((char *) &hparams.n_audio_ctx, sizeof(hparams.n_audio_ctx));
471
- fin.read((char *) &hparams.n_audio_state, sizeof(hparams.n_audio_state));
472
- fin.read((char *) &hparams.n_audio_head, sizeof(hparams.n_audio_head));
473
- fin.read((char *) &hparams.n_audio_layer, sizeof(hparams.n_audio_layer));
474
- fin.read((char *) &hparams.n_text_ctx, sizeof(hparams.n_text_ctx));
475
- fin.read((char *) &hparams.n_text_state, sizeof(hparams.n_text_state));
476
- fin.read((char *) &hparams.n_text_head, sizeof(hparams.n_text_head));
477
- fin.read((char *) &hparams.n_text_layer, sizeof(hparams.n_text_layer));
478
- fin.read((char *) &hparams.n_mels, sizeof(hparams.n_mels));
479
- fin.read((char *) &hparams.f16, sizeof(hparams.f16));
480
 
481
  assert(hparams.n_text_state == hparams.n_audio_state);
482
 
@@ -524,8 +530,8 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
524
  {
525
  auto & filters = wctx.model.filters;
526
 
527
- fin.read((char *) &filters.n_mel, sizeof(filters.n_mel));
528
- fin.read((char *) &filters.n_fft, sizeof(filters.n_fft));
529
 
530
  filters.data.resize(filters.n_mel * filters.n_fft);
531
  fin.read((char *) filters.data.data(), filters.data.size() * sizeof(float));
@@ -534,7 +540,7 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
534
  // load vocab
535
  {
536
  int32_t n_vocab = 0;
537
- fin.read((char *) &n_vocab, sizeof(n_vocab));
538
 
539
  //if (n_vocab != model.hparams.n_vocab) {
540
  // fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n",
@@ -545,10 +551,11 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
545
  std::string word;
546
  for (int i = 0; i < n_vocab; i++) {
547
  uint32_t len;
548
- fin.read((char *) &len, sizeof(len));
549
 
550
- word.resize(len);
551
- fin.read((char *) word.data(), len);
 
552
 
553
  vocab.token_to_id[word] = i;
554
  vocab.id_to_token[i] = word;
@@ -998,9 +1005,9 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
998
  int32_t length;
999
  int32_t ftype;
1000
 
1001
- fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
1002
- fin.read(reinterpret_cast<char *>(&length), sizeof(length));
1003
- fin.read(reinterpret_cast<char *>(&ftype), sizeof(ftype));
1004
 
1005
  if (fin.eof()) {
1006
  break;
@@ -1009,12 +1016,14 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
1009
  int32_t nelements = 1;
1010
  int32_t ne[3] = { 1, 1, 1 };
1011
  for (int i = 0; i < n_dims; ++i) {
1012
- fin.read(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
1013
  nelements *= ne[i];
1014
  }
1015
 
1016
- std::string name(length, 0);
1017
- fin.read(&name[0], length);
 
 
1018
 
1019
  if (model.tensors.find(name.data()) == model.tensors.end()) {
1020
  fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data());
 
429
  int32_t exp_n_audio_ctx; // 0 - use default
430
  };
431
 
432
+ template<typename T>
433
+ static void read_safe(std::ifstream& fin, T& dest)
434
+ {
435
+ fin.read((char*)& dest, sizeof(T));
436
+ }
437
+
438
  // load the model from a ggml file
439
  //
440
  // file format:
 
461
  // verify magic
462
  {
463
  uint32_t magic;
464
+ read_safe(fin, magic);
465
  if (magic != 0x67676d6c) {
466
  fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str());
467
  return false;
 
472
  {
473
  auto & hparams = model.hparams;
474
 
475
+ read_safe(fin, hparams.n_vocab);
476
+ read_safe(fin, hparams.n_audio_ctx);
477
+ read_safe(fin, hparams.n_audio_state);
478
+ read_safe(fin, hparams.n_audio_head);
479
+ read_safe(fin, hparams.n_audio_layer);
480
+ read_safe(fin, hparams.n_text_ctx);
481
+ read_safe(fin, hparams.n_text_state);
482
+ read_safe(fin, hparams.n_text_head);
483
+ read_safe(fin, hparams.n_text_layer);
484
+ read_safe(fin, hparams.n_mels);
485
+ read_safe(fin, hparams.f16);
486
 
487
  assert(hparams.n_text_state == hparams.n_audio_state);
488
 
 
530
  {
531
  auto & filters = wctx.model.filters;
532
 
533
+ read_safe(fin, filters.n_mel);
534
+ read_safe(fin, filters.n_fft);
535
 
536
  filters.data.resize(filters.n_mel * filters.n_fft);
537
  fin.read((char *) filters.data.data(), filters.data.size() * sizeof(float));
 
540
  // load vocab
541
  {
542
  int32_t n_vocab = 0;
543
+ read_safe(fin, n_vocab);
544
 
545
  //if (n_vocab != model.hparams.n_vocab) {
546
  // fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n",
 
551
  std::string word;
552
  for (int i = 0; i < n_vocab; i++) {
553
  uint32_t len;
554
+ read_safe(fin, len);
555
 
556
+ std::vector<char> tmp(len); // create a buffer
557
+ fin.read( &tmp[0], tmp.size() ); // read to buffer
558
+ word.assign(&tmp[0], tmp.size());
559
 
560
  vocab.token_to_id[word] = i;
561
  vocab.id_to_token[i] = word;
 
1005
  int32_t length;
1006
  int32_t ftype;
1007
 
1008
+ read_safe(fin, n_dims);
1009
+ read_safe(fin, length);
1010
+ read_safe(fin, ftype);
1011
 
1012
  if (fin.eof()) {
1013
  break;
 
1016
  int32_t nelements = 1;
1017
  int32_t ne[3] = { 1, 1, 1 };
1018
  for (int i = 0; i < n_dims; ++i) {
1019
+ read_safe(fin, ne[i]);
1020
  nelements *= ne[i];
1021
  }
1022
 
1023
+ std::string name;
1024
+ std::vector<char> tmp(length); // create a buffer
1025
+ fin.read( &tmp[0], tmp.size() ); // read to buffer
1026
+ name.assign(&tmp[0], tmp.size());
1027
 
1028
  if (model.tensors.find(name.data()) == model.tensors.end()) {
1029
  fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data());