ggerganov commited on
Commit
da4acca
·
unverified ·
1 Parent(s): ca0dc29

whisper : add full CUDA and Metal offloading (#1472)

Browse files

* whisper : migrate to ggml-backend

* whisper : fix logit reading

* whisper : fix tensor allocation during load

* whisper : fix beam-search with CUDA

* whisper : free backends + fix compile warning

* whisper : print when CUDA is enabled

* whisper : fix CoreML

* make : clean-up

* talk : fix compile warning

* whisper : support ggml_conv with CUDA and Metal (#1473)

* ggml : add CUDA support for ggml_conv

* whisper : remove ggml_repeat for conv bias + single backend

* cuda : fix im2col kernel

* metal : add im2col support + mul mat-vec f16 x f16

* bench-all : add q4 models

* whisper : clean-up

* quantize-all : fix

* ggml : im2col opts

* whisper : avoid whisper_model_data wrapper

* whisper : add note that ggml_mul_mat_pad does not work with CUDA

* whisper : factor out graph compute in common function

* whisper : fixes

* whisper : fix UB with measure buffers

* whisper : try to fix the parallel whisper_state functionality (#1479)

* whisper : try to fix the parallel whisper_state functionality

* whisper : fix multi-state Metal

* whisper : free backend instances in whisper_state

Files changed (14) hide show
  1. .gitignore +1 -0
  2. Makefile +21 -21
  3. examples/common.h +1 -1
  4. examples/talk/gpt-2.cpp +4 -4
  5. extra/bench-all.sh +9 -5
  6. extra/quantize-all.sh +7 -27
  7. ggml-cuda.cu +95 -1
  8. ggml-metal.h +1 -1
  9. ggml-metal.m +74 -6
  10. ggml-metal.metal +107 -1
  11. ggml.c +201 -1085
  12. ggml.h +13 -6
  13. whisper.cpp +486 -541
  14. whisper.h +9 -8
.gitignore CHANGED
@@ -8,6 +8,7 @@
8
  .DS_Store
9
 
10
  build/
 
11
  build-em/
12
  build-debug/
13
  build-release/
 
8
  .DS_Store
9
 
10
  build/
11
+ build-coreml/
12
  build-em/
13
  build-debug/
14
  build-release/
Makefile CHANGED
@@ -307,7 +307,7 @@ ggml-backend.o: ggml-backend.c ggml.h ggml-backend.h
307
  ggml-quants.o: ggml-quants.c ggml.h ggml-quants.h
308
  $(CC) $(CFLAGS) -c $< -o $@
309
 
310
- WHISPER_OBJ += ggml-alloc.o ggml-backend.o ggml-quants.o
311
 
312
  whisper.o: whisper.cpp whisper.h ggml.h ggml-cuda.h
313
  $(CXX) $(CXXFLAGS) -c $< -o $@
@@ -331,11 +331,11 @@ ggml-metal.o: ggml-metal.m ggml-metal.h
331
  WHISPER_OBJ += ggml-metal.o
332
  endif
333
 
334
- libwhisper.a: ggml.o $(WHISPER_OBJ)
335
- $(AR) rcs libwhisper.a ggml.o $(WHISPER_OBJ)
336
 
337
- libwhisper.so: ggml.o $(WHISPER_OBJ)
338
- $(CXX) $(CXXFLAGS) -shared -o libwhisper.so ggml.o $(WHISPER_OBJ) $(LDFLAGS)
339
 
340
  clean:
341
  rm -f *.o main stream command talk talk-llama bench quantize lsp libwhisper.a libwhisper.so
@@ -349,30 +349,30 @@ CC_SDL=`sdl2-config --cflags --libs`
349
  SRC_COMMON = examples/common.cpp examples/common-ggml.cpp
350
  SRC_COMMON_SDL = examples/common-sdl.cpp
351
 
352
- main: examples/main/main.cpp $(SRC_COMMON) ggml.o $(WHISPER_OBJ)
353
- $(CXX) $(CXXFLAGS) examples/main/main.cpp $(SRC_COMMON) ggml.o $(WHISPER_OBJ) -o main $(LDFLAGS)
354
  ./main -h
355
 
356
- bench: examples/bench/bench.cpp ggml.o $(WHISPER_OBJ)
357
- $(CXX) $(CXXFLAGS) examples/bench/bench.cpp ggml.o $(WHISPER_OBJ) -o bench $(LDFLAGS)
358
 
359
- quantize: examples/quantize/quantize.cpp ggml.o $(WHISPER_OBJ) $(SRC_COMMON)
360
- $(CXX) $(CXXFLAGS) examples/quantize/quantize.cpp $(SRC_COMMON) ggml.o $(WHISPER_OBJ) -o quantize $(LDFLAGS)
361
 
362
- stream: examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ)
363
- $(CXX) $(CXXFLAGS) examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o stream $(CC_SDL) $(LDFLAGS)
364
 
365
- command: examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ)
366
- $(CXX) $(CXXFLAGS) examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o command $(CC_SDL) $(LDFLAGS)
367
 
368
- lsp: examples/lsp/lsp.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ)
369
- $(CXX) $(CXXFLAGS) examples/lsp/lsp.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o lsp $(CC_SDL) $(LDFLAGS)
370
 
371
- talk: examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ)
372
- $(CXX) $(CXXFLAGS) examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o talk $(CC_SDL) $(LDFLAGS)
373
 
374
- talk-llama: examples/talk-llama/talk-llama.cpp examples/talk-llama/llama.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ)
375
- $(CXX) $(CXXFLAGS) examples/talk-llama/talk-llama.cpp examples/talk-llama/llama.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o talk-llama $(CC_SDL) $(LDFLAGS)
376
 
377
  #
378
  # Audio samples
 
307
  ggml-quants.o: ggml-quants.c ggml.h ggml-quants.h
308
  $(CC) $(CFLAGS) -c $< -o $@
309
 
310
+ WHISPER_OBJ += ggml.o ggml-alloc.o ggml-backend.o ggml-quants.o
311
 
312
  whisper.o: whisper.cpp whisper.h ggml.h ggml-cuda.h
313
  $(CXX) $(CXXFLAGS) -c $< -o $@
 
331
  WHISPER_OBJ += ggml-metal.o
332
  endif
333
 
334
+ libwhisper.a: $(WHISPER_OBJ)
335
+ $(AR) rcs libwhisper.a $(WHISPER_OBJ)
336
 
337
+ libwhisper.so: $(WHISPER_OBJ)
338
+ $(CXX) $(CXXFLAGS) -shared -o libwhisper.so $(WHISPER_OBJ) $(LDFLAGS)
339
 
340
  clean:
341
  rm -f *.o main stream command talk talk-llama bench quantize lsp libwhisper.a libwhisper.so
 
349
  SRC_COMMON = examples/common.cpp examples/common-ggml.cpp
350
  SRC_COMMON_SDL = examples/common-sdl.cpp
351
 
352
+ main: examples/main/main.cpp $(SRC_COMMON) $(WHISPER_OBJ)
353
+ $(CXX) $(CXXFLAGS) examples/main/main.cpp $(SRC_COMMON) $(WHISPER_OBJ) -o main $(LDFLAGS)
354
  ./main -h
355
 
356
+ bench: examples/bench/bench.cpp $(WHISPER_OBJ)
357
+ $(CXX) $(CXXFLAGS) examples/bench/bench.cpp $(WHISPER_OBJ) -o bench $(LDFLAGS)
358
 
359
+ quantize: examples/quantize/quantize.cpp $(WHISPER_OBJ) $(SRC_COMMON)
360
+ $(CXX) $(CXXFLAGS) examples/quantize/quantize.cpp $(SRC_COMMON) $(WHISPER_OBJ) -o quantize $(LDFLAGS)
361
 
362
+ stream: examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ)
363
+ $(CXX) $(CXXFLAGS) examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) -o stream $(CC_SDL) $(LDFLAGS)
364
 
365
+ command: examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ)
366
+ $(CXX) $(CXXFLAGS) examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) -o command $(CC_SDL) $(LDFLAGS)
367
 
368
+ lsp: examples/lsp/lsp.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ)
369
+ $(CXX) $(CXXFLAGS) examples/lsp/lsp.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) -o lsp $(CC_SDL) $(LDFLAGS)
370
 
371
+ talk: examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ)
372
+ $(CXX) $(CXXFLAGS) examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) -o talk $(CC_SDL) $(LDFLAGS)
373
 
374
+ talk-llama: examples/talk-llama/talk-llama.cpp examples/talk-llama/llama.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ)
375
+ $(CXX) $(CXXFLAGS) examples/talk-llama/talk-llama.cpp examples/talk-llama/llama.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) -o talk-llama $(CC_SDL) $(LDFLAGS)
376
 
377
  #
378
  # Audio samples
examples/common.h CHANGED
@@ -181,7 +181,7 @@ private:
181
  // It is assumed that PCM data is normalized to a range from -1 to 1
182
  bool write_audio(const float * data, size_t length) {
183
  for (size_t i = 0; i < length; ++i) {
184
- const auto intSample = static_cast<const int16_t>(data[i] * 32767);
185
  file.write(reinterpret_cast<const char *>(&intSample), sizeof(int16_t));
186
  dataSize += sizeof(int16_t);
187
  }
 
181
  // It is assumed that PCM data is normalized to a range from -1 to 1
182
  bool write_audio(const float * data, size_t length) {
183
  for (size_t i = 0; i < length; ++i) {
184
+ const int16_t intSample = data[i] * 32767;
185
  file.write(reinterpret_cast<const char *>(&intSample), sizeof(int16_t));
186
  dataSize += sizeof(int16_t);
187
  }
examples/talk/gpt-2.cpp CHANGED
@@ -121,13 +121,13 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
121
  return false;
122
  }
123
 
124
- std::string word;
 
125
  for (int i = 0; i < n_vocab; i++) {
126
  uint32_t len;
127
  fin.read((char *) &len, sizeof(len));
128
-
129
- word.resize(len);
130
- fin.read((char *) word.data(), len);
131
 
132
  vocab.token_to_id[word] = i;
133
  vocab.id_to_token[i] = word;
 
121
  return false;
122
  }
123
 
124
+ char word[129];
125
+
126
  for (int i = 0; i < n_vocab; i++) {
127
  uint32_t len;
128
  fin.read((char *) &len, sizeof(len));
129
+ word[len] = '\0';
130
+ fin.read((char *) word, len);
 
131
 
132
  vocab.token_to_id[word] = i;
133
  vocab.id_to_token[i] = word;
extra/bench-all.sh CHANGED
@@ -18,11 +18,11 @@ else
18
  fi
19
 
20
  models=( \
21
- "tiny" "tiny-q5_0" "tiny-q5_1" "tiny-q8_0" \
22
- "base" "base-q5_0" "base-q5_1" "base-q8_0" \
23
- "small" "small-q5_0" "small-q5_1" "small-q8_0" \
24
- "medium" "medium-q5_0" "medium-q5_1" "medium-q8_0" \
25
- "large" "large-q5_0" "large-q5_1" "large-q8_0" \
26
  )
27
 
28
  if [ "$encoder_only" -eq 0 ]; then
@@ -83,6 +83,10 @@ for model in "${models[@]}"; do
83
  config="$config COREML"
84
  fi
85
 
 
 
 
 
86
  if [[ $system_info == *"METAL = 1"* ]]; then
87
  config="$config METAL"
88
  fi
 
18
  fi
19
 
20
  models=( \
21
+ "tiny" "tiny-q4_0" "tiny-q4_1" "tiny-q5_0" "tiny-q5_1" "tiny-q8_0" \
22
+ "base" "base-q4_0" "base-q4_1" "base-q5_0" "base-q5_1" "base-q8_0" \
23
+ "small" "small-q4_0" "small-q4_1" "small-q5_0" "small-q5_1" "small-q8_0" \
24
+ "medium" "medium-q4_0" "medium-q4_1" "medium-q5_0" "medium-q5_1" "medium-q8_0" \
25
+ "large" "large-q4_0" "large-q4_1" "large-q5_0" "large-q5_1" "large-q8_0" \
26
  )
27
 
28
  if [ "$encoder_only" -eq 0 ]; then
 
83
  config="$config COREML"
84
  fi
85
 
86
+ if [[ $system_info == *"CUDA = 1"* ]]; then
87
+ config="$config CUDA"
88
+ fi
89
+
90
  if [[ $system_info == *"METAL = 1"* ]]; then
91
  config="$config METAL"
92
  fi
extra/quantize-all.sh CHANGED
@@ -15,33 +15,13 @@ declare -a filedex
15
  cd `dirname $0`
16
  cd ../
17
 
18
- # Let's loop across all the objects in the 'models' dir:
19
- for i in ./models/*; do
20
- # Check to see if it's a file or directory
21
- if [ -d "$i" ]; then
22
- # It's a directory! We should make sure it's not empty first:
23
- if [ "$(ls -A $i)" ]; then
24
- # Passed! Let's go searching for bin files (shouldn't need to go more than a layer deep here)
25
- for f in "$i"/*.bin; do
26
- # [Neuron Activation]
27
- newfile=`echo "${f##*/}" | cut -d _ -f 1`;
28
- if [ "$newfile" != "q5" ]; then
29
- ./quantize "${f}" "${i:-4}/${i:9:${#i}-4}-${qtype1}.bin" ${qtype1};
30
- ./quantize "${f}" "${i:-4}/${i:9:${#i}-4}-${qtype0}.bin" ${qtype0};
31
- filedex+=( "${i:-4}/${i:9:${#i}-4}-${qtype1}.bin" "${i:-4}/${i:9:${#i}-4}-${qtype0}.bin" )
32
- fi
33
- done
34
- fi
35
- else
36
- # It's a file! Let's make sure it's the right type:
37
- if [ "${i##*.}" == "bin" ]; then
38
- # And we probably want to skip the testing files
39
- if [ "${i:9:8}" != "for-test" ]; then
40
- # [Neuron Activation]
41
- ./quantize "${i}" "${i:-4}-${qtype1}.bin" ${qtype1};
42
- ./quantize "${i}" "${i:-4}-${qtype0}.bin" ${qtype0};
43
- filedex+=( "${i:-4}-${qtype1}.bin" "${i:-4}-${qtype0}.bin" )
44
- fi
45
  fi
46
  fi
47
  done
 
15
  cd `dirname $0`
16
  cd ../
17
 
18
+ for i in `ls ./models | grep ^ggml-.*.bin | grep -v "\-q"`; do
19
+ m="models/$i"
20
+ if [ -f "$m" ]; then
21
+ if [ "${m##*.}" == "bin" ]; then
22
+ ./quantize "${m}" "${m::${#m}-4}-${qtype1}.bin" ${qtype1};
23
+ ./quantize "${m}" "${m::${#m}-4}-${qtype0}.bin" ${qtype0};
24
+ filedex+=( "${m::${#m}-4}-${qtype1}.bin" "${m::${#m}-4}-${qtype0}.bin" )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  fi
26
  fi
27
  done
ggml-cuda.cu CHANGED
@@ -4476,6 +4476,13 @@ static __device__ void cpy_1_f32_f16(const char * cxi, char * cdsti) {
4476
  *dsti = __float2half(*xi);
4477
  }
4478
 
 
 
 
 
 
 
 
4479
  template <cpy_kernel_t cpy_1>
4480
  static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
4481
  const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
@@ -4729,6 +4736,25 @@ static __global__ void clamp_f32(const float * x, float * dst, const float min,
4729
  dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
4730
  }
4731
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4732
  template<int qk, int qr, dequantize_kernel_t dq>
4733
  static void get_rows_cuda(const void * x, const int32_t * y, float * dst, const int nrows, const int ncols, cudaStream_t stream) {
4734
  const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
@@ -5618,6 +5644,16 @@ static void ggml_cpy_f32_f16_cuda(
5618
  (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
5619
  }
5620
 
 
 
 
 
 
 
 
 
 
 
5621
  static void scale_f32_cuda(const float * x, float * dst, const float scale, const int k, cudaStream_t stream) {
5622
  const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE;
5623
  scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, k);
@@ -5701,6 +5737,15 @@ static void soft_max_f32_cuda(const float * x, float * dst, const int ncols_x, c
5701
  soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x);
5702
  }
5703
 
 
 
 
 
 
 
 
 
 
5704
  // buffer pool for cuda
5705
  #define MAX_CUDA_BUFFERS 256
5706
 
@@ -6483,7 +6528,7 @@ inline void ggml_cuda_op_mul_mat_cublas(
6483
  src1_as_f16 = (half *) ggml_cuda_pool_malloc_async(ne * sizeof(half), &src1_as, id, stream);
6484
  to_fp16_cuda(src1_ddf_i, src1_as_f16, ne, stream);
6485
  }
6486
- const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddq_i : src1_as_f16;
6487
  size_t dst_f16_as = 0;
6488
  half * dst_f16 = (half *) ggml_cuda_pool_malloc_async(row_diff*src1_ncols * sizeof(half), &dst_f16_as, id, stream);
6489
 
@@ -6659,6 +6704,45 @@ inline void ggml_cuda_op_alibi(
6659
  (void) src1_dd;
6660
  }
6661
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6662
  inline void ggml_cuda_op_diag_mask_inf(
6663
  const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
6664
  const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
@@ -7549,6 +7633,9 @@ static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, gg
7549
  } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
7550
  ggml_cpy_f32_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
7551
  ne10, ne11, nb10, nb11, nb12, main_stream);
 
 
 
7552
  } else {
7553
  fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__,
7554
  ggml_type_name(src0->type), ggml_type_name(src1->type));
@@ -7580,6 +7667,10 @@ static void ggml_cuda_alibi(const ggml_tensor * src0, const ggml_tensor * src1,
7580
  ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_alibi);
7581
  }
7582
 
 
 
 
 
7583
  static void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
7584
  (void) src0;
7585
  (void) src1;
@@ -7943,6 +8034,9 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
7943
  case GGML_OP_ALIBI:
7944
  func = ggml_cuda_alibi;
7945
  break;
 
 
 
7946
  default:
7947
  return false;
7948
  }
 
4476
  *dsti = __float2half(*xi);
4477
  }
4478
 
4479
+ static __device__ void cpy_1_f16_f16(const char * cxi, char * cdsti) {
4480
+ const half * xi = (const half *) cxi;
4481
+ half * dsti = (half *) cdsti;
4482
+
4483
+ *dsti = *xi;
4484
+ }
4485
+
4486
  template <cpy_kernel_t cpy_1>
4487
  static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
4488
  const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
 
4736
  dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
4737
  }
4738
 
4739
+ static __global__ void im2col_f32_f16(
4740
+ const float * x, half * dst,
4741
+ int ofs0, int ofs1, int IW, int IH, int CHW,
4742
+ int s0, int s1, int p0, int p1, int d0, int d1) {
4743
+ const int iiw = blockIdx.z * s0 + threadIdx.z * d0 - p0;
4744
+ const int iih = blockIdx.y * s1 + threadIdx.y * d1 - p1;
4745
+
4746
+ const int offset_dst =
4747
+ (threadIdx.x * gridDim.y * gridDim.z + blockIdx.y * gridDim.z + blockIdx.z) * CHW +
4748
+ (blockIdx.x * (blockDim.y * blockDim.z) + threadIdx.y * blockDim.z + threadIdx.z);
4749
+
4750
+ if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
4751
+ dst[offset_dst] = __float2half(0.0f);
4752
+ } else {
4753
+ const int offset_src = threadIdx.x * ofs0 + blockIdx.x * ofs1;
4754
+ dst[offset_dst] = __float2half(x[offset_src + iih * IW + iiw]);
4755
+ }
4756
+ }
4757
+
4758
  template<int qk, int qr, dequantize_kernel_t dq>
4759
  static void get_rows_cuda(const void * x, const int32_t * y, float * dst, const int nrows, const int ncols, cudaStream_t stream) {
4760
  const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
 
5644
  (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
5645
  }
5646
 
5647
+ static void ggml_cpy_f16_f16_cuda(
5648
+ const char * cx, char * cdst, const int ne,
5649
+ const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
5650
+ const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) {
5651
+
5652
+ const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
5653
+ cpy_f32_f16<cpy_1_f16_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
5654
+ (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
5655
+ }
5656
+
5657
  static void scale_f32_cuda(const float * x, float * dst, const float scale, const int k, cudaStream_t stream) {
5658
  const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE;
5659
  scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, k);
 
5737
  soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x);
5738
  }
5739
 
5740
+ static void im2col_f32_f16_cuda(const float * x, half * dst,
5741
+ int OH, int IW, int IH, int OW, int IC,
5742
+ int KH, int KW, int N, int ofs0, int ofs1,
5743
+ int s0, int s1, int p0, int p1, int d0, int d1, cudaStream_t stream) {
5744
+ dim3 block_nums(IC, OH, OW);
5745
+ dim3 block_dims(N, KH, KW);
5746
+ im2col_f32_f16<<<block_nums, block_dims, 0, stream>>>(x, dst, ofs0, ofs1, IW, IH, (IC * KH * KW), s0, s1, p0, p1, d0, d1);
5747
+ }
5748
+
5749
  // buffer pool for cuda
5750
  #define MAX_CUDA_BUFFERS 256
5751
 
 
6528
  src1_as_f16 = (half *) ggml_cuda_pool_malloc_async(ne * sizeof(half), &src1_as, id, stream);
6529
  to_fp16_cuda(src1_ddf_i, src1_as_f16, ne, stream);
6530
  }
6531
+ const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16;
6532
  size_t dst_f16_as = 0;
6533
  half * dst_f16 = (half *) ggml_cuda_pool_malloc_async(row_diff*src1_ncols * sizeof(half), &dst_f16_as, id, stream);
6534
 
 
6704
  (void) src1_dd;
6705
  }
6706
 
6707
+ inline void ggml_cuda_op_im2col(
6708
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
6709
+ const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
6710
+
6711
+ GGML_ASSERT(src0->type == GGML_TYPE_F16);
6712
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
6713
+ GGML_ASSERT( dst->type == GGML_TYPE_F16);
6714
+
6715
+ const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
6716
+ const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
6717
+ const int32_t p0 = ((const int32_t*)(dst->op_params))[2];
6718
+ const int32_t p1 = ((const int32_t*)(dst->op_params))[3];
6719
+ const int32_t d0 = ((const int32_t*)(dst->op_params))[4];
6720
+ const int32_t d1 = ((const int32_t*)(dst->op_params))[5];
6721
+
6722
+ const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1;
6723
+
6724
+ const int64_t N = src1->ne[is_2D ? 3 : 2];
6725
+ const int64_t IC = src1->ne[is_2D ? 2 : 1];
6726
+ const int64_t IH = is_2D ? src1->ne[1] : 1;
6727
+ const int64_t IW = src1->ne[0];
6728
+
6729
+ const int64_t KH = is_2D ? src0->ne[1] : 1;
6730
+ const int64_t KW = src0->ne[0];
6731
+
6732
+ const int64_t OH = is_2D ? dst->ne[2] : 1;
6733
+ const int64_t OW = dst->ne[1];
6734
+
6735
+ const size_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32
6736
+ const size_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
6737
+
6738
+ im2col_f32_f16_cuda(src1_dd, (half*) dst_dd,
6739
+ OH, IW, IH, OW, IC, KH, KW, N,
6740
+ ofs0, ofs1, s0, s1, p0, p1, d0, d1, main_stream);
6741
+
6742
+ (void) src0;
6743
+ (void) src0_dd;
6744
+ }
6745
+
6746
  inline void ggml_cuda_op_diag_mask_inf(
6747
  const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
6748
  const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
 
7633
  } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
7634
  ggml_cpy_f32_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
7635
  ne10, ne11, nb10, nb11, nb12, main_stream);
7636
+ } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
7637
+ ggml_cpy_f16_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
7638
+ ne10, ne11, nb10, nb11, nb12, main_stream);
7639
  } else {
7640
  fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__,
7641
  ggml_type_name(src0->type), ggml_type_name(src1->type));
 
7667
  ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_alibi);
7668
  }
7669
 
7670
+ void ggml_cuda_im2col(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
7671
+ ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_im2col);
7672
+ }
7673
+
7674
  static void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
7675
  (void) src0;
7676
  (void) src1;
 
8034
  case GGML_OP_ALIBI:
8035
  func = ggml_cuda_alibi;
8036
  break;
8037
+ case GGML_OP_IM2COL:
8038
+ func = ggml_cuda_im2col;
8039
+ break;
8040
  default:
8041
  return false;
8042
  }
ggml-metal.h CHANGED
@@ -26,7 +26,7 @@
26
  #include <stdbool.h>
27
 
28
  // max memory buffers that can be mapped to the device
29
- #define GGML_METAL_MAX_BUFFERS 16
30
  #define GGML_METAL_MAX_COMMAND_BUFFERS 32
31
 
32
  struct ggml_tensor;
 
26
  #include <stdbool.h>
27
 
28
  // max memory buffers that can be mapped to the device
29
+ #define GGML_METAL_MAX_BUFFERS 64
30
  #define GGML_METAL_MAX_COMMAND_BUFFERS 32
31
 
32
  struct ggml_tensor;
ggml-metal.m CHANGED
@@ -86,6 +86,7 @@ struct ggml_metal_context {
86
  GGML_METAL_DECL_KERNEL(rms_norm);
87
  GGML_METAL_DECL_KERNEL(norm);
88
  GGML_METAL_DECL_KERNEL(mul_mv_f32_f32);
 
89
  GGML_METAL_DECL_KERNEL(mul_mv_f16_f32);
90
  GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_1row);
91
  GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_l4);
@@ -114,6 +115,7 @@ struct ggml_metal_context {
114
  GGML_METAL_DECL_KERNEL(rope_f32);
115
  GGML_METAL_DECL_KERNEL(rope_f16);
116
  GGML_METAL_DECL_KERNEL(alibi_f32);
 
117
  GGML_METAL_DECL_KERNEL(cpy_f32_f16);
118
  GGML_METAL_DECL_KERNEL(cpy_f32_f32);
119
  GGML_METAL_DECL_KERNEL(cpy_f16_f16);
@@ -287,6 +289,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
287
  GGML_METAL_ADD_KERNEL(rms_norm);
288
  GGML_METAL_ADD_KERNEL(norm);
289
  GGML_METAL_ADD_KERNEL(mul_mv_f32_f32);
 
290
  GGML_METAL_ADD_KERNEL(mul_mv_f16_f32);
291
  GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_1row);
292
  GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_l4);
@@ -317,6 +320,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
317
  GGML_METAL_ADD_KERNEL(rope_f32);
318
  GGML_METAL_ADD_KERNEL(rope_f16);
319
  GGML_METAL_ADD_KERNEL(alibi_f32);
 
320
  GGML_METAL_ADD_KERNEL(cpy_f32_f16);
321
  GGML_METAL_ADD_KERNEL(cpy_f32_f32);
322
  GGML_METAL_ADD_KERNEL(cpy_f16_f16);
@@ -386,6 +390,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
386
  GGML_METAL_DEL_KERNEL(rms_norm);
387
  GGML_METAL_DEL_KERNEL(norm);
388
  GGML_METAL_DEL_KERNEL(mul_mv_f32_f32);
 
389
  GGML_METAL_DEL_KERNEL(mul_mv_f16_f32);
390
  GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_1row);
391
  GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_l4);
@@ -416,6 +421,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
416
  GGML_METAL_DEL_KERNEL(rope_f32);
417
  GGML_METAL_DEL_KERNEL(rope_f16);
418
  GGML_METAL_DEL_KERNEL(alibi_f32);
 
419
  GGML_METAL_DEL_KERNEL(cpy_f32_f16);
420
  GGML_METAL_DEL_KERNEL(cpy_f32_f32);
421
  GGML_METAL_DEL_KERNEL(cpy_f16_f16);
@@ -473,6 +479,10 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, stru
473
 
474
  const int64_t tsize = ggml_nbytes(t);
475
 
 
 
 
 
476
  // find the view that contains the tensor fully
477
  for (int i = 0; i < ctx->n_buffers; ++i) {
478
  const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data;
@@ -1139,6 +1149,7 @@ void ggml_metal_graph_compute(
1139
  switch (src0t) {
1140
  case GGML_TYPE_F32:
1141
  {
 
1142
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_f32_f32];
1143
  nrows = 4;
1144
  } break;
@@ -1146,13 +1157,18 @@ void ggml_metal_graph_compute(
1146
  {
1147
  nth0 = 32;
1148
  nth1 = 1;
1149
- if (ne11 * ne12 < 4) {
1150
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_1row];
1151
- } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
1152
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_l4];
1153
- nrows = ne11;
 
 
 
 
 
1154
  } else {
1155
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32];
1156
  nrows = 4;
1157
  }
1158
  } break;
@@ -1464,6 +1480,58 @@ void ggml_metal_graph_compute(
1464
 
1465
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1466
  } break;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1467
  case GGML_OP_DUP:
1468
  case GGML_OP_CPY:
1469
  case GGML_OP_CONT:
 
86
  GGML_METAL_DECL_KERNEL(rms_norm);
87
  GGML_METAL_DECL_KERNEL(norm);
88
  GGML_METAL_DECL_KERNEL(mul_mv_f32_f32);
89
+ GGML_METAL_DECL_KERNEL(mul_mv_f16_f16);
90
  GGML_METAL_DECL_KERNEL(mul_mv_f16_f32);
91
  GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_1row);
92
  GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_l4);
 
115
  GGML_METAL_DECL_KERNEL(rope_f32);
116
  GGML_METAL_DECL_KERNEL(rope_f16);
117
  GGML_METAL_DECL_KERNEL(alibi_f32);
118
+ GGML_METAL_DECL_KERNEL(im2col_f16);
119
  GGML_METAL_DECL_KERNEL(cpy_f32_f16);
120
  GGML_METAL_DECL_KERNEL(cpy_f32_f32);
121
  GGML_METAL_DECL_KERNEL(cpy_f16_f16);
 
289
  GGML_METAL_ADD_KERNEL(rms_norm);
290
  GGML_METAL_ADD_KERNEL(norm);
291
  GGML_METAL_ADD_KERNEL(mul_mv_f32_f32);
292
+ GGML_METAL_ADD_KERNEL(mul_mv_f16_f16);
293
  GGML_METAL_ADD_KERNEL(mul_mv_f16_f32);
294
  GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_1row);
295
  GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_l4);
 
320
  GGML_METAL_ADD_KERNEL(rope_f32);
321
  GGML_METAL_ADD_KERNEL(rope_f16);
322
  GGML_METAL_ADD_KERNEL(alibi_f32);
323
+ GGML_METAL_ADD_KERNEL(im2col_f16);
324
  GGML_METAL_ADD_KERNEL(cpy_f32_f16);
325
  GGML_METAL_ADD_KERNEL(cpy_f32_f32);
326
  GGML_METAL_ADD_KERNEL(cpy_f16_f16);
 
390
  GGML_METAL_DEL_KERNEL(rms_norm);
391
  GGML_METAL_DEL_KERNEL(norm);
392
  GGML_METAL_DEL_KERNEL(mul_mv_f32_f32);
393
+ GGML_METAL_DEL_KERNEL(mul_mv_f16_f16);
394
  GGML_METAL_DEL_KERNEL(mul_mv_f16_f32);
395
  GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_1row);
396
  GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_l4);
 
421
  GGML_METAL_DEL_KERNEL(rope_f32);
422
  GGML_METAL_DEL_KERNEL(rope_f16);
423
  GGML_METAL_DEL_KERNEL(alibi_f32);
424
+ GGML_METAL_DEL_KERNEL(im2col_f16);
425
  GGML_METAL_DEL_KERNEL(cpy_f32_f16);
426
  GGML_METAL_DEL_KERNEL(cpy_f32_f32);
427
  GGML_METAL_DEL_KERNEL(cpy_f16_f16);
 
479
 
480
  const int64_t tsize = ggml_nbytes(t);
481
 
482
+ if (t->buffer && t->buffer->backend && t->buffer->backend->context) {
483
+ ctx = t->buffer->backend->context;
484
+ }
485
+
486
  // find the view that contains the tensor fully
487
  for (int i = 0; i < ctx->n_buffers; ++i) {
488
  const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data;
 
1149
  switch (src0t) {
1150
  case GGML_TYPE_F32:
1151
  {
1152
+ GGML_ASSERT(src1t == GGML_TYPE_F32);
1153
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_f32_f32];
1154
  nrows = 4;
1155
  } break;
 
1157
  {
1158
  nth0 = 32;
1159
  nth1 = 1;
1160
+ if (src1t == GGML_TYPE_F32) {
1161
+ if (ne11 * ne12 < 4) {
1162
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_1row];
1163
+ } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
1164
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_l4];
1165
+ nrows = ne11;
1166
+ } else {
1167
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32];
1168
+ nrows = 4;
1169
+ }
1170
  } else {
1171
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f16];
1172
  nrows = 4;
1173
  }
1174
  } break;
 
1480
 
1481
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1482
  } break;
1483
+ case GGML_OP_IM2COL:
1484
+ {
1485
+ GGML_ASSERT(src0->type == GGML_TYPE_F16);
1486
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
1487
+ GGML_ASSERT( dst->type == GGML_TYPE_F16);
1488
+
1489
+ const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
1490
+ const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
1491
+ const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
1492
+ const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
1493
+ const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
1494
+ const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
1495
+ const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
1496
+
1497
+ const int32_t N = src1->ne[is_2D ? 3 : 2];
1498
+ const int32_t IC = src1->ne[is_2D ? 2 : 1];
1499
+ const int32_t IH = is_2D ? src1->ne[1] : 1;
1500
+ const int32_t IW = src1->ne[0];
1501
+
1502
+ const int32_t KH = is_2D ? src0->ne[1] : 1;
1503
+ const int32_t KW = src0->ne[0];
1504
+
1505
+ const int32_t OH = is_2D ? dst->ne[2] : 1;
1506
+ const int32_t OW = dst->ne[1];
1507
+
1508
+ const int32_t CHW = IC * KH * KW;
1509
+
1510
+ const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
1511
+ const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
1512
+
1513
+ switch (src0->type) {
1514
+ case GGML_TYPE_F32: GGML_ASSERT(false && "not implemented"); break;
1515
+ case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_im2col_f16]; break;
1516
+ default: GGML_ASSERT(false);
1517
+ };
1518
+
1519
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
1520
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1521
+ [encoder setBytes:&ofs0 length:sizeof( int32_t) atIndex:2];
1522
+ [encoder setBytes:&ofs1 length:sizeof( int32_t) atIndex:3];
1523
+ [encoder setBytes:&IW length:sizeof( int32_t) atIndex:4];
1524
+ [encoder setBytes:&IH length:sizeof( int32_t) atIndex:5];
1525
+ [encoder setBytes:&CHW length:sizeof( int32_t) atIndex:6];
1526
+ [encoder setBytes:&s0 length:sizeof( int32_t) atIndex:7];
1527
+ [encoder setBytes:&s1 length:sizeof( int32_t) atIndex:8];
1528
+ [encoder setBytes:&p0 length:sizeof( int32_t) atIndex:9];
1529
+ [encoder setBytes:&p1 length:sizeof( int32_t) atIndex:10];
1530
+ [encoder setBytes:&d0 length:sizeof( int32_t) atIndex:11];
1531
+ [encoder setBytes:&d1 length:sizeof( int32_t) atIndex:12];
1532
+
1533
+ [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
1534
+ } break;
1535
  case GGML_OP_DUP:
1536
  case GGML_OP_CPY:
1537
  case GGML_OP_CONT:
ggml-metal.metal CHANGED
@@ -792,7 +792,7 @@ kernel void kernel_mul_mv_f32_f32(
792
  constant int64_t & ne0,
793
  constant int64_t & ne1,
794
  uint3 tgpig[[threadgroup_position_in_grid]],
795
- uint tiisg[[thread_index_in_simdgroup]]) {
796
 
797
  const int64_t r0 = tgpig.x;
798
  const int64_t rb = tgpig.y*N_F32_F32;
@@ -844,6 +844,79 @@ kernel void kernel_mul_mv_f32_f32(
844
  }
845
  }
846
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
847
  kernel void kernel_mul_mv_f16_f32_1row(
848
  device const char * src0,
849
  device const char * src1,
@@ -1229,6 +1302,39 @@ kernel void kernel_rope(
1229
  template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope<float>;
1230
  template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope<half>;
1231
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1232
  kernel void kernel_cpy_f16_f16(
1233
  device const half * src0,
1234
  device half * dst,
 
792
  constant int64_t & ne0,
793
  constant int64_t & ne1,
794
  uint3 tgpig[[threadgroup_position_in_grid]],
795
+ uint tiisg[[thread_index_in_simdgroup]]) {
796
 
797
  const int64_t r0 = tgpig.x;
798
  const int64_t rb = tgpig.y*N_F32_F32;
 
844
  }
845
  }
846
 
847
+ #define N_F16_F16 4
848
+
849
+ kernel void kernel_mul_mv_f16_f16(
850
+ device const char * src0,
851
+ device const char * src1,
852
+ device float * dst,
853
+ constant int64_t & ne00,
854
+ constant int64_t & ne01,
855
+ constant int64_t & ne02,
856
+ constant uint64_t & nb00,
857
+ constant uint64_t & nb01,
858
+ constant uint64_t & nb02,
859
+ constant int64_t & ne10,
860
+ constant int64_t & ne11,
861
+ constant int64_t & ne12,
862
+ constant uint64_t & nb10,
863
+ constant uint64_t & nb11,
864
+ constant uint64_t & nb12,
865
+ constant int64_t & ne0,
866
+ constant int64_t & ne1,
867
+ uint3 tgpig[[threadgroup_position_in_grid]],
868
+ uint tiisg[[thread_index_in_simdgroup]]) {
869
+
870
+ const int64_t r0 = tgpig.x;
871
+ const int64_t rb = tgpig.y*N_F16_F16;
872
+ const int64_t im = tgpig.z;
873
+
874
+ device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
875
+
876
+ if (ne00 < 128) {
877
+ for (int row = 0; row < N_F16_F16; ++row) {
878
+ int r1 = rb + row;
879
+ if (r1 >= ne11) {
880
+ break;
881
+ }
882
+
883
+ device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12);
884
+
885
+ float sumf = 0;
886
+ for (int i = tiisg; i < ne00; i += 32) {
887
+ sumf += (half) x[i] * (half) y[i];
888
+ }
889
+
890
+ float all_sum = simd_sum(sumf);
891
+ if (tiisg == 0) {
892
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
893
+ }
894
+ }
895
+ } else {
896
+ device const half4 * x4 = (device const half4 *)x;
897
+ for (int row = 0; row < N_F16_F16; ++row) {
898
+ int r1 = rb + row;
899
+ if (r1 >= ne11) {
900
+ break;
901
+ }
902
+
903
+ device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12);
904
+ device const half4 * y4 = (device const half4 *) y;
905
+
906
+ float sumf = 0;
907
+ for (int i = tiisg; i < ne00/4; i += 32) {
908
+ for (int k = 0; k < 4; ++k) sumf += (half) x4[i][k] * y4[i][k];
909
+ }
910
+
911
+ float all_sum = simd_sum(sumf);
912
+ if (tiisg == 0) {
913
+ for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (half) x[i] * y[i];
914
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
915
+ }
916
+ }
917
+ }
918
+ }
919
+
920
  kernel void kernel_mul_mv_f16_f32_1row(
921
  device const char * src0,
922
  device const char * src1,
 
1302
  template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope<float>;
1303
  template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope<half>;
1304
 
1305
+ kernel void kernel_im2col_f16(
1306
+ device const float * x,
1307
+ device half * dst,
1308
+ constant int32_t & ofs0,
1309
+ constant int32_t & ofs1,
1310
+ constant int32_t & IW,
1311
+ constant int32_t & IH,
1312
+ constant int32_t & CHW,
1313
+ constant int32_t & s0,
1314
+ constant int32_t & s1,
1315
+ constant int32_t & p0,
1316
+ constant int32_t & p1,
1317
+ constant int32_t & d0,
1318
+ constant int32_t & d1,
1319
+ uint3 tgpig[[threadgroup_position_in_grid]],
1320
+ uint3 tgpg[[threadgroups_per_grid]],
1321
+ uint3 tpitg[[thread_position_in_threadgroup]],
1322
+ uint3 ntg[[threads_per_threadgroup]]) {
1323
+ const int32_t iiw = tgpig[2] * s0 + tpitg[2] * d0 - p0;
1324
+ const int32_t iih = tgpig[1] * s1 + tpitg[1] * d1 - p1;
1325
+
1326
+ const int32_t offset_dst =
1327
+ (tpitg[0] * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW +
1328
+ (tgpig[0] * (ntg[1] * ntg[2]) + tpitg[1] * ntg[2] + tpitg[2]);
1329
+
1330
+ if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
1331
+ dst[offset_dst] = 0.0f;
1332
+ } else {
1333
+ const int32_t offset_src = tpitg[0] * ofs0 + tgpig[0] * ofs1;
1334
+ dst[offset_dst] = x[offset_src + iih * IW + iiw];
1335
+ }
1336
+ }
1337
+
1338
  kernel void kernel_cpy_f16_f16(
1339
  device const half * src0,
1340
  device half * dst,
ggml.c CHANGED
@@ -1634,13 +1634,8 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
1634
  "ROPE_BACK",
1635
  "ALIBI",
1636
  "CLAMP",
1637
- "CONV_1D",
1638
- "CONV_1D_STAGE_0",
1639
- "CONV_1D_STAGE_1",
1640
  "CONV_TRANSPOSE_1D",
1641
- "CONV_2D",
1642
- "CONV_2D_STAGE_0",
1643
- "CONV_2D_STAGE_1",
1644
  "CONV_TRANSPOSE_2D",
1645
  "POOL_1D",
1646
  "POOL_2D",
@@ -1671,7 +1666,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
1671
  "CROSS_ENTROPY_LOSS_BACK",
1672
  };
1673
 
1674
- static_assert(GGML_OP_COUNT == 73, "GGML_OP_COUNT != 73");
1675
 
1676
  static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
1677
  "none",
@@ -1721,13 +1716,8 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
1721
  "rope_back(x)",
1722
  "alibi(x)",
1723
  "clamp(x)",
1724
- "conv_1d(x)",
1725
- "conv_1d_stage_0(x)",
1726
- "conv_1d_stage_1(x)",
1727
  "conv_transpose_1d(x)",
1728
- "conv_2d(x)",
1729
- "conv_2d_stage_0(x)",
1730
- "conv_2d_stage_1(x)",
1731
  "conv_transpose_2d(x)",
1732
  "pool_1d(x)",
1733
  "pool_2d(x)",
@@ -1758,7 +1748,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
1758
  "cross_entropy_loss_back(x,y)",
1759
  };
1760
 
1761
- static_assert(GGML_OP_COUNT == 73, "GGML_OP_COUNT != 73");
1762
 
1763
  static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
1764
 
@@ -1786,13 +1776,7 @@ static void ggml_setup_op_has_task_pass(void) {
1786
  p[GGML_OP_GET_ROWS_BACK ] = true;
1787
  p[GGML_OP_DIAG_MASK_INF ] = true;
1788
  p[GGML_OP_DIAG_MASK_ZERO ] = true;
1789
- p[GGML_OP_CONV_1D ] = true;
1790
- p[GGML_OP_CONV_1D_STAGE_0 ] = true;
1791
- p[GGML_OP_CONV_1D_STAGE_1 ] = true;
1792
  p[GGML_OP_CONV_TRANSPOSE_1D ] = true;
1793
- p[GGML_OP_CONV_2D ] = true;
1794
- p[GGML_OP_CONV_2D_STAGE_0 ] = true;
1795
- p[GGML_OP_CONV_2D_STAGE_1 ] = true;
1796
  p[GGML_OP_CONV_TRANSPOSE_2D ] = true;
1797
  p[GGML_OP_FLASH_ATTN_BACK ] = true;
1798
  p[GGML_OP_CROSS_ENTROPY_LOSS ] = true;
@@ -5137,82 +5121,6 @@ static int64_t ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p,
5137
  return (ins + 2 * p - d * (ks - 1) - 1) / s + 1;
5138
  }
5139
 
5140
- // im2col: [N, IC, IL] => [N, OL, IC*K]
5141
- // a: [OC,IC, K]
5142
- // b: [N, IC, IL]
5143
- // result: [N, OL, IC*K]
5144
- static struct ggml_tensor * ggml_conv_1d_stage_0(
5145
- struct ggml_context * ctx,
5146
- struct ggml_tensor * a,
5147
- struct ggml_tensor * b,
5148
- int s0,
5149
- int p0,
5150
- int d0) {
5151
- GGML_ASSERT(a->ne[1] == b->ne[1]);
5152
- bool is_node = false;
5153
-
5154
- if (a->grad || b->grad) {
5155
- GGML_ASSERT(false); // TODO: implement backward
5156
- is_node = true;
5157
- }
5158
-
5159
- const int64_t OL = ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0);
5160
-
5161
- const int64_t ne[4] = {
5162
- a->ne[1] * a->ne[0],
5163
- OL,
5164
- b->ne[2],
5165
- 1,
5166
- };
5167
- struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F16, 4, ne);
5168
-
5169
- int32_t params[] = { s0, p0, d0 };
5170
- ggml_set_op_params(result, params, sizeof(params));
5171
-
5172
- result->op = GGML_OP_CONV_1D_STAGE_0;
5173
- result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5174
- result->src[0] = a;
5175
- result->src[1] = b;
5176
-
5177
- return result;
5178
- }
5179
-
5180
- // ggml_conv_1d_stage_1
5181
-
5182
- // gemm: [N, OC, OL] = [OC, IC * K] x [N*OL, IC * K]
5183
- // a: [OC, IC, K]
5184
- // b: [N, OL, IC * K]
5185
- // result: [N, OC, OL]
5186
- static struct ggml_tensor * ggml_conv_1d_stage_1(
5187
- struct ggml_context * ctx,
5188
- struct ggml_tensor * a,
5189
- struct ggml_tensor * b) {
5190
-
5191
- bool is_node = false;
5192
-
5193
- if (a->grad || b->grad) {
5194
- GGML_ASSERT(false); // TODO: implement backward
5195
- is_node = true;
5196
- }
5197
-
5198
- const int64_t ne[4] = {
5199
- b->ne[1],
5200
- a->ne[2],
5201
- b->ne[2],
5202
- 1,
5203
- };
5204
- struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
5205
-
5206
- result->op = GGML_OP_CONV_1D_STAGE_1;
5207
- result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5208
- result->src[0] = a;
5209
- result->src[1] = b;
5210
-
5211
- return result;
5212
- }
5213
-
5214
- // ggml_conv_1d
5215
-
5216
  GGML_API struct ggml_tensor * ggml_conv_1d(
5217
  struct ggml_context * ctx,
5218
  struct ggml_tensor * a,
@@ -5220,43 +5128,17 @@ GGML_API struct ggml_tensor * ggml_conv_1d(
5220
  int s0,
5221
  int p0,
5222
  int d0) {
5223
- struct ggml_tensor * result = ggml_conv_1d_stage_0(ctx, a, b, s0, p0, d0);
5224
- result = ggml_conv_1d_stage_1(ctx, a, result);
5225
- return result;
5226
- }
5227
 
5228
- // GGML_API struct ggml_tensor * ggml_conv_1d(
5229
- // struct ggml_context * ctx,
5230
- // struct ggml_tensor * a,
5231
- // struct ggml_tensor * b,
5232
- // int s0,
5233
- // int p0,
5234
- // int d0) {
5235
- // GGML_ASSERT(ggml_is_matrix(b));
5236
- // GGML_ASSERT(a->ne[1] == b->ne[1]);
5237
- // bool is_node = false;
5238
 
5239
- // if (a->grad || b->grad) {
5240
- // GGML_ASSERT(false); // TODO: implement backward
5241
- // is_node = true;
5242
- // }
5243
 
5244
- // const int64_t ne[4] = {
5245
- // ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0),
5246
- // a->ne[2], 1, 1,
5247
- // };
5248
- // struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 2, ne);
5249
-
5250
- // int32_t params[] = { s0, p0, d0 };
5251
- // ggml_set_op_params(result, params, sizeof(params));
5252
-
5253
- // result->op = GGML_OP_CONV_1D;
5254
- // result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5255
- // result->src[0] = a;
5256
- // result->src[1] = b;
5257
-
5258
- // return result;
5259
- // }
5260
 
5261
  // ggml_conv_1d_ph
5262
 
@@ -5319,7 +5201,7 @@ GGML_API struct ggml_tensor * ggml_conv_transpose_1d(
5319
  // a: [OC,IC, KH, KW]
5320
  // b: [N, IC, IH, IW]
5321
  // result: [N, OH, OW, IC*KH*KW]
5322
- static struct ggml_tensor * ggml_conv_2d_stage_0(
5323
  struct ggml_context * ctx,
5324
  struct ggml_tensor * a,
5325
  struct ggml_tensor * b,
@@ -5328,9 +5210,14 @@ static struct ggml_tensor * ggml_conv_2d_stage_0(
5328
  int p0,
5329
  int p1,
5330
  int d0,
5331
- int d1) {
 
5332
 
5333
- GGML_ASSERT(a->ne[2] == b->ne[2]);
 
 
 
 
5334
  bool is_node = false;
5335
 
5336
  if (a->grad || b->grad) {
@@ -5338,81 +5225,51 @@ static struct ggml_tensor * ggml_conv_2d_stage_0(
5338
  is_node = true;
5339
  }
5340
 
5341
- const int64_t OH = ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1);
5342
- const int64_t OW = ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0);
5343
 
5344
  const int64_t ne[4] = {
5345
- a->ne[2] * a->ne[1] * a->ne[0],
5346
  OW,
5347
- OH,
5348
- b->ne[3],
5349
  };
5350
- struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F16, 4, ne);
5351
 
5352
- int32_t params[] = { s0, s1, p0, p1, d0, d1 };
 
5353
  ggml_set_op_params(result, params, sizeof(params));
5354
 
5355
- result->op = GGML_OP_CONV_2D_STAGE_0;
5356
- result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5357
- result->src[0] = a;
5358
- result->src[1] = b;
5359
-
5360
- return result;
5361
-
5362
- }
5363
-
5364
- // gemm: [N, OC, OH, OW] = [OC, IC * KH * KW] x [N*OH*OW, IC * KH * KW]
5365
- // a: [OC, IC, KH, KW]
5366
- // b: [N, OH, OW, IC * KH * KW]
5367
- // result: [N, OC, OH, OW]
5368
- static struct ggml_tensor * ggml_conv_2d_stage_1(
5369
- struct ggml_context * ctx,
5370
- struct ggml_tensor * a,
5371
- struct ggml_tensor * b) {
5372
-
5373
- bool is_node = false;
5374
-
5375
- if (a->grad || b->grad) {
5376
- GGML_ASSERT(false); // TODO: implement backward
5377
- is_node = true;
5378
- }
5379
-
5380
- const int64_t ne[4] = {
5381
- b->ne[1],
5382
- b->ne[2],
5383
- a->ne[3],
5384
- b->ne[3],
5385
- };
5386
- struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
5387
-
5388
- result->op = GGML_OP_CONV_2D_STAGE_1;
5389
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5390
  result->src[0] = a;
5391
  result->src[1] = b;
5392
 
5393
  return result;
5394
-
5395
  }
5396
 
5397
  // a: [OC,IC, KH, KW]
5398
  // b: [N, IC, IH, IW]
5399
  // result: [N, OC, OH, OW]
5400
  struct ggml_tensor * ggml_conv_2d(
5401
- struct ggml_context * ctx,
5402
- struct ggml_tensor * a,
5403
- struct ggml_tensor * b,
5404
- int s0,
5405
- int s1,
5406
- int p0,
5407
- int p1,
5408
- int d0,
5409
- int d1) {
 
5410
 
5411
- struct ggml_tensor * result = ggml_conv_2d_stage_0(ctx, a, b, s0, s1, p0, p1, d0, d1); // [N, OH, OW, IC * KH * KW]
5412
- result = ggml_conv_2d_stage_1(ctx, a, result);
 
 
5413
 
5414
- return result;
5415
 
 
5416
  }
5417
 
5418
  // ggml_conv_2d_sk_p0
@@ -9507,6 +9364,8 @@ static bool ggml_compute_forward_mul_mat_use_blas(
9507
  // TODO: find the optimal values for these
9508
  if (ggml_is_contiguous(src0) &&
9509
  ggml_is_contiguous(src1) &&
 
 
9510
  (ne0 >= 32 && ne1 >= 32 && ne10 >= 32)) {
9511
 
9512
  /*printf("BLAS: %d %d %d %d %d\n", ne0, ne1, ne10, ne00, ne01);*/
@@ -9517,6 +9376,7 @@ static bool ggml_compute_forward_mul_mat_use_blas(
9517
  }
9518
  #endif
9519
 
 
9520
  static void ggml_compute_forward_mul_mat(
9521
  const struct ggml_compute_params * params,
9522
  const struct ggml_tensor * src0,
@@ -9545,7 +9405,7 @@ static void ggml_compute_forward_mul_mat(
9545
 
9546
  // we don't support permuted src0 or src1
9547
  GGML_ASSERT(nb00 == ggml_type_size(type));
9548
- GGML_ASSERT(nb10 == sizeof(float));
9549
 
9550
  // dst cannot be transposed or permuted
9551
  GGML_ASSERT(nb0 == sizeof(float));
@@ -11637,9 +11497,9 @@ static void ggml_compute_forward_rope_back(
11637
  }
11638
  }
11639
 
11640
- // ggml_compute_forward_conv_1d
11641
 
11642
- static void ggml_compute_forward_conv_1d_f16_f32(
11643
  const struct ggml_compute_params * params,
11644
  const struct ggml_tensor * src0,
11645
  const struct ggml_tensor * src1,
@@ -11656,14 +11516,7 @@ static void ggml_compute_forward_conv_1d_f16_f32(
11656
  const int ith = params->ith;
11657
  const int nth = params->nth;
11658
 
11659
- const int nk = ne00;
11660
-
11661
- // size of the convolution row - the kernel size unrolled across all input channels
11662
- const int ew0 = nk*ne01;
11663
-
11664
- const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
11665
- const int32_t p0 = ((const int32_t*)(dst->op_params))[1];
11666
- const int32_t d0 = ((const int32_t*)(dst->op_params))[2];
11667
 
11668
  GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
11669
  GGML_ASSERT(nb10 == sizeof(float));
@@ -11671,23 +11524,37 @@ static void ggml_compute_forward_conv_1d_f16_f32(
11671
  if (params->type == GGML_TASK_INIT) {
11672
  memset(params->wdata, 0, params->wsize);
11673
 
11674
- ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
 
 
11675
 
11676
- for (int64_t i11 = 0; i11 < ne11; i11++) {
11677
- const float * const src = (float *)((char *) src1->data + i11*nb11);
11678
- ggml_fp16_t * dst_data = wdata;
 
 
 
 
 
 
 
11679
 
11680
- for (int64_t i0 = 0; i0 < ne0; i0++) {
11681
- for (int64_t ik = 0; ik < nk; ik++) {
11682
- const int idx0 = i0*s0 + ik*d0 - p0;
 
11683
 
11684
- if(!(idx0 < 0 || idx0 >= ne10)) {
11685
- dst_data[i0*ew0 + i11*nk + ik] = GGML_FP32_TO_FP16(src[idx0]);
11686
- }
 
11687
  }
11688
  }
11689
  }
11690
 
 
 
 
11691
  return;
11692
  }
11693
 
@@ -11695,8 +11562,10 @@ static void ggml_compute_forward_conv_1d_f16_f32(
11695
  return;
11696
  }
11697
 
 
 
11698
  // total rows in dst
11699
- const int nr = ne2;
11700
 
11701
  // rows per thread
11702
  const int dr = (nr + nth - 1)/nth;
@@ -11705,22 +11574,26 @@ static void ggml_compute_forward_conv_1d_f16_f32(
11705
  const int ir0 = dr*ith;
11706
  const int ir1 = MIN(ir0 + dr, nr);
11707
 
11708
- ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
11709
-
11710
- for (int i2 = 0; i2 < ne2; i2++) {
11711
- for (int i1 = ir0; i1 < ir1; i1++) {
11712
- float * dst_data = (float *)((char *) dst->data + i2*nb2 + i1*nb1);
11713
 
11714
- for (int i0 = 0; i0 < ne0; i0++) {
11715
- ggml_vec_dot_f16(ew0, dst_data + i0,
11716
- (ggml_fp16_t *) ((char *) src0->data + i1*nb02),
11717
- (ggml_fp16_t *) wdata + i2*nb2 + i0*ew0);
 
 
 
 
 
 
 
11718
  }
11719
  }
11720
  }
11721
  }
11722
 
11723
- static void ggml_compute_forward_conv_1d_f32(
11724
  const struct ggml_compute_params * params,
11725
  const struct ggml_tensor * src0,
11726
  const struct ggml_tensor * src1,
@@ -11737,13 +11610,7 @@ static void ggml_compute_forward_conv_1d_f32(
11737
  const int ith = params->ith;
11738
  const int nth = params->nth;
11739
 
11740
- const int nk = ne00;
11741
-
11742
- const int ew0 = nk*ne01;
11743
-
11744
- const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
11745
- const int32_t p0 = ((const int32_t*)(dst->op_params))[1];
11746
- const int32_t d0 = ((const int32_t*)(dst->op_params))[2];
11747
 
11748
  GGML_ASSERT(nb00 == sizeof(float));
11749
  GGML_ASSERT(nb10 == sizeof(float));
@@ -11751,23 +11618,37 @@ static void ggml_compute_forward_conv_1d_f32(
11751
  if (params->type == GGML_TASK_INIT) {
11752
  memset(params->wdata, 0, params->wsize);
11753
 
11754
- float * const wdata = (float *) params->wdata + 0;
 
 
11755
 
11756
- for (int64_t i11 = 0; i11 < ne11; i11++) {
11757
- const float * const src = (float *)((char *) src1->data + i11*nb11);
11758
- float * dst_data = wdata;
 
 
 
 
 
 
 
11759
 
11760
- for (int64_t i0 = 0; i0 < ne0; i0++) {
11761
- for (int64_t ik = 0; ik < nk; ik++) {
11762
- const int idx0 = i0*s0 + ik*d0 - p0;
 
11763
 
11764
- if(!(idx0 < 0 || idx0 >= ne10)) {
11765
- dst_data[i0*ew0 + i11*nk + ik] = src[idx0];
11766
- }
 
11767
  }
11768
  }
11769
  }
11770
 
 
 
 
11771
  return;
11772
  }
11773
 
@@ -11775,8 +11656,10 @@ static void ggml_compute_forward_conv_1d_f32(
11775
  return;
11776
  }
11777
 
 
 
11778
  // total rows in dst
11779
- const int nr = ne02;
11780
 
11781
  // rows per thread
11782
  const int dr = (nr + nth - 1)/nth;
@@ -11785,94 +11668,50 @@ static void ggml_compute_forward_conv_1d_f32(
11785
  const int ir0 = dr*ith;
11786
  const int ir1 = MIN(ir0 + dr, nr);
11787
 
11788
- float * const wdata = (float *) params->wdata + 0;
11789
-
11790
- for (int i2 = 0; i2 < ne2; i2++) {
11791
- for (int i1 = ir0; i1 < ir1; i1++) {
11792
- float * dst_data = (float *)((char *) dst->data + i2*nb2 + i1*nb1);
11793
 
11794
- for (int i0 = 0; i0 < ne0; i0++) {
11795
- ggml_vec_dot_f32(ew0, dst_data + i0,
11796
- (float *) ((char *) src0->data + i1*nb02),
11797
- (float *) wdata + i2*nb2 + i0*ew0);
 
 
 
 
 
 
 
11798
  }
11799
  }
11800
  }
11801
  }
11802
 
11803
- // TODO: reuse ggml_mul_mat or implement ggml_im2col and remove stage_0 and stage_1
11804
- static void gemm_f16_out_f32(int64_t m, int64_t n, int64_t k,
11805
- ggml_fp16_t * A,
11806
- ggml_fp16_t * B,
11807
- float * C,
11808
- const int ith, const int nth) {
11809
- // does not seem to make a difference
11810
- int64_t m0, m1, n0, n1;
11811
- // patches per thread
11812
- if (m > n) {
11813
- n0 = 0;
11814
- n1 = n;
11815
-
11816
- // total patches in dst
11817
- const int np = m;
11818
-
11819
- // patches per thread
11820
- const int dp = (np + nth - 1)/nth;
11821
-
11822
- // patch range for this thread
11823
- m0 = dp*ith;
11824
- m1 = MIN(m0 + dp, np);
11825
- } else {
11826
- m0 = 0;
11827
- m1 = m;
11828
-
11829
- // total patches in dst
11830
- const int np = n;
11831
-
11832
- // patches per thread
11833
- const int dp = (np + nth - 1)/nth;
11834
-
11835
- // patch range for this thread
11836
- n0 = dp*ith;
11837
- n1 = MIN(n0 + dp, np);
11838
- }
11839
-
11840
- // block-tiling attempt
11841
- int64_t blck_n = 16;
11842
- int64_t blck_m = 16;
11843
-
11844
- // int64_t CACHE_SIZE = 2 * 1024 * 1024; // 2MB
11845
- // int64_t blck_size = CACHE_SIZE / (sizeof(float) + 2 * sizeof(ggml_fp16_t) * K);
11846
- // if (blck_size > 0) {
11847
- // blck_0 = 4;
11848
- // blck_1 = blck_size / blck_0;
11849
- // if (blck_1 < 0) {
11850
- // blck_1 = 1;
11851
- // }
11852
- // // blck_0 = (int64_t)sqrt(blck_size);
11853
- // // blck_1 = blck_0;
11854
- // }
11855
- // // printf("%zd %zd %zd %zd\n", blck_size, K, blck_0, blck_1);
11856
-
11857
- for (int j = n0; j < n1; j+=blck_n) {
11858
- for (int i = m0; i < m1; i+=blck_m) {
11859
- // printf("i j k => %d %d %d\n", i, j, K);
11860
- for (int ii = i; ii < i + blck_m && ii < m1; ii++) {
11861
- for (int jj = j; jj < j + blck_n && jj < n1; jj++) {
11862
- ggml_vec_dot_f16(k,
11863
- C + ii*n + jj,
11864
- A + ii * k,
11865
- B + jj * k);
11866
- }
11867
- }
11868
- }
11869
  }
11870
  }
11871
 
11872
- // src0: kernel [OC, IC, K]
11873
- // src1: signal [N, IC, IL]
11874
- // dst: result [N, OL, IC*K]
11875
- static void ggml_compute_forward_conv_1d_stage_0_f32(
11876
  const struct ggml_compute_params * params,
11877
  const struct ggml_tensor * src0,
11878
  const struct ggml_tensor * src1,
@@ -11886,26 +11725,35 @@ static void ggml_compute_forward_conv_1d_stage_0_f32(
11886
 
11887
  GGML_TENSOR_BINARY_OP_LOCALS;
11888
 
11889
- const int64_t N = ne12;
11890
- const int64_t IC = ne11;
11891
- const int64_t IL = ne10;
11892
-
11893
- const int64_t K = ne00;
11894
-
11895
- const int64_t OL = ne1;
11896
 
11897
  const int ith = params->ith;
11898
  const int nth = params->nth;
11899
 
11900
- const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
11901
- const int32_t p0 = ((const int32_t*)(dst->op_params))[1];
11902
- const int32_t d0 = ((const int32_t*)(dst->op_params))[2];
 
 
 
 
 
 
 
 
 
 
11903
 
11904
  GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
11905
  GGML_ASSERT(nb10 == sizeof(float));
11906
 
11907
  if (params->type == GGML_TASK_INIT) {
11908
- memset(dst->data, 0, ggml_nbytes(dst));
11909
  return;
11910
  }
11911
 
@@ -11913,23 +11761,30 @@ static void ggml_compute_forward_conv_1d_stage_0_f32(
11913
  return;
11914
  }
11915
 
11916
- // im2col: [N, IC, IL] => [N, OL, IC*K]
11917
  {
11918
  ggml_fp16_t * const wdata = (ggml_fp16_t *) dst->data;
11919
 
11920
  for (int64_t in = 0; in < N; in++) {
11921
- for (int64_t iol = 0; iol < OL; iol++) {
11922
- for (int64_t iic = ith; iic < IC; iic+=nth) {
 
11923
 
11924
- // micro kernel
11925
- ggml_fp16_t * dst_data = wdata + (in*OL + iol)*(IC*K); // [IC, K]
11926
- const float * const src_data = (float *)((char *) src1->data + in*nb12 + iic*nb11); // [IL]
11927
 
11928
- for (int64_t ik = 0; ik < K; ik++) {
11929
- const int64_t iil = iol*s0 + ik*d0 - p0;
 
 
11930
 
11931
- if (!(iil < 0 || iil >= IL)) {
11932
- dst_data[iic*K + ik] = GGML_FP32_TO_FP16(src_data[iil]);
 
 
 
 
11933
  }
11934
  }
11935
  }
@@ -11938,627 +11793,7 @@ static void ggml_compute_forward_conv_1d_stage_0_f32(
11938
  }
11939
  }
11940
 
11941
- // gemm: [N, OC, OL] = [OC, IC * K] x [N*OL, IC * K]
11942
- // src0: [OC, IC, K]
11943
- // src1: [N, OL, IC * K]
11944
- // result: [N, OC, OL]
11945
- static void ggml_compute_forward_conv_1d_stage_1_f16(
11946
- const struct ggml_compute_params * params,
11947
- const struct ggml_tensor * src0,
11948
- const struct ggml_tensor * src1,
11949
- struct ggml_tensor * dst) {
11950
- GGML_ASSERT(src0->type == GGML_TYPE_F16);
11951
- GGML_ASSERT(src1->type == GGML_TYPE_F16);
11952
- GGML_ASSERT( dst->type == GGML_TYPE_F32);
11953
-
11954
- int64_t t0 = ggml_perf_time_us();
11955
- UNUSED(t0);
11956
-
11957
- if (params->type == GGML_TASK_INIT) {
11958
- return;
11959
- }
11960
-
11961
- if (params->type == GGML_TASK_FINALIZE) {
11962
- return;
11963
- }
11964
-
11965
- GGML_TENSOR_BINARY_OP_LOCALS;
11966
-
11967
- GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
11968
- GGML_ASSERT(nb10 == sizeof(ggml_fp16_t));
11969
- GGML_ASSERT(nb0 == sizeof(float));
11970
-
11971
- const int N = ne12;
11972
- const int OL = ne11;
11973
-
11974
- const int OC = ne02;
11975
- const int IC = ne01;
11976
- const int K = ne00;
11977
-
11978
- const int ith = params->ith;
11979
- const int nth = params->nth;
11980
-
11981
- int64_t m = OC;
11982
- int64_t n = OL;
11983
- int64_t k = IC * K;
11984
-
11985
- // [N, OC, OL] = [OC, IC * K] x [N*OL, IC * K]
11986
- for (int i = 0; i < N; i++) {
11987
- ggml_fp16_t * A = (ggml_fp16_t *)src0->data; // [m, k]
11988
- ggml_fp16_t * B = (ggml_fp16_t *)src1->data + i * m * k; // [n, k]
11989
- float * C = (float *)dst->data + i * m * n; // [m, n]
11990
-
11991
- gemm_f16_out_f32(m, n, k, A, B, C, ith, nth);
11992
- }
11993
- }
11994
-
11995
- static void ggml_compute_forward_conv_1d(
11996
- const struct ggml_compute_params * params,
11997
- const struct ggml_tensor * src0,
11998
- const struct ggml_tensor * src1,
11999
- struct ggml_tensor * dst) {
12000
- switch(src0->type) {
12001
- case GGML_TYPE_F16:
12002
- {
12003
- ggml_compute_forward_conv_1d_f16_f32(params, src0, src1, dst);
12004
- } break;
12005
- case GGML_TYPE_F32:
12006
- {
12007
- ggml_compute_forward_conv_1d_f32(params, src0, src1, dst);
12008
- } break;
12009
- default:
12010
- {
12011
- GGML_ASSERT(false);
12012
- } break;
12013
- }
12014
- }
12015
-
12016
- static void ggml_compute_forward_conv_1d_stage_0(
12017
- const struct ggml_compute_params * params,
12018
- const struct ggml_tensor * src0,
12019
- const struct ggml_tensor * src1,
12020
- struct ggml_tensor * dst) {
12021
- switch(src0->type) {
12022
- case GGML_TYPE_F16:
12023
- {
12024
- ggml_compute_forward_conv_1d_stage_0_f32(params, src0, src1, dst);
12025
- } break;
12026
- default:
12027
- {
12028
- GGML_ASSERT(false);
12029
- } break;
12030
- }
12031
- }
12032
-
12033
- static void ggml_compute_forward_conv_1d_stage_1(
12034
- const struct ggml_compute_params * params,
12035
- const struct ggml_tensor * src0,
12036
- const struct ggml_tensor * src1,
12037
- struct ggml_tensor * dst) {
12038
- switch(src0->type) {
12039
- case GGML_TYPE_F16:
12040
- {
12041
- ggml_compute_forward_conv_1d_stage_1_f16(params, src0, src1, dst);
12042
- } break;
12043
- default:
12044
- {
12045
- GGML_ASSERT(false);
12046
- } break;
12047
- }
12048
- }
12049
-
12050
- // ggml_compute_forward_conv_transpose_1d
12051
-
12052
- static void ggml_compute_forward_conv_transpose_1d_f16_f32(
12053
- const struct ggml_compute_params * params,
12054
- const struct ggml_tensor * src0,
12055
- const struct ggml_tensor * src1,
12056
- struct ggml_tensor * dst) {
12057
- GGML_ASSERT(src0->type == GGML_TYPE_F16);
12058
- GGML_ASSERT(src1->type == GGML_TYPE_F32);
12059
- GGML_ASSERT( dst->type == GGML_TYPE_F32);
12060
-
12061
- int64_t t0 = ggml_perf_time_us();
12062
- UNUSED(t0);
12063
-
12064
- GGML_TENSOR_BINARY_OP_LOCALS
12065
-
12066
- const int ith = params->ith;
12067
- const int nth = params->nth;
12068
-
12069
- const int nk = ne00*ne01*ne02;
12070
-
12071
- GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
12072
- GGML_ASSERT(nb10 == sizeof(float));
12073
-
12074
- if (params->type == GGML_TASK_INIT) {
12075
- memset(params->wdata, 0, params->wsize);
12076
-
12077
- // permute kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout)
12078
- {
12079
- ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
12080
-
12081
- for (int64_t i02 = 0; i02 < ne02; i02++) {
12082
- for (int64_t i01 = 0; i01 < ne01; i01++) {
12083
- const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01);
12084
- ggml_fp16_t * dst_data = wdata + i01*ne00*ne02;
12085
- for (int64_t i00 = 0; i00 < ne00; i00++) {
12086
- dst_data[i00*ne02 + i02] = src[i00];
12087
- }
12088
- }
12089
- }
12090
- }
12091
-
12092
- // permute source data (src1) from (L x Cin) to (Cin x L)
12093
- {
12094
- ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + nk;
12095
- ggml_fp16_t * dst_data = wdata;
12096
-
12097
- for (int64_t i11 = 0; i11 < ne11; i11++) {
12098
- const float * const src = (float *)((char *) src1->data + i11*nb11);
12099
- for (int64_t i10 = 0; i10 < ne10; i10++) {
12100
- dst_data[i10*ne11 + i11] = GGML_FP32_TO_FP16(src[i10]);
12101
- }
12102
- }
12103
- }
12104
-
12105
- // need to zero dst since we are accumulating into it
12106
- memset(dst->data, 0, ggml_nbytes(dst));
12107
-
12108
- return;
12109
- }
12110
-
12111
- if (params->type == GGML_TASK_FINALIZE) {
12112
- return;
12113
- }
12114
-
12115
- const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
12116
-
12117
- // total rows in dst
12118
- const int nr = ne1;
12119
-
12120
- // rows per thread
12121
- const int dr = (nr + nth - 1)/nth;
12122
-
12123
- // row range for this thread
12124
- const int ir0 = dr*ith;
12125
- const int ir1 = MIN(ir0 + dr, nr);
12126
-
12127
- ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
12128
- ggml_fp16_t * const wdata_src = wdata + nk;
12129
-
12130
- for (int i1 = ir0; i1 < ir1; i1++) {
12131
- float * dst_data = (float *)((char *) dst->data + i1*nb1);
12132
- ggml_fp16_t * wdata_kernel = wdata + i1*ne02*ne00;
12133
- for (int i10 = 0; i10 < ne10; i10++) {
12134
- const int i1n = i10*ne11;
12135
- for (int i00 = 0; i00 < ne00; i00++) {
12136
- float v = 0;
12137
- ggml_vec_dot_f16(ne02, &v,
12138
- (ggml_fp16_t *) wdata_src + i1n,
12139
- (ggml_fp16_t *) wdata_kernel + i00*ne02);
12140
- dst_data[i10*s0 + i00] += v;
12141
- }
12142
- }
12143
- }
12144
- }
12145
-
12146
- static void ggml_compute_forward_conv_transpose_1d_f32(
12147
- const struct ggml_compute_params * params,
12148
- const struct ggml_tensor * src0,
12149
- const struct ggml_tensor * src1,
12150
- struct ggml_tensor * dst) {
12151
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
12152
- GGML_ASSERT(src1->type == GGML_TYPE_F32);
12153
- GGML_ASSERT( dst->type == GGML_TYPE_F32);
12154
-
12155
- int64_t t0 = ggml_perf_time_us();
12156
- UNUSED(t0);
12157
-
12158
- GGML_TENSOR_BINARY_OP_LOCALS
12159
-
12160
- const int ith = params->ith;
12161
- const int nth = params->nth;
12162
-
12163
- const int nk = ne00*ne01*ne02;
12164
-
12165
- GGML_ASSERT(nb00 == sizeof(float));
12166
- GGML_ASSERT(nb10 == sizeof(float));
12167
-
12168
- if (params->type == GGML_TASK_INIT) {
12169
- memset(params->wdata, 0, params->wsize);
12170
-
12171
- // prepare kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout)
12172
- {
12173
- float * const wdata = (float *) params->wdata + 0;
12174
-
12175
- for (int64_t i02 = 0; i02 < ne02; i02++) {
12176
- for (int64_t i01 = 0; i01 < ne01; i01++) {
12177
- const float * const src = (float *)((char *) src0->data + i02*nb02 + i01*nb01);
12178
- float * dst_data = wdata + i01*ne00*ne02;
12179
- for (int64_t i00 = 0; i00 < ne00; i00++) {
12180
- dst_data[i00*ne02 + i02] = src[i00];
12181
- }
12182
- }
12183
- }
12184
- }
12185
-
12186
- // prepare source data (src1)
12187
- {
12188
- float * const wdata = (float *) params->wdata + nk;
12189
- float * dst_data = wdata;
12190
-
12191
- for (int64_t i11 = 0; i11 < ne11; i11++) {
12192
- const float * const src = (float *)((char *) src1->data + i11*nb11);
12193
- for (int64_t i10 = 0; i10 < ne10; i10++) {
12194
- dst_data[i10*ne11 + i11] = src[i10];
12195
- }
12196
- }
12197
- }
12198
-
12199
- // need to zero dst since we are accumulating into it
12200
- memset(dst->data, 0, ggml_nbytes(dst));
12201
-
12202
- return;
12203
- }
12204
-
12205
- if (params->type == GGML_TASK_FINALIZE) {
12206
- return;
12207
- }
12208
-
12209
- const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
12210
-
12211
- // total rows in dst
12212
- const int nr = ne1;
12213
-
12214
- // rows per thread
12215
- const int dr = (nr + nth - 1)/nth;
12216
-
12217
- // row range for this thread
12218
- const int ir0 = dr*ith;
12219
- const int ir1 = MIN(ir0 + dr, nr);
12220
-
12221
- float * const wdata = (float *) params->wdata + 0;
12222
- float * const wdata_src = wdata + nk;
12223
-
12224
- for (int i1 = ir0; i1 < ir1; i1++) {
12225
- float * dst_data = (float *)((char *) dst->data + i1*nb1);
12226
- float * wdata_kernel = wdata + i1*ne02*ne00;
12227
- for (int i10 = 0; i10 < ne10; i10++) {
12228
- const int i1n = i10*ne11;
12229
- for (int i00 = 0; i00 < ne00; i00++) {
12230
- float v = 0;
12231
- ggml_vec_dot_f32(ne02, &v,
12232
- wdata_src + i1n,
12233
- wdata_kernel + i00*ne02);
12234
- dst_data[i10*s0 + i00] += v;
12235
- }
12236
- }
12237
- }
12238
- }
12239
-
12240
- static void ggml_compute_forward_conv_transpose_1d(
12241
- const struct ggml_compute_params * params,
12242
- const struct ggml_tensor * src0,
12243
- const struct ggml_tensor * src1,
12244
- struct ggml_tensor * dst) {
12245
- switch (src0->type) {
12246
- case GGML_TYPE_F16:
12247
- {
12248
- ggml_compute_forward_conv_transpose_1d_f16_f32(params, src0, src1, dst);
12249
- } break;
12250
- case GGML_TYPE_F32:
12251
- {
12252
- ggml_compute_forward_conv_transpose_1d_f32(params, src0, src1, dst);
12253
- } break;
12254
- default:
12255
- {
12256
- GGML_ASSERT(false);
12257
- } break;
12258
- }
12259
- }
12260
-
12261
- // ggml_compute_forward_conv_2d
12262
-
12263
- // src0: kernel [OC, IC, KH, KW]
12264
- // src1: image [N, IC, IH, IW]
12265
- // dst: result [N, OH, OW, IC*KH*KW]
12266
- static void ggml_compute_forward_conv_2d_stage_0_f32(
12267
- const struct ggml_compute_params * params,
12268
- const struct ggml_tensor * src0,
12269
- const struct ggml_tensor * src1,
12270
- struct ggml_tensor * dst) {
12271
- GGML_ASSERT(src0->type == GGML_TYPE_F16);
12272
- GGML_ASSERT(src1->type == GGML_TYPE_F32);
12273
- GGML_ASSERT( dst->type == GGML_TYPE_F16);
12274
-
12275
- int64_t t0 = ggml_perf_time_us();
12276
- UNUSED(t0);
12277
-
12278
- GGML_TENSOR_BINARY_OP_LOCALS;
12279
-
12280
- const int64_t N = ne13;
12281
- const int64_t IC = ne12;
12282
- const int64_t IH = ne11;
12283
- const int64_t IW = ne10;
12284
-
12285
- // const int64_t OC = ne03;
12286
- // const int64_t IC = ne02;
12287
- const int64_t KH = ne01;
12288
- const int64_t KW = ne00;
12289
-
12290
- const int64_t OH = ne2;
12291
- const int64_t OW = ne1;
12292
-
12293
- const int ith = params->ith;
12294
- const int nth = params->nth;
12295
-
12296
- const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
12297
- const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
12298
- const int32_t p0 = ((const int32_t*)(dst->op_params))[2];
12299
- const int32_t p1 = ((const int32_t*)(dst->op_params))[3];
12300
- const int32_t d0 = ((const int32_t*)(dst->op_params))[4];
12301
- const int32_t d1 = ((const int32_t*)(dst->op_params))[5];
12302
-
12303
- GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
12304
- GGML_ASSERT(nb10 == sizeof(float));
12305
-
12306
- if (params->type == GGML_TASK_INIT) {
12307
- memset(dst->data, 0, ggml_nbytes(dst));
12308
- return;
12309
- }
12310
-
12311
- if (params->type == GGML_TASK_FINALIZE) {
12312
- return;
12313
- }
12314
-
12315
- // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
12316
- {
12317
- ggml_fp16_t * const wdata = (ggml_fp16_t *) dst->data;
12318
-
12319
- for (int64_t in = 0; in < N; in++) {
12320
- for (int64_t ioh = 0; ioh < OH; ioh++) {
12321
- for (int64_t iow = 0; iow < OW; iow++) {
12322
- for (int64_t iic = ith; iic < IC; iic+=nth) {
12323
-
12324
- // micro kernel
12325
- ggml_fp16_t * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
12326
- const float * const src_data = (float *)((char *) src1->data + in*nb13 + iic*nb12); // [IH, IW]
12327
-
12328
- for (int64_t ikh = 0; ikh < KH; ikh++) {
12329
- for (int64_t ikw = 0; ikw < KW; ikw++) {
12330
- const int64_t iiw = iow*s0 + ikw*d0 - p0;
12331
- const int64_t iih = ioh*s1 + ikh*d1 - p1;
12332
-
12333
- if (!(iih < 0 || iih >= IH || iiw < 0 || iiw >= IW)) {
12334
- dst_data[iic*(KH*KW) + ikh*KW + ikw] = GGML_FP32_TO_FP16(src_data[iih*IW + iiw]);
12335
- }
12336
- }
12337
- }
12338
- }
12339
- }
12340
- }
12341
- }
12342
- }
12343
- }
12344
-
12345
- // gemm: [N, OC, OH, OW] = [OC, IC * KH * KW] x [N*OH*OW, IC * KH * KW]
12346
- // src0: [OC, IC, KH, KW]
12347
- // src1: [N, OH, OW, IC * KH * KW]
12348
- // result: [N, OC, OH, OW]
12349
- static void ggml_compute_forward_conv_2d_stage_1_f16(
12350
- const struct ggml_compute_params * params,
12351
- const struct ggml_tensor * src0,
12352
- const struct ggml_tensor * src1,
12353
- struct ggml_tensor * dst) {
12354
- GGML_ASSERT(src0->type == GGML_TYPE_F16);
12355
- GGML_ASSERT(src1->type == GGML_TYPE_F16);
12356
- GGML_ASSERT( dst->type == GGML_TYPE_F32);
12357
-
12358
- int64_t t0 = ggml_perf_time_us();
12359
- UNUSED(t0);
12360
-
12361
- if (params->type == GGML_TASK_INIT) {
12362
- return;
12363
- }
12364
-
12365
- if (params->type == GGML_TASK_FINALIZE) {
12366
- return;
12367
- }
12368
-
12369
- GGML_TENSOR_BINARY_OP_LOCALS;
12370
-
12371
- GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
12372
- GGML_ASSERT(nb10 == sizeof(ggml_fp16_t));
12373
- GGML_ASSERT(nb0 == sizeof(float));
12374
-
12375
- const int N = ne13;
12376
- const int OH = ne12;
12377
- const int OW = ne11;
12378
-
12379
- const int OC = ne03;
12380
- const int IC = ne02;
12381
- const int KH = ne01;
12382
- const int KW = ne00;
12383
-
12384
- const int ith = params->ith;
12385
- const int nth = params->nth;
12386
-
12387
- int64_t m = OC;
12388
- int64_t n = OH * OW;
12389
- int64_t k = IC * KH * KW;
12390
-
12391
- // [N, OC, OH, OW] = [OC, IC * KH * KW] x [N*OH*OW, IC * KH * KW]
12392
- for (int i = 0; i < N; i++) {
12393
- ggml_fp16_t * A = (ggml_fp16_t *)src0->data; // [m, k]
12394
- ggml_fp16_t * B = (ggml_fp16_t *)src1->data + i * m * k; // [n, k]
12395
- float * C = (float *)dst->data + i * m * n; // [m, n]
12396
-
12397
- gemm_f16_out_f32(m, n, k, A, B, C, ith, nth);
12398
- }
12399
- }
12400
-
12401
- static void ggml_compute_forward_conv_2d_f16_f32(
12402
- const struct ggml_compute_params * params,
12403
- const struct ggml_tensor * src0,
12404
- const struct ggml_tensor * src1,
12405
- struct ggml_tensor * dst) {
12406
- GGML_ASSERT(src0->type == GGML_TYPE_F16);
12407
- GGML_ASSERT(src1->type == GGML_TYPE_F32);
12408
- GGML_ASSERT( dst->type == GGML_TYPE_F32);
12409
-
12410
- int64_t t0 = ggml_perf_time_us();
12411
- UNUSED(t0);
12412
-
12413
- GGML_TENSOR_BINARY_OP_LOCALS
12414
-
12415
- // src1: image [N, IC, IH, IW]
12416
- // src0: kernel [OC, IC, KH, KW]
12417
- // dst: result [N, OC, OH, OW]
12418
- // ne12: IC
12419
- // ne0: OW
12420
- // ne1: OH
12421
- // nk0: KW
12422
- // nk1: KH
12423
- // ne13: N
12424
-
12425
- const int N = ne13;
12426
- const int IC = ne12;
12427
- const int IH = ne11;
12428
- const int IW = ne10;
12429
-
12430
- const int OC = ne03;
12431
- // const int IC = ne02;
12432
- const int KH = ne01;
12433
- const int KW = ne00;
12434
-
12435
- const int OH = ne1;
12436
- const int OW = ne0;
12437
-
12438
- const int ith = params->ith;
12439
- const int nth = params->nth;
12440
-
12441
- // const int nk0 = ne00;
12442
- // const int nk1 = ne01;
12443
-
12444
- // size of the convolution row - the kernel size unrolled across all channels
12445
- // const int ew0 = nk0*nk1*ne02;
12446
- // ew0: IC*KH*KW
12447
-
12448
- const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
12449
- const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
12450
- const int32_t p0 = ((const int32_t*)(dst->op_params))[2];
12451
- const int32_t p1 = ((const int32_t*)(dst->op_params))[3];
12452
- const int32_t d0 = ((const int32_t*)(dst->op_params))[4];
12453
- const int32_t d1 = ((const int32_t*)(dst->op_params))[5];
12454
-
12455
- GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
12456
- GGML_ASSERT(nb10 == sizeof(float));
12457
-
12458
- if (params->type == GGML_TASK_INIT) {
12459
- memset(params->wdata, 0, params->wsize);
12460
-
12461
- // prepare source data (src1)
12462
- // im2col: [N, IC, IH, IW] => [N*OH*OW, IC*KH*KW]
12463
-
12464
- {
12465
- ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
12466
-
12467
- for (int in = 0; in < N; in++) {
12468
- for (int iic = 0; iic < IC; iic++) {
12469
- for (int ioh = 0; ioh < OH; ioh++) {
12470
- for (int iow = 0; iow < OW; iow++) {
12471
-
12472
- // micro kernel
12473
- ggml_fp16_t * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
12474
- const float * const src_data = (float *)((char *) src1->data + in*nb13 + iic*nb12); // [IH, IW]
12475
-
12476
- for (int ikh = 0; ikh < KH; ikh++) {
12477
- for (int ikw = 0; ikw < KW; ikw++) {
12478
- const int iiw = iow*s0 + ikw*d0 - p0;
12479
- const int iih = ioh*s1 + ikh*d1 - p1;
12480
-
12481
- if (!(iih < 0 || iih >= IH || iiw < 0 || iiw >= IW)) {
12482
- dst_data[iic*(KH*KW) + ikh*KW + ikw] = GGML_FP32_TO_FP16(src_data[iih*IW + iiw]);
12483
- }
12484
- }
12485
- }
12486
- }
12487
- }
12488
- }
12489
- }
12490
- }
12491
-
12492
- return;
12493
- }
12494
-
12495
- if (params->type == GGML_TASK_FINALIZE) {
12496
- return;
12497
- }
12498
-
12499
- ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
12500
- // wdata: [N*OH*OW, IC*KH*KW]
12501
- // dst: result [N, OC, OH, OW]
12502
- // src0: kernel [OC, IC, KH, KW]
12503
-
12504
- int64_t m = OC;
12505
- int64_t n = OH * OW;
12506
- int64_t k = IC * KH * KW;
12507
-
12508
- // [N, OC, OH, OW] = [OC, IC * KH * KW] x [N*OH*OW, IC * KH * KW]
12509
- for (int i = 0; i < N; i++) {
12510
- ggml_fp16_t * A = (ggml_fp16_t *)src0->data; // [m, k]
12511
- ggml_fp16_t * B = (ggml_fp16_t *)wdata + i * m * k; // [n, k]
12512
- float * C = (float *)dst->data + i * m * n; // [m * k]
12513
-
12514
- gemm_f16_out_f32(m, n, k, A, B, C, ith, nth);
12515
- }
12516
- }
12517
-
12518
- static void ggml_compute_forward_conv_2d(
12519
- const struct ggml_compute_params * params,
12520
- const struct ggml_tensor * src0,
12521
- const struct ggml_tensor * src1,
12522
- struct ggml_tensor * dst) {
12523
- switch (src0->type) {
12524
- case GGML_TYPE_F16:
12525
- {
12526
- ggml_compute_forward_conv_2d_f16_f32(params, src0, src1, dst);
12527
- } break;
12528
- case GGML_TYPE_F32:
12529
- {
12530
- //ggml_compute_forward_conv_2d_f32(params, src0, src1, dst);
12531
- GGML_ASSERT(false);
12532
- } break;
12533
- default:
12534
- {
12535
- GGML_ASSERT(false);
12536
- } break;
12537
- }
12538
- }
12539
-
12540
- static void ggml_compute_forward_conv_2d_stage_0(
12541
- const struct ggml_compute_params * params,
12542
- const struct ggml_tensor * src0,
12543
- const struct ggml_tensor * src1,
12544
- struct ggml_tensor * dst) {
12545
- switch (src0->type) {
12546
- case GGML_TYPE_F16:
12547
- {
12548
- ggml_compute_forward_conv_2d_stage_0_f32(params, src0, src1, dst);
12549
- } break;
12550
- case GGML_TYPE_F32:
12551
- {
12552
- GGML_ASSERT(false);
12553
- } break;
12554
- default:
12555
- {
12556
- GGML_ASSERT(false);
12557
- } break;
12558
- }
12559
- }
12560
-
12561
- static void ggml_compute_forward_conv_2d_stage_1(
12562
  const struct ggml_compute_params * params,
12563
  const struct ggml_tensor * src0,
12564
  const struct ggml_tensor * src1,
@@ -12566,7 +11801,7 @@ static void ggml_compute_forward_conv_2d_stage_1(
12566
  switch (src0->type) {
12567
  case GGML_TYPE_F16:
12568
  {
12569
- ggml_compute_forward_conv_2d_stage_1_f16(params, src0, src1, dst);
12570
  } break;
12571
  case GGML_TYPE_F32:
12572
  {
@@ -14783,33 +14018,13 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
14783
  {
14784
  ggml_compute_forward_clamp(params, tensor->src[0], tensor);
14785
  } break;
14786
- case GGML_OP_CONV_1D:
14787
- {
14788
- ggml_compute_forward_conv_1d(params, tensor->src[0], tensor->src[1], tensor);
14789
- } break;
14790
- case GGML_OP_CONV_1D_STAGE_0:
14791
- {
14792
- ggml_compute_forward_conv_1d_stage_0(params, tensor->src[0], tensor->src[1], tensor);
14793
- } break;
14794
- case GGML_OP_CONV_1D_STAGE_1:
14795
- {
14796
- ggml_compute_forward_conv_1d_stage_1(params, tensor->src[0], tensor->src[1], tensor);
14797
- } break;
14798
  case GGML_OP_CONV_TRANSPOSE_1D:
14799
  {
14800
  ggml_compute_forward_conv_transpose_1d(params, tensor->src[0], tensor->src[1], tensor);
14801
  } break;
14802
- case GGML_OP_CONV_2D:
14803
- {
14804
- ggml_compute_forward_conv_2d(params, tensor->src[0], tensor->src[1], tensor);
14805
- } break;
14806
- case GGML_OP_CONV_2D_STAGE_0:
14807
- {
14808
- ggml_compute_forward_conv_2d_stage_0(params, tensor->src[0], tensor->src[1], tensor);
14809
- } break;
14810
- case GGML_OP_CONV_2D_STAGE_1:
14811
  {
14812
- ggml_compute_forward_conv_2d_stage_1(params, tensor->src[0], tensor->src[1], tensor);
14813
  } break;
14814
  case GGML_OP_CONV_TRANSPOSE_2D:
14815
  {
@@ -15780,31 +14995,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
15780
  {
15781
  GGML_ASSERT(false); // TODO: not implemented
15782
  } break;
15783
- case GGML_OP_CONV_1D:
15784
- {
15785
- GGML_ASSERT(false); // TODO: not implemented
15786
- } break;
15787
- case GGML_OP_CONV_1D_STAGE_0:
15788
- {
15789
- GGML_ASSERT(false); // TODO: not implemented
15790
- } break;
15791
- case GGML_OP_CONV_1D_STAGE_1:
15792
- {
15793
- GGML_ASSERT(false); // TODO: not implemented
15794
- } break;
15795
  case GGML_OP_CONV_TRANSPOSE_1D:
15796
  {
15797
  GGML_ASSERT(false); // TODO: not implemented
15798
  } break;
15799
- case GGML_OP_CONV_2D:
15800
- {
15801
- GGML_ASSERT(false); // TODO: not implemented
15802
- } break;
15803
- case GGML_OP_CONV_2D_STAGE_0:
15804
- {
15805
- GGML_ASSERT(false); // TODO: not implemented
15806
- } break;
15807
- case GGML_OP_CONV_2D_STAGE_1:
15808
  {
15809
  GGML_ASSERT(false); // TODO: not implemented
15810
  } break;
@@ -16533,31 +15728,11 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
16533
  {
16534
  n_tasks = 1; //TODO
16535
  } break;
16536
- case GGML_OP_CONV_1D:
16537
- {
16538
- n_tasks = n_threads;
16539
- } break;
16540
- case GGML_OP_CONV_1D_STAGE_0:
16541
- {
16542
- n_tasks = n_threads;
16543
- } break;
16544
- case GGML_OP_CONV_1D_STAGE_1:
16545
- {
16546
- n_tasks = n_threads;
16547
- } break;
16548
  case GGML_OP_CONV_TRANSPOSE_1D:
16549
  {
16550
  n_tasks = n_threads;
16551
  } break;
16552
- case GGML_OP_CONV_2D:
16553
- {
16554
- n_tasks = n_threads;
16555
- } break;
16556
- case GGML_OP_CONV_2D_STAGE_0:
16557
- {
16558
- n_tasks = n_threads;
16559
- } break;
16560
- case GGML_OP_CONV_2D_STAGE_1:
16561
  {
16562
  n_tasks = n_threads;
16563
  } break;
@@ -16642,6 +15817,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
16642
  } break;
16643
  default:
16644
  {
 
16645
  GGML_ASSERT(false);
16646
  } break;
16647
  }
@@ -16844,38 +16020,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
16844
  cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
16845
  }
16846
  } break;
16847
- case GGML_OP_CONV_1D:
16848
- {
16849
- GGML_ASSERT(node->src[0]->ne[3] == 1);
16850
- GGML_ASSERT(node->src[1]->ne[2] == 1);
16851
- GGML_ASSERT(node->src[1]->ne[3] == 1);
16852
-
16853
- const int64_t ne00 = node->src[0]->ne[0];
16854
- const int64_t ne01 = node->src[0]->ne[1];
16855
- const int64_t ne02 = node->src[0]->ne[2];
16856
-
16857
- const int64_t ne10 = node->src[1]->ne[0];
16858
- const int64_t ne11 = node->src[1]->ne[1];
16859
-
16860
- const int64_t ne0 = node->ne[0];
16861
- const int64_t ne1 = node->ne[1];
16862
- const int64_t nk = ne00;
16863
- const int64_t ew0 = nk * ne01;
16864
-
16865
- UNUSED(ne02);
16866
- UNUSED(ne10);
16867
- UNUSED(ne11);
16868
-
16869
- if (node->src[0]->type == GGML_TYPE_F16 &&
16870
- node->src[1]->type == GGML_TYPE_F32) {
16871
- cur = sizeof(ggml_fp16_t)*(ne0*ne1*ew0);
16872
- } else if (node->src[0]->type == GGML_TYPE_F32 &&
16873
- node->src[1]->type == GGML_TYPE_F32) {
16874
- cur = sizeof(float)*(ne0*ne1*ew0);
16875
- } else {
16876
- GGML_ASSERT(false);
16877
- }
16878
- } break;
16879
  case GGML_OP_CONV_TRANSPOSE_1D:
16880
  {
16881
  GGML_ASSERT(node->src[0]->ne[3] == 1);
@@ -16901,37 +16045,9 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
16901
  GGML_ASSERT(false);
16902
  }
16903
  } break;
16904
- case GGML_OP_CONV_2D:
16905
  {
16906
- const int64_t ne00 = node->src[0]->ne[0]; // W
16907
- const int64_t ne01 = node->src[0]->ne[1]; // H
16908
- const int64_t ne02 = node->src[0]->ne[2]; // C
16909
- const int64_t ne03 = node->src[0]->ne[3]; // N
16910
-
16911
- const int64_t ne10 = node->src[1]->ne[0]; // W
16912
- const int64_t ne11 = node->src[1]->ne[1]; // H
16913
- const int64_t ne12 = node->src[1]->ne[2]; // C
16914
-
16915
- const int64_t ne0 = node->ne[0];
16916
- const int64_t ne1 = node->ne[1];
16917
- const int64_t ne2 = node->ne[2];
16918
- const int64_t ne3 = node->ne[3];
16919
- const int64_t nk = ne00*ne01;
16920
- const int64_t ew0 = nk * ne02;
16921
-
16922
- UNUSED(ne03);
16923
- UNUSED(ne2);
16924
-
16925
- if (node->src[0]->type == GGML_TYPE_F16 &&
16926
- node->src[1]->type == GGML_TYPE_F32) {
16927
- // im2col: [N*OH*OW, IC*KH*KW]
16928
- cur = sizeof(ggml_fp16_t)*(ne3*ne0*ne1*ew0);
16929
- } else if (node->src[0]->type == GGML_TYPE_F32 &&
16930
- node->src[1]->type == GGML_TYPE_F32) {
16931
- cur = sizeof(float)* (ne10*ne11*ne12);
16932
- } else {
16933
- GGML_ASSERT(false);
16934
- }
16935
  } break;
16936
  case GGML_OP_CONV_TRANSPOSE_2D:
16937
  {
 
1634
  "ROPE_BACK",
1635
  "ALIBI",
1636
  "CLAMP",
 
 
 
1637
  "CONV_TRANSPOSE_1D",
1638
+ "IM2COL",
 
 
1639
  "CONV_TRANSPOSE_2D",
1640
  "POOL_1D",
1641
  "POOL_2D",
 
1666
  "CROSS_ENTROPY_LOSS_BACK",
1667
  };
1668
 
1669
+ static_assert(GGML_OP_COUNT == 68, "GGML_OP_COUNT != 68");
1670
 
1671
  static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
1672
  "none",
 
1716
  "rope_back(x)",
1717
  "alibi(x)",
1718
  "clamp(x)",
 
 
 
1719
  "conv_transpose_1d(x)",
1720
+ "im2col(x)",
 
 
1721
  "conv_transpose_2d(x)",
1722
  "pool_1d(x)",
1723
  "pool_2d(x)",
 
1748
  "cross_entropy_loss_back(x,y)",
1749
  };
1750
 
1751
+ static_assert(GGML_OP_COUNT == 68, "GGML_OP_COUNT != 68");
1752
 
1753
  static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
1754
 
 
1776
  p[GGML_OP_GET_ROWS_BACK ] = true;
1777
  p[GGML_OP_DIAG_MASK_INF ] = true;
1778
  p[GGML_OP_DIAG_MASK_ZERO ] = true;
 
 
 
1779
  p[GGML_OP_CONV_TRANSPOSE_1D ] = true;
 
 
 
1780
  p[GGML_OP_CONV_TRANSPOSE_2D ] = true;
1781
  p[GGML_OP_FLASH_ATTN_BACK ] = true;
1782
  p[GGML_OP_CROSS_ENTROPY_LOSS ] = true;
 
5121
  return (ins + 2 * p - d * (ks - 1) - 1) / s + 1;
5122
  }
5123
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5124
  GGML_API struct ggml_tensor * ggml_conv_1d(
5125
  struct ggml_context * ctx,
5126
  struct ggml_tensor * a,
 
5128
  int s0,
5129
  int p0,
5130
  int d0) {
5131
+ struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, 0, p0, 0, d0, 0, false); // [N, OL, IC * K]
 
 
 
5132
 
5133
+ struct ggml_tensor * result =
5134
+ ggml_mul_mat(ctx,
5135
+ ggml_reshape_2d(ctx, im2col, im2col->ne[0], (im2col->ne[2] * im2col->ne[1])), // [N, OL, IC * K] => [N*OL, IC * K]
5136
+ ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1]), a->ne[2])); // [OC,IC, K] => [OC, IC * K]
 
 
 
 
 
 
5137
 
5138
+ result = ggml_reshape_3d(ctx, result, im2col->ne[1], a->ne[2], im2col->ne[2]); // [N, OC, OL]
 
 
 
5139
 
5140
+ return result;
5141
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5142
 
5143
  // ggml_conv_1d_ph
5144
 
 
5201
  // a: [OC,IC, KH, KW]
5202
  // b: [N, IC, IH, IW]
5203
  // result: [N, OH, OW, IC*KH*KW]
5204
+ struct ggml_tensor * ggml_im2col(
5205
  struct ggml_context * ctx,
5206
  struct ggml_tensor * a,
5207
  struct ggml_tensor * b,
 
5210
  int p0,
5211
  int p1,
5212
  int d0,
5213
+ int d1,
5214
+ bool is_2D) {
5215
 
5216
+ if(is_2D) {
5217
+ GGML_ASSERT(a->ne[2] == b->ne[2]);
5218
+ } else {
5219
+ GGML_ASSERT(a->ne[1] == b->ne[1]);
5220
+ }
5221
  bool is_node = false;
5222
 
5223
  if (a->grad || b->grad) {
 
5225
  is_node = true;
5226
  }
5227
 
5228
+ const int64_t OH = is_2D ? ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1) : 0;
5229
+ const int64_t OW = ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0);
5230
 
5231
  const int64_t ne[4] = {
5232
+ is_2D ? (a->ne[2] * a->ne[1] * a->ne[0]) : a->ne[1] * a->ne[0],
5233
  OW,
5234
+ is_2D ? OH : b->ne[2],
5235
+ is_2D ? b->ne[3] : 1,
5236
  };
 
5237
 
5238
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F16, 4, ne);
5239
+ int32_t params[] = { s0, s1, p0, p1, d0, d1, (is_2D ? 1 : 0) };
5240
  ggml_set_op_params(result, params, sizeof(params));
5241
 
5242
+ result->op = GGML_OP_IM2COL;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5243
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5244
  result->src[0] = a;
5245
  result->src[1] = b;
5246
 
5247
  return result;
 
5248
  }
5249
 
5250
  // a: [OC,IC, KH, KW]
5251
  // b: [N, IC, IH, IW]
5252
  // result: [N, OC, OH, OW]
5253
  struct ggml_tensor * ggml_conv_2d(
5254
+ struct ggml_context * ctx,
5255
+ struct ggml_tensor * a,
5256
+ struct ggml_tensor * b,
5257
+ int s0,
5258
+ int s1,
5259
+ int p0,
5260
+ int p1,
5261
+ int d0,
5262
+ int d1) {
5263
+ struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, s1, p0, p1, d0, d1, true); // [N, OH, OW, IC * KH * KW]
5264
 
5265
+ struct ggml_tensor * result =
5266
+ ggml_mul_mat(ctx,
5267
+ ggml_reshape_2d(ctx, im2col, im2col->ne[0], im2col->ne[3] * im2col->ne[2] * im2col->ne[1]), // [N, OH, OW, IC * KH * KW] => [N*OH*OW, IC * KH * KW]
5268
+ ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1] * a->ne[2]), a->ne[3])); // [OC,IC, KH, KW] => [OC, IC * KH * KW]
5269
 
5270
+ result = ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], a->ne[3], im2col->ne[3]); // [N, OC, OH, OW]
5271
 
5272
+ return result;
5273
  }
5274
 
5275
  // ggml_conv_2d_sk_p0
 
9364
  // TODO: find the optimal values for these
9365
  if (ggml_is_contiguous(src0) &&
9366
  ggml_is_contiguous(src1) &&
9367
+ src0->type == GGML_TYPE_F32 &&
9368
+ src1->type == GGML_TYPE_F32 &&
9369
  (ne0 >= 32 && ne1 >= 32 && ne10 >= 32)) {
9370
 
9371
  /*printf("BLAS: %d %d %d %d %d\n", ne0, ne1, ne10, ne00, ne01);*/
 
9376
  }
9377
  #endif
9378
 
9379
+
9380
  static void ggml_compute_forward_mul_mat(
9381
  const struct ggml_compute_params * params,
9382
  const struct ggml_tensor * src0,
 
9405
 
9406
  // we don't support permuted src0 or src1
9407
  GGML_ASSERT(nb00 == ggml_type_size(type));
9408
+ GGML_ASSERT(nb10 == ggml_type_size(src1->type));
9409
 
9410
  // dst cannot be transposed or permuted
9411
  GGML_ASSERT(nb0 == sizeof(float));
 
11497
  }
11498
  }
11499
 
11500
+ // ggml_compute_forward_conv_transpose_1d
11501
 
11502
+ static void ggml_compute_forward_conv_transpose_1d_f16_f32(
11503
  const struct ggml_compute_params * params,
11504
  const struct ggml_tensor * src0,
11505
  const struct ggml_tensor * src1,
 
11516
  const int ith = params->ith;
11517
  const int nth = params->nth;
11518
 
11519
+ const int nk = ne00*ne01*ne02;
 
 
 
 
 
 
 
11520
 
11521
  GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
11522
  GGML_ASSERT(nb10 == sizeof(float));
 
11524
  if (params->type == GGML_TASK_INIT) {
11525
  memset(params->wdata, 0, params->wsize);
11526
 
11527
+ // permute kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout)
11528
+ {
11529
+ ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
11530
 
11531
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
11532
+ for (int64_t i01 = 0; i01 < ne01; i01++) {
11533
+ const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01);
11534
+ ggml_fp16_t * dst_data = wdata + i01*ne00*ne02;
11535
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
11536
+ dst_data[i00*ne02 + i02] = src[i00];
11537
+ }
11538
+ }
11539
+ }
11540
+ }
11541
 
11542
+ // permute source data (src1) from (L x Cin) to (Cin x L)
11543
+ {
11544
+ ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + nk;
11545
+ ggml_fp16_t * dst_data = wdata;
11546
 
11547
+ for (int64_t i11 = 0; i11 < ne11; i11++) {
11548
+ const float * const src = (float *)((char *) src1->data + i11*nb11);
11549
+ for (int64_t i10 = 0; i10 < ne10; i10++) {
11550
+ dst_data[i10*ne11 + i11] = GGML_FP32_TO_FP16(src[i10]);
11551
  }
11552
  }
11553
  }
11554
 
11555
+ // need to zero dst since we are accumulating into it
11556
+ memset(dst->data, 0, ggml_nbytes(dst));
11557
+
11558
  return;
11559
  }
11560
 
 
11562
  return;
11563
  }
11564
 
11565
+ const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
11566
+
11567
  // total rows in dst
11568
+ const int nr = ne1;
11569
 
11570
  // rows per thread
11571
  const int dr = (nr + nth - 1)/nth;
 
11574
  const int ir0 = dr*ith;
11575
  const int ir1 = MIN(ir0 + dr, nr);
11576
 
11577
+ ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
11578
+ ggml_fp16_t * const wdata_src = wdata + nk;
 
 
 
11579
 
11580
+ for (int i1 = ir0; i1 < ir1; i1++) {
11581
+ float * dst_data = (float *)((char *) dst->data + i1*nb1);
11582
+ ggml_fp16_t * wdata_kernel = wdata + i1*ne02*ne00;
11583
+ for (int i10 = 0; i10 < ne10; i10++) {
11584
+ const int i1n = i10*ne11;
11585
+ for (int i00 = 0; i00 < ne00; i00++) {
11586
+ float v = 0;
11587
+ ggml_vec_dot_f16(ne02, &v,
11588
+ (ggml_fp16_t *) wdata_src + i1n,
11589
+ (ggml_fp16_t *) wdata_kernel + i00*ne02);
11590
+ dst_data[i10*s0 + i00] += v;
11591
  }
11592
  }
11593
  }
11594
  }
11595
 
11596
+ static void ggml_compute_forward_conv_transpose_1d_f32(
11597
  const struct ggml_compute_params * params,
11598
  const struct ggml_tensor * src0,
11599
  const struct ggml_tensor * src1,
 
11610
  const int ith = params->ith;
11611
  const int nth = params->nth;
11612
 
11613
+ const int nk = ne00*ne01*ne02;
 
 
 
 
 
 
11614
 
11615
  GGML_ASSERT(nb00 == sizeof(float));
11616
  GGML_ASSERT(nb10 == sizeof(float));
 
11618
  if (params->type == GGML_TASK_INIT) {
11619
  memset(params->wdata, 0, params->wsize);
11620
 
11621
+ // prepare kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout)
11622
+ {
11623
+ float * const wdata = (float *) params->wdata + 0;
11624
 
11625
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
11626
+ for (int64_t i01 = 0; i01 < ne01; i01++) {
11627
+ const float * const src = (float *)((char *) src0->data + i02*nb02 + i01*nb01);
11628
+ float * dst_data = wdata + i01*ne00*ne02;
11629
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
11630
+ dst_data[i00*ne02 + i02] = src[i00];
11631
+ }
11632
+ }
11633
+ }
11634
+ }
11635
 
11636
+ // prepare source data (src1)
11637
+ {
11638
+ float * const wdata = (float *) params->wdata + nk;
11639
+ float * dst_data = wdata;
11640
 
11641
+ for (int64_t i11 = 0; i11 < ne11; i11++) {
11642
+ const float * const src = (float *)((char *) src1->data + i11*nb11);
11643
+ for (int64_t i10 = 0; i10 < ne10; i10++) {
11644
+ dst_data[i10*ne11 + i11] = src[i10];
11645
  }
11646
  }
11647
  }
11648
 
11649
+ // need to zero dst since we are accumulating into it
11650
+ memset(dst->data, 0, ggml_nbytes(dst));
11651
+
11652
  return;
11653
  }
11654
 
 
11656
  return;
11657
  }
11658
 
11659
+ const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
11660
+
11661
  // total rows in dst
11662
+ const int nr = ne1;
11663
 
11664
  // rows per thread
11665
  const int dr = (nr + nth - 1)/nth;
 
11668
  const int ir0 = dr*ith;
11669
  const int ir1 = MIN(ir0 + dr, nr);
11670
 
11671
+ float * const wdata = (float *) params->wdata + 0;
11672
+ float * const wdata_src = wdata + nk;
 
 
 
11673
 
11674
+ for (int i1 = ir0; i1 < ir1; i1++) {
11675
+ float * dst_data = (float *)((char *) dst->data + i1*nb1);
11676
+ float * wdata_kernel = wdata + i1*ne02*ne00;
11677
+ for (int i10 = 0; i10 < ne10; i10++) {
11678
+ const int i1n = i10*ne11;
11679
+ for (int i00 = 0; i00 < ne00; i00++) {
11680
+ float v = 0;
11681
+ ggml_vec_dot_f32(ne02, &v,
11682
+ wdata_src + i1n,
11683
+ wdata_kernel + i00*ne02);
11684
+ dst_data[i10*s0 + i00] += v;
11685
  }
11686
  }
11687
  }
11688
  }
11689
 
11690
+ static void ggml_compute_forward_conv_transpose_1d(
11691
+ const struct ggml_compute_params * params,
11692
+ const struct ggml_tensor * src0,
11693
+ const struct ggml_tensor * src1,
11694
+ struct ggml_tensor * dst) {
11695
+ switch (src0->type) {
11696
+ case GGML_TYPE_F16:
11697
+ {
11698
+ ggml_compute_forward_conv_transpose_1d_f16_f32(params, src0, src1, dst);
11699
+ } break;
11700
+ case GGML_TYPE_F32:
11701
+ {
11702
+ ggml_compute_forward_conv_transpose_1d_f32(params, src0, src1, dst);
11703
+ } break;
11704
+ default:
11705
+ {
11706
+ GGML_ASSERT(false);
11707
+ } break;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11708
  }
11709
  }
11710
 
11711
+ // src0: kernel [OC, IC, KH, KW]
11712
+ // src1: image [N, IC, IH, IW]
11713
+ // dst: result [N, OH, OW, IC*KH*KW]
11714
+ static void ggml_compute_forward_im2col_f16(
11715
  const struct ggml_compute_params * params,
11716
  const struct ggml_tensor * src0,
11717
  const struct ggml_tensor * src1,
 
11725
 
11726
  GGML_TENSOR_BINARY_OP_LOCALS;
11727
 
11728
+ const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
11729
+ const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
11730
+ const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
11731
+ const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
11732
+ const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
11733
+ const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
11734
+ const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
11735
 
11736
  const int ith = params->ith;
11737
  const int nth = params->nth;
11738
 
11739
+ const int64_t N = is_2D ? ne13 : ne12;
11740
+ const int64_t IC = is_2D ? ne12 : ne11;
11741
+ const int64_t IH = is_2D ? ne11 : 1;
11742
+ const int64_t IW = ne10;
11743
+
11744
+ const int64_t KH = is_2D ? ne01 : 1;
11745
+ const int64_t KW = ne00;
11746
+
11747
+ const int64_t OH = is_2D ? ne2 : 1;
11748
+ const int64_t OW = ne1;
11749
+
11750
+ int ofs0 = is_2D ? nb13 : nb12;
11751
+ int ofs1 = is_2D ? nb12 : nb11;
11752
 
11753
  GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
11754
  GGML_ASSERT(nb10 == sizeof(float));
11755
 
11756
  if (params->type == GGML_TASK_INIT) {
 
11757
  return;
11758
  }
11759
 
 
11761
  return;
11762
  }
11763
 
11764
+ // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
11765
  {
11766
  ggml_fp16_t * const wdata = (ggml_fp16_t *) dst->data;
11767
 
11768
  for (int64_t in = 0; in < N; in++) {
11769
+ for (int64_t ioh = 0; ioh < OH; ioh++) { // 1
11770
+ for (int64_t iow = 0; iow < OW; iow++) {
11771
+ for (int64_t iic = ith; iic < IC; iic += nth) {
11772
 
11773
+ // micro kernel
11774
+ ggml_fp16_t * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
11775
+ const float * const src_data = (float *)((char *) src1->data + in*ofs0 + iic*ofs1); // [IH, IW]
11776
 
11777
+ for (int64_t ikh = 0; ikh < KH; ikh++) { // 1
11778
+ for (int64_t ikw = 0; ikw < KW; ikw++) {
11779
+ const int64_t iiw = iow*s0 + ikw*d0 - p0;
11780
+ const int64_t iih = ioh*s1 + ikh*d1 - p1;
11781
 
11782
+ if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
11783
+ dst_data[iic*(KH*KW) + ikh*KW + ikw] = 0;
11784
+ } else {
11785
+ dst_data[iic*(KH*KW) + ikh*KW + ikw] = GGML_FP32_TO_FP16(src_data[iih*IW + iiw]);
11786
+ }
11787
+ }
11788
  }
11789
  }
11790
  }
 
11793
  }
11794
  }
11795
 
11796
+ static void ggml_compute_forward_im2col(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11797
  const struct ggml_compute_params * params,
11798
  const struct ggml_tensor * src0,
11799
  const struct ggml_tensor * src1,
 
11801
  switch (src0->type) {
11802
  case GGML_TYPE_F16:
11803
  {
11804
+ ggml_compute_forward_im2col_f16(params, src0, src1, dst);
11805
  } break;
11806
  case GGML_TYPE_F32:
11807
  {
 
14018
  {
14019
  ggml_compute_forward_clamp(params, tensor->src[0], tensor);
14020
  } break;
 
 
 
 
 
 
 
 
 
 
 
 
14021
  case GGML_OP_CONV_TRANSPOSE_1D:
14022
  {
14023
  ggml_compute_forward_conv_transpose_1d(params, tensor->src[0], tensor->src[1], tensor);
14024
  } break;
14025
+ case GGML_OP_IM2COL:
 
 
 
 
 
 
 
 
14026
  {
14027
+ ggml_compute_forward_im2col(params, tensor->src[0], tensor->src[1], tensor);
14028
  } break;
14029
  case GGML_OP_CONV_TRANSPOSE_2D:
14030
  {
 
14995
  {
14996
  GGML_ASSERT(false); // TODO: not implemented
14997
  } break;
 
 
 
 
 
 
 
 
 
 
 
 
14998
  case GGML_OP_CONV_TRANSPOSE_1D:
14999
  {
15000
  GGML_ASSERT(false); // TODO: not implemented
15001
  } break;
15002
+ case GGML_OP_IM2COL:
 
 
 
 
 
 
 
 
15003
  {
15004
  GGML_ASSERT(false); // TODO: not implemented
15005
  } break;
 
15728
  {
15729
  n_tasks = 1; //TODO
15730
  } break;
 
 
 
 
 
 
 
 
 
 
 
 
15731
  case GGML_OP_CONV_TRANSPOSE_1D:
15732
  {
15733
  n_tasks = n_threads;
15734
  } break;
15735
+ case GGML_OP_IM2COL:
 
 
 
 
 
 
 
 
15736
  {
15737
  n_tasks = n_threads;
15738
  } break;
 
15817
  } break;
15818
  default:
15819
  {
15820
+ printf("%s: op %s not implemented\n", __func__, ggml_op_name(node->op));
15821
  GGML_ASSERT(false);
15822
  } break;
15823
  }
 
16020
  cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
16021
  }
16022
  } break;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16023
  case GGML_OP_CONV_TRANSPOSE_1D:
16024
  {
16025
  GGML_ASSERT(node->src[0]->ne[3] == 1);
 
16045
  GGML_ASSERT(false);
16046
  }
16047
  } break;
16048
+ case GGML_OP_IM2COL:
16049
  {
16050
+ n_tasks = n_threads;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16051
  } break;
16052
  case GGML_OP_CONV_TRANSPOSE_2D:
16053
  {
ggml.h CHANGED
@@ -403,13 +403,8 @@ extern "C" {
403
  GGML_OP_ROPE_BACK,
404
  GGML_OP_ALIBI,
405
  GGML_OP_CLAMP,
406
- GGML_OP_CONV_1D,
407
- GGML_OP_CONV_1D_STAGE_0, // internal
408
- GGML_OP_CONV_1D_STAGE_1, // internal
409
  GGML_OP_CONV_TRANSPOSE_1D,
410
- GGML_OP_CONV_2D,
411
- GGML_OP_CONV_2D_STAGE_0, // internal
412
- GGML_OP_CONV_2D_STAGE_1, // internal
413
  GGML_OP_CONV_TRANSPOSE_2D,
414
  GGML_OP_POOL_1D,
415
  GGML_OP_POOL_2D,
@@ -1398,6 +1393,18 @@ extern "C" {
1398
  float min,
1399
  float max);
1400
 
 
 
 
 
 
 
 
 
 
 
 
 
1401
  GGML_API struct ggml_tensor * ggml_conv_1d(
1402
  struct ggml_context * ctx,
1403
  struct ggml_tensor * a,
 
403
  GGML_OP_ROPE_BACK,
404
  GGML_OP_ALIBI,
405
  GGML_OP_CLAMP,
 
 
 
406
  GGML_OP_CONV_TRANSPOSE_1D,
407
+ GGML_OP_IM2COL,
 
 
408
  GGML_OP_CONV_TRANSPOSE_2D,
409
  GGML_OP_POOL_1D,
410
  GGML_OP_POOL_2D,
 
1393
  float min,
1394
  float max);
1395
 
1396
+ GGML_API struct ggml_tensor * ggml_im2col(
1397
+ struct ggml_context * ctx,
1398
+ struct ggml_tensor * a,
1399
+ struct ggml_tensor * b,
1400
+ int s0,
1401
+ int s1,
1402
+ int p0,
1403
+ int p1,
1404
+ int d0,
1405
+ int d1,
1406
+ bool is_2D);
1407
+
1408
  GGML_API struct ggml_tensor * ggml_conv_1d(
1409
  struct ggml_context * ctx,
1410
  struct ggml_tensor * a,
whisper.cpp CHANGED
@@ -1,10 +1,15 @@
1
  #include "whisper.h"
 
2
  #ifdef WHISPER_USE_COREML
3
  #include "coreml/whisper-encoder.h"
4
  #endif
5
 
6
  #ifdef GGML_USE_METAL
7
- # include "ggml-metal.h"
 
 
 
 
8
  #endif
9
 
10
  #ifdef WHISPER_USE_OPENVINO
@@ -13,6 +18,7 @@
13
 
14
  #include "ggml.h"
15
  #include "ggml-alloc.h"
 
16
 
17
  #include <algorithm>
18
  #include <cassert>
@@ -97,10 +103,32 @@ static void byteswap_tensor(ggml_tensor * tensor) {
97
  #define BYTESWAP_TENSOR(t) do {} while (0)
98
  #endif
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  #define WHISPER_ASSERT(x) \
101
  do { \
102
  if (!(x)) { \
103
- log("WHISPER_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
104
  abort(); \
105
  } \
106
  } while (0)
@@ -127,8 +155,8 @@ static void byteswap_tensor(ggml_tensor * tensor) {
127
  //
128
 
129
  static void ggml_graph_compute_helper(
 
130
  std::vector<uint8_t> & buf,
131
- ggml_cgraph * graph,
132
  int n_threads,
133
  whisper_abort_callback abort_callback,
134
  void * abort_callback_data) {
@@ -145,6 +173,21 @@ static void ggml_graph_compute_helper(
145
  ggml_graph_compute(graph, &plan);
146
  }
147
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  // faster matrix multiplications for tensors that do not have dimension 0 divisible by "pad"
149
  // the idea is to represent the original matrix multiplication:
150
  //
@@ -179,6 +222,7 @@ static struct ggml_tensor * ggml_mul_mat_pad(struct ggml_context * ctx, struct g
179
  }
180
 
181
  // TODO: check if other platforms can benefit from this optimization
 
182
  #if defined(GGML_USE_METAL)
183
  #define ggml_mul_mat ggml_mul_mat_pad
184
  #endif
@@ -305,75 +349,6 @@ static const std::map<std::string, std::pair<int, std::string>> g_lang = {
305
  { "yue", { 99, "cantonese", } },
306
  };
307
 
308
- static const size_t MB = 1ull*1024*1024;
309
-
310
- // TODO: avoid using GGUF
311
- static const std::map<ggml_type, std::map<e_model, size_t>> MEM_REQ_MODEL = {
312
- { GGML_TYPE_F32,
313
- {
314
- { MODEL_TINY, 74ull*MB },
315
- { MODEL_BASE, 142ull*MB },
316
- { MODEL_SMALL, 466ull*MB },
317
- { MODEL_MEDIUM, 1464ull*MB },
318
- { MODEL_LARGE, 2952ull*MB },
319
- },
320
- },
321
- { GGML_TYPE_F16,
322
- {
323
- { MODEL_TINY, 74ull*MB },
324
- { MODEL_BASE, 142ull*MB },
325
- { MODEL_SMALL, 466ull*MB },
326
- { MODEL_MEDIUM, 1464ull*MB },
327
- { MODEL_LARGE, 2952ull*MB },
328
- },
329
- },
330
- { GGML_TYPE_Q4_0,
331
- {
332
- { MODEL_TINY, 26ull*MB },
333
- { MODEL_BASE, 50ull*MB },
334
- { MODEL_SMALL, 154ull*MB },
335
- { MODEL_MEDIUM, 470ull*MB },
336
- { MODEL_LARGE, 940ull*MB },
337
- },
338
- },
339
- { GGML_TYPE_Q4_1,
340
- {
341
- { MODEL_TINY, 32ull*MB },
342
- { MODEL_BASE, 58ull*MB },
343
- { MODEL_SMALL, 182ull*MB },
344
- { MODEL_MEDIUM, 562ull*MB },
345
- { MODEL_LARGE, 1124ull*MB },
346
- },
347
- },
348
- { GGML_TYPE_Q5_0,
349
- {
350
- { MODEL_TINY, 30ull*MB },
351
- { MODEL_BASE, 54ull*MB },
352
- { MODEL_SMALL, 170ull*MB },
353
- { MODEL_MEDIUM, 516ull*MB },
354
- { MODEL_LARGE, 1034ull*MB },
355
- },
356
- },
357
- { GGML_TYPE_Q5_1,
358
- {
359
- { MODEL_TINY, 32ull*MB },
360
- { MODEL_BASE, 58ull*MB },
361
- { MODEL_SMALL, 182ull*MB },
362
- { MODEL_MEDIUM, 562ull*MB },
363
- { MODEL_LARGE, 1124ull*MB },
364
- },
365
- },
366
- { GGML_TYPE_Q8_0,
367
- {
368
- { MODEL_TINY, 45ull*MB },
369
- { MODEL_BASE, 84ull*MB },
370
- { MODEL_SMALL, 268ull*MB },
371
- { MODEL_MEDIUM, 834ull*MB },
372
- { MODEL_LARGE, 1674ull*MB },
373
- },
374
- },
375
- };
376
-
377
  struct whisper_mel {
378
  int n_len;
379
  int n_len_org;
@@ -554,8 +529,7 @@ struct whisper_kv_cache {
554
 
555
  struct ggml_context * ctx;
556
 
557
- // buf points to the memory allocated for both ggml_tensor 'k' and 'v' (see kv_cache_init)
558
- std::vector<uint8_t> buf;
559
 
560
  int n; // number of tokens currently in the cache
561
  };
@@ -594,11 +568,11 @@ struct whisper_model {
594
  std::vector<whisper_layer_encoder> layers_encoder;
595
  std::vector<whisper_layer_decoder> layers_decoder;
596
 
597
- // context
598
  struct ggml_context * ctx;
599
 
600
- // the model memory buffer is read-only and can be shared between processors
601
- std::vector<uint8_t> * buf;
602
 
603
  // tensors
604
  int n_loaded;
@@ -663,37 +637,47 @@ struct whisper_allocr {
663
  ggml_allocr * alloc = nullptr;
664
 
665
  std::vector<uint8_t> meta;
666
- std::vector<uint8_t> data;
 
667
  };
668
 
669
  static size_t whisper_allocr_size(struct whisper_allocr & allocr) {
670
- return allocr.meta.size() + allocr.data.size();
671
  }
672
 
673
  // measure the memory usage of a graph and prepare the allocr's internal data buffer
674
- static void whisper_allocr_graph_init(struct whisper_allocr & allocr, std::function<struct ggml_cgraph *()> && get_graph) {
675
- const int tensor_alignment = 32;
 
676
 
677
- auto & alloc = allocr.alloc;
678
- auto & meta = allocr.meta;
679
- auto & data = allocr.data;
680
 
681
  meta.resize(ggml_tensor_overhead()*WHISPER_MAX_NODES + ggml_graph_overhead());
682
 
683
- alloc = ggml_allocr_new_measure(tensor_alignment);
 
684
 
685
- const size_t alloc_size = ggml_allocr_alloc_graph(alloc, get_graph()) + tensor_alignment;
 
 
 
 
686
 
687
- ggml_allocr_free(alloc);
 
688
 
689
- data.resize(alloc_size);
690
 
691
- alloc = ggml_allocr_new(data.data(), data.size(), tensor_alignment);
 
 
 
692
  }
693
 
694
  static void whisper_allocr_free(struct whisper_allocr & allocr) {
695
  if (allocr.alloc) {
696
  ggml_allocr_free(allocr.alloc);
 
697
  allocr.alloc = nullptr;
698
  }
699
  }
@@ -722,8 +706,7 @@ struct whisper_state {
722
  // buffer for swapping KV caches between decoders during beam-search
723
  std::vector<kv_buf> kv_swap_bufs;
724
 
725
- // reusable buffer for `struct ggml_graph_plan.work_data`
726
- std::vector<uint8_t> work_buffer;
727
 
728
  // ggml-alloc:
729
  // - stores meta info about the intermediate tensors into the `meta` buffers
@@ -737,6 +720,9 @@ struct whisper_state {
737
  struct ggml_tensor * embd_conv = nullptr;
738
  struct ggml_tensor * embd_enc = nullptr;
739
 
 
 
 
740
  // decode output (2-dimensional array: [n_tokens][n_vocab])
741
  std::vector<float> logits;
742
 
@@ -751,22 +737,21 @@ struct whisper_state {
751
  int lang_id = 0; // english by default
752
 
753
  std::string path_model; // populated by whisper_init_from_file_with_params()
 
754
  #ifdef WHISPER_USE_COREML
755
  whisper_coreml_context * ctx_coreml = nullptr;
756
  #endif
757
 
758
- #ifdef GGML_USE_METAL
759
- ggml_metal_context * ctx_metal = nullptr;
760
- #endif
761
-
762
  #ifdef WHISPER_USE_OPENVINO
763
  whisper_openvino_context * ctx_openvino = nullptr;
764
  #endif
765
 
766
  // [EXPERIMENTAL] token-level timestamps data
767
- int64_t t_beg = 0;
768
  int64_t t_last = 0;
 
769
  whisper_token tid_last;
 
770
  std::vector<float> energy; // PCM signal energy
771
 
772
  // [EXPERIMENTAL] speed-up techniques
@@ -780,35 +765,25 @@ struct whisper_context {
780
  ggml_type wtype = ggml_type::GGML_TYPE_F16; // weight type (FP32 / FP16 / QX)
781
  ggml_type itype = ggml_type::GGML_TYPE_F16; // intermediate type (FP32 or FP16)
782
 
 
 
783
  whisper_model model;
784
  whisper_vocab vocab;
 
785
  whisper_state * state = nullptr;
786
 
 
 
787
  std::string path_model; // populated by whisper_init_from_file_with_params()
788
- whisper_context_params params;
789
  };
790
 
791
- static void whisper_default_log(const char * text) {
792
- fprintf(stderr, "%s", text);
793
- }
 
 
794
 
795
- static whisper_log_callback whisper_log = whisper_default_log;
796
-
797
- #ifdef __GNUC__
798
- #ifdef __MINGW32__
799
- __attribute__((gnu_format(printf, 1, 2)))
800
- #else
801
- __attribute__((format(printf, 1, 2)))
802
- #endif
803
- #endif
804
- static void log(const char * fmt, ...) {
805
- if (!whisper_log) return;
806
- char buf[1024];
807
- va_list args;
808
- va_start(args, fmt);
809
- vsnprintf(buf, sizeof(buf), fmt, args);
810
- whisper_log(buf);
811
- }
812
 
813
  template<typename T>
814
  static void read_safe(whisper_model_loader * loader, T & dest) {
@@ -819,6 +794,7 @@ static void read_safe(whisper_model_loader * loader, T & dest) {
819
  static bool kv_cache_init(
820
  const struct whisper_hparams & hparams,
821
  struct whisper_kv_cache & cache,
 
822
  ggml_type wtype,
823
  int n_ctx) {
824
  const int64_t n_text_state = hparams.n_text_state;
@@ -827,30 +803,41 @@ static bool kv_cache_init(
827
  const int64_t n_mem = n_text_layer*n_ctx;
828
  const int64_t n_elements = n_text_state*n_mem;
829
 
830
- const size_t mem_bytes = 2*(ggml_type_size(wtype)*n_elements + ggml_tensor_overhead());
831
-
832
- cache.buf.resize(mem_bytes);
833
-
834
  struct ggml_init_params params = {
835
- /*.mem_size =*/ cache.buf.size(),
836
- /*.mem_buffer =*/ cache.buf.data(),
837
- /*.no_alloc =*/ false,
838
  };
839
 
840
  cache.ctx = ggml_init(params);
841
 
842
  if (!cache.ctx) {
843
- log("%s: failed to allocate memory for kv cache\n", __func__);
844
  return false;
845
  }
846
 
847
  cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
848
  cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
849
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
850
  return true;
851
  }
852
 
853
- static bool kv_cache_reinit(struct whisper_kv_cache & cache) {
 
854
  WHISPER_ASSERT(cache.ctx);
855
 
856
  const int n_elements = ggml_nelements(cache.k);
@@ -859,34 +846,78 @@ static bool kv_cache_reinit(struct whisper_kv_cache & cache) {
859
  const ggml_type wtype = cache.k->type;
860
  WHISPER_ASSERT(wtype == cache.v->type);
861
 
862
- WHISPER_ASSERT(cache.buf.size() >= 2*n_elements*ggml_type_sizef(wtype));
863
-
864
  struct ggml_init_params params = {
865
- /*.mem_size =*/ cache.buf.size(),
866
- /*.mem_buffer =*/ cache.buf.data(),
867
- /*.no_alloc =*/ false,
868
  };
869
 
870
  cache.ctx = ggml_init(params);
871
 
872
  if (!cache.ctx) {
873
- log("%s: failed to allocate memory for kv cache\n", __func__);
874
  return false;
875
  }
876
 
877
  cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
878
  cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
879
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
880
  return true;
881
  }
882
 
883
  static void kv_cache_free(struct whisper_kv_cache & cache) {
884
  if (cache.ctx) {
885
  ggml_free(cache.ctx);
 
886
  cache.ctx = nullptr;
887
  }
888
  }
889
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
890
  // load the model from a ggml file
891
  //
892
  // file format:
@@ -899,7 +930,7 @@ static void kv_cache_free(struct whisper_kv_cache & cache) {
899
  // see the convert-pt-to-ggml.py script for details
900
  //
901
  static bool whisper_model_load(struct whisper_model_loader * loader, whisper_context & wctx) {
902
- log("%s: loading model\n", __func__);
903
 
904
  const int64_t t_start_us = ggml_time_us();
905
 
@@ -913,7 +944,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
913
  uint32_t magic;
914
  read_safe(loader, magic);
915
  if (magic != GGML_FILE_MAGIC) {
916
- log("%s: invalid model data (bad magic)\n", __func__);
917
  return false;
918
  }
919
  }
@@ -970,41 +1001,23 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
970
  // in order to save memory and also to speed up the computation
971
  wctx.wtype = ggml_ftype_to_ggml_type((ggml_ftype) (model.hparams.ftype));
972
  if (wctx.wtype == GGML_TYPE_COUNT) {
973
- log("%s: invalid model (bad ftype value %d)\n", __func__, model.hparams.ftype);
974
  return false;
975
  }
976
 
977
- const size_t scale = model.hparams.ftype ? 1 : 2;
978
-
979
- log("%s: n_vocab = %d\n", __func__, hparams.n_vocab);
980
- log("%s: n_audio_ctx = %d\n", __func__, hparams.n_audio_ctx);
981
- log("%s: n_audio_state = %d\n", __func__, hparams.n_audio_state);
982
- log("%s: n_audio_head = %d\n", __func__, hparams.n_audio_head);
983
- log("%s: n_audio_layer = %d\n", __func__, hparams.n_audio_layer);
984
- log("%s: n_text_ctx = %d\n", __func__, hparams.n_text_ctx);
985
- log("%s: n_text_state = %d\n", __func__, hparams.n_text_state);
986
- log("%s: n_text_head = %d\n", __func__, hparams.n_text_head);
987
- log("%s: n_text_layer = %d\n", __func__, hparams.n_text_layer);
988
- log("%s: n_mels = %d\n", __func__, hparams.n_mels);
989
- log("%s: ftype = %d\n", __func__, model.hparams.ftype);
990
- log("%s: qntvr = %d\n", __func__, qntvr);
991
- log("%s: type = %d (%s%s)\n", __func__, model.type, g_model_name.at(model.type).c_str(), mver.c_str());
992
-
993
- // print memory requirements
994
- {
995
- // TODO
996
- //log("%s: mem required = %7.2f MB (+ %7.2f MB per decoder)\n", __func__,
997
- // mem_required / 1024.0 / 1024.0, mem_required_decoder / 1024.0 / 1024.0);
998
- }
999
-
1000
- // initialize all memory buffers
1001
- // always have at least one decoder
1002
-
1003
- wctx.model.buf = new std::vector<uint8_t>();
1004
- wctx.model.buf->resize(scale*MEM_REQ_MODEL.at(wctx.wtype).at(model.type));
1005
-
1006
- // we skip initialization of the state until it is needed
1007
- // because it might be that state will always be provided externally.
1008
  }
1009
 
1010
  // load mel filters
@@ -1025,7 +1038,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
1025
  read_safe(loader, n_vocab);
1026
 
1027
  //if (n_vocab != model.hparams.n_vocab) {
1028
- // log("%s: invalid model file '%s' (bad vocab size %d != %d)\n",
1029
  // __func__, fname.c_str(), n_vocab, model.hparams.n_vocab);
1030
  // return false;
1031
  //}
@@ -1045,7 +1058,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
1045
  word.assign(&tmp[0], tmp.size());
1046
  } else {
1047
  // seems like we have an empty-string token in multi-language models (i = 50256)
1048
- //log("%s: warning: empty-string token in vocab, i = %d\n", __func__, i);
1049
  word = "";
1050
  }
1051
 
@@ -1073,7 +1086,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
1073
  }
1074
 
1075
  if (n_vocab < model.hparams.n_vocab) {
1076
- log("%s: adding %d extra tokens\n", __func__, model.hparams.n_vocab - n_vocab);
1077
  for (int i = n_vocab; i < model.hparams.n_vocab; i++) {
1078
  if (i > vocab.token_beg) {
1079
  word = "[_TT_" + std::to_string(i - vocab.token_beg) + "]";
@@ -1099,140 +1112,35 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
1099
  }
1100
  }
1101
 
1102
- log("%s: n_langs = %d\n", __func__, vocab.num_languages());
1103
  }
1104
 
1105
- size_t ctx_size = 0;
1106
-
1107
  const ggml_type wtype = wctx.wtype;
1108
  const ggml_type vtype = wctx.wtype == GGML_TYPE_F32 ? GGML_TYPE_F32 : GGML_TYPE_F16; // conv type
1109
 
 
1110
  {
1111
  const auto & hparams = model.hparams;
1112
 
1113
- const int n_vocab = hparams.n_vocab;
1114
-
1115
- const int n_audio_ctx = hparams.n_audio_ctx;
1116
- const int n_audio_state = hparams.n_audio_state;
1117
  const int n_audio_layer = hparams.n_audio_layer;
 
1118
 
1119
- const int n_text_ctx = hparams.n_text_ctx;
1120
- const int n_text_state = hparams.n_text_state;
1121
- const int n_text_layer = hparams.n_text_layer;
1122
-
1123
- const int n_mels = hparams.n_mels;
1124
-
1125
- // encoder
1126
- {
1127
- ctx_size += n_audio_ctx*n_audio_state*ggml_type_sizef(GGML_TYPE_F32); // e_pe;
1128
-
1129
- ctx_size += 3*n_mels*n_audio_state*ggml_type_sizef(vtype); // e_conv_1_w
1130
- ctx_size += n_audio_state*ggml_type_sizef(GGML_TYPE_F32); // e_conv_1_b
1131
-
1132
- ctx_size += 3*n_audio_state*n_audio_state*ggml_type_sizef(vtype); // e_conv_2_w
1133
- ctx_size += n_audio_state*ggml_type_sizef(GGML_TYPE_F32); // e_conv_2_b
1134
-
1135
- ctx_size += n_audio_state*ggml_type_sizef(GGML_TYPE_F32); // e_ln_w;
1136
- ctx_size += n_audio_state*ggml_type_sizef(GGML_TYPE_F32); // e_ln_b;
1137
- }
1138
-
1139
- // decoder
1140
- {
1141
- ctx_size += n_text_ctx*n_text_state*ggml_type_sizef(GGML_TYPE_F32); // d_pe;
1142
-
1143
- ctx_size += n_vocab*n_text_state*ggml_type_sizef(wtype); // d_te;
1144
-
1145
- ctx_size += n_text_state*ggml_type_sizef(GGML_TYPE_F32); // d_ln_w;
1146
- ctx_size += n_text_state*ggml_type_sizef(GGML_TYPE_F32); // d_ln_b;
1147
- }
1148
-
1149
- // encoder layers
1150
- {
1151
- ctx_size += n_audio_layer*(n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_ln_w
1152
- ctx_size += n_audio_layer*(n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_ln_b
1153
-
1154
- ctx_size += n_audio_layer*(4*n_audio_state*n_audio_state*ggml_type_sizef(wtype)); // mlp_0_w
1155
- ctx_size += n_audio_layer*( 4*n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_0_b
1156
-
1157
- ctx_size += n_audio_layer*(4*n_audio_state*n_audio_state*ggml_type_sizef(wtype)); // mlp_1_w
1158
- ctx_size += n_audio_layer*( n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_1_b
1159
-
1160
- ctx_size += n_audio_layer*(n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_ln_0_w
1161
- ctx_size += n_audio_layer*(n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_ln_0_b
1162
-
1163
- ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_sizef(wtype)); // attn_q_w
1164
- ctx_size += n_audio_layer*( n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_q_b
1165
-
1166
- ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_sizef(wtype)); // attn_k_w
1167
-
1168
- ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_sizef(wtype)); // attn_v_w
1169
- ctx_size += n_audio_layer*( n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_v_b
1170
-
1171
- ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_sizef(wtype)); // attn_ln_1_w
1172
- ctx_size += n_audio_layer*( n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_ln_1_b
1173
- }
1174
-
1175
- // decoder layers
1176
- {
1177
- ctx_size += n_text_layer*(n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_ln_w
1178
- ctx_size += n_text_layer*(n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_ln_b
1179
-
1180
- ctx_size += n_text_layer*(4*n_text_state*n_text_state*ggml_type_sizef(wtype)); // mlp_0_w
1181
- ctx_size += n_text_layer*( 4*n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_0_b
1182
-
1183
- ctx_size += n_text_layer*(4*n_text_state*n_text_state*ggml_type_sizef(wtype)); // mlp_1_w
1184
- ctx_size += n_text_layer*( n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_1_b
1185
-
1186
- ctx_size += n_text_layer*(n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_ln_0_w
1187
- ctx_size += n_text_layer*(n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_ln_0_b
1188
-
1189
- ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // attn_q_w
1190
- ctx_size += n_text_layer*( n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_q_b
1191
-
1192
- ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // attn_k_w
1193
-
1194
- ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // attn_v_w
1195
- ctx_size += n_text_layer*( n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_v_b
1196
-
1197
- ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // attn_ln_1_w
1198
- ctx_size += n_text_layer*( n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_ln_1_b
1199
- //
1200
- ctx_size += n_text_layer*(n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // cross_attn_ln_0_w
1201
- ctx_size += n_text_layer*(n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // cross_attn_ln_0_b
1202
-
1203
- ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // cross_attn_q_w
1204
- ctx_size += n_text_layer*( n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // cross_attn_q_b
1205
-
1206
- ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // cross_attn_k_w
1207
-
1208
- ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // cross_attn_v_w
1209
- ctx_size += n_text_layer*( n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // cross_attn_v_b
1210
-
1211
- ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // cross_attn_ln_1_w
1212
- ctx_size += n_text_layer*( n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // cross_attn_ln_1_b
1213
- }
1214
-
1215
- ctx_size += (15 + 15*n_audio_layer + 24*n_text_layer)*512; // object overhead
1216
-
1217
- log("%s: model ctx = %7.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
1218
- }
1219
 
1220
- // create the ggml context
1221
- {
1222
  struct ggml_init_params params = {
1223
- /*.mem_size =*/ wctx.model.buf->size(),
1224
- /*.mem_buffer =*/ wctx.model.buf->data(),
1225
- /*.no_alloc =*/ false,
1226
  };
1227
 
1228
  model.ctx = ggml_init(params);
1229
  if (!model.ctx) {
1230
- log("%s: ggml_init() failed\n", __func__);
1231
  return false;
1232
  }
1233
  }
1234
 
1235
- // prepare memory for the weights
1236
  {
1237
  auto & ctx = model.ctx;
1238
 
@@ -1255,16 +1163,16 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
1255
 
1256
  // encoder
1257
  {
1258
- model.e_pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_state, n_audio_ctx);
1259
 
1260
- model.e_conv_1_w = ggml_new_tensor_3d(ctx, vtype, 3, n_mels, n_audio_state);
1261
- model.e_conv_1_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state);
1262
 
1263
- model.e_conv_2_w = ggml_new_tensor_3d(ctx, vtype, 3, n_audio_state, n_audio_state);
1264
- model.e_conv_2_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state);
1265
 
1266
- model.e_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
1267
- model.e_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
1268
 
1269
  // map by name
1270
  model.tensors["encoder.positional_embedding"] = model.e_pe;
@@ -1428,12 +1336,37 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
1428
  }
1429
  }
1430
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1431
  // load weights
1432
  {
1433
  size_t total_size = 0;
1434
 
1435
  model.n_loaded = 0;
1436
 
 
 
1437
  while (true) {
1438
  int32_t n_dims;
1439
  int32_t length;
@@ -1460,50 +1393,92 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
1460
  name.assign(&tmp[0], tmp.size());
1461
 
1462
  if (model.tensors.find(name) == model.tensors.end()) {
1463
- log("%s: unknown tensor '%s' in model file\n", __func__, name.data());
1464
  return false;
1465
  }
1466
 
1467
  auto tensor = model.tensors[name.data()];
1468
- if (ggml_nelements(tensor) != nelements) {
1469
- log("%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
1470
- log("%s: shape: [%d, %d, %d], expected: [%d, %d, %d]\n",
1471
- __func__, ne[0], ne[1], ne[2], (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2]);
1472
- return false;
1473
- }
1474
 
1475
- if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2]) {
1476
- log("%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]\n",
1477
- __func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2], ne[0], ne[1], ne[2]);
1478
- return false;
1479
- }
1480
 
1481
- const size_t bpe = ggml_type_size(ggml_type(ttype));
 
 
 
 
 
 
1482
 
1483
- if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
1484
- log("%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
1485
- __func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
1486
- return false;
 
 
 
 
 
 
 
 
 
1487
  }
1488
 
1489
- loader->read(loader->context, tensor->data, ggml_nbytes(tensor));
1490
- BYTESWAP_TENSOR(tensor);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1491
 
1492
  //printf("%48s - [%5d, %5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ne[2], ggml_type_name((ggml_type) ttype), ggml_nbytes(tensor)/1024.0/1024.0);
1493
  total_size += ggml_nbytes(tensor);
1494
  model.n_loaded++;
1495
  }
1496
 
1497
- log("%s: model size = %7.2f MB\n", __func__, total_size/1024.0/1024.0);
1498
 
1499
  if (model.n_loaded == 0) {
1500
- log("%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__);
1501
  } else if (model.n_loaded != (int) model.tensors.size()) {
1502
- log("%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", __func__, model.tensors.size(), model.n_loaded);
1503
  return false;
1504
  }
1505
  }
1506
 
 
 
1507
  wctx.t_load_us = ggml_time_us() - t_start_us;
1508
 
1509
  return true;
@@ -1559,10 +1534,12 @@ static struct ggml_cgraph * whisper_build_graph_conv(
1559
  if (!ggml_allocr_is_measure(alloc)) {
1560
  assert(mel_inp.n_mel == n_mels);
1561
 
1562
- float * dst = (float *) mel->data;
 
 
1563
  memset(dst, 0, ggml_nbytes(mel));
1564
 
1565
- const int i0 = std::min(mel_offset, mel_inp.n_len);
1566
  const int i1 = std::min(mel_offset + 2*n_ctx, mel_inp.n_len);
1567
 
1568
  for (int j = 0; j < mel_inp.n_mel; ++j) {
@@ -1570,6 +1547,8 @@ static struct ggml_cgraph * whisper_build_graph_conv(
1570
  dst[j*2*n_ctx + (i - i0)] = mel_inp.data[j*mel_inp.n_len + i];
1571
  }
1572
  }
 
 
1573
  }
1574
 
1575
  struct ggml_tensor * cur = nullptr;
@@ -1578,24 +1557,27 @@ static struct ggml_cgraph * whisper_build_graph_conv(
1578
  // convolution + gelu
1579
  {
1580
  cur = ggml_conv_1d_ph(ctx0, model.e_conv_1_w, mel, 1, 1);
1581
- cur = ggml_add(ctx0,
1582
- ggml_repeat(ctx0,
1583
- model.e_conv_1_b,
1584
- cur),
1585
- cur);
 
1586
 
1587
  cur = ggml_gelu(ctx0, cur);
1588
 
1589
  cur = ggml_conv_1d_ph(ctx0, model.e_conv_2_w, cur, 2, 1);
1590
- cur = ggml_add(ctx0,
1591
- ggml_repeat(ctx0,
1592
- model.e_conv_2_b,
1593
- cur),
1594
- cur);
 
1595
 
1596
  cur = ggml_gelu(ctx0, cur);
1597
  }
1598
 
 
1599
  wstate.embd_conv = cur;
1600
  } else {
1601
  #ifdef WHISPER_USE_COREML
@@ -1615,6 +1597,7 @@ static struct ggml_cgraph * whisper_build_graph_conv(
1615
  }
1616
  #endif
1617
 
 
1618
  wstate.embd_enc = cur;
1619
  }
1620
 
@@ -1648,15 +1631,22 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
1648
 
1649
  ggml_allocr * alloc = wstate.alloc_encode.alloc;
1650
 
 
 
 
 
 
 
 
 
1651
  struct ggml_tensor * KQscale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
1652
  ggml_allocr_alloc(alloc, KQscale);
1653
 
1654
  if (!ggml_allocr_is_measure(alloc)) {
1655
- ggml_set_f32(KQscale, 1.0f/sqrt(float(n_state)/n_head));
 
1656
  }
1657
 
1658
- struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_conv);
1659
-
1660
  // ===================================================================
1661
  // NOTE: experimenting with partial evaluation of the encoder (ignore)
1662
  //static int iter = -1;
@@ -1675,7 +1665,6 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
1675
  const size_t e_pe_offset = model.e_pe->ne[0]*ggml_element_size(model.e_pe)*n_ctx*iter;
1676
 
1677
  struct ggml_tensor * e_pe = ggml_view_2d(ctx0, model.e_pe, model.e_pe->ne[0], n_ctx, e_pe_stride, e_pe_offset);
1678
-
1679
  cur = ggml_add(ctx0, e_pe, ggml_cont(ctx0, ggml_transpose(ctx0, cur)));
1680
 
1681
  // ===================================================================
@@ -1897,13 +1886,20 @@ static struct ggml_cgraph * whisper_build_graph_cross(
1897
 
1898
  ggml_allocr * alloc = wstate.alloc_cross.alloc;
1899
 
 
 
 
 
 
 
1900
  struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_enc);
1901
 
1902
  struct ggml_tensor * Kscale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
1903
  ggml_allocr_alloc(alloc, Kscale);
1904
 
1905
  if (!ggml_allocr_is_measure(alloc)) {
1906
- ggml_set_f32(Kscale, pow(float(n_state) / n_head, -0.25));
 
1907
  }
1908
 
1909
  for (int il = 0; il < model.hparams.n_text_layer; ++il) {
@@ -1974,7 +1970,7 @@ static bool whisper_encode_internal(
1974
  ggml_allocr_alloc_graph(alloc, gf);
1975
 
1976
  if (!whisper_encode_external(wstate)) {
1977
- ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
1978
  }
1979
  }
1980
 
@@ -1988,16 +1984,7 @@ static bool whisper_encode_internal(
1988
 
1989
  ggml_allocr_alloc_graph(alloc, gf);
1990
 
1991
- #ifdef GGML_USE_METAL
1992
- if (wstate.ctx_metal) {
1993
- ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
1994
- ggml_metal_graph_compute(wstate.ctx_metal, gf);
1995
- } else {
1996
- ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
1997
- }
1998
- #else
1999
- ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
2000
- #endif
2001
  }
2002
 
2003
  // cross
@@ -2010,20 +1997,9 @@ static bool whisper_encode_internal(
2010
 
2011
  ggml_allocr_alloc_graph(alloc, gf);
2012
 
2013
- #ifdef GGML_USE_METAL
2014
- if (wstate.ctx_metal) {
2015
- ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
2016
- ggml_metal_graph_compute(wstate.ctx_metal, gf);
2017
- } else {
2018
- ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
2019
- }
2020
- #else
2021
- ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
2022
- #endif
2023
  }
2024
 
2025
- // ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
2026
-
2027
  wstate.t_encode_us += ggml_time_us() - t_start_us;
2028
  wstate.n_encode++;
2029
 
@@ -2070,7 +2046,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
2070
  ggml_allocr_alloc(alloc, embd);
2071
 
2072
  if (!ggml_allocr_is_measure(alloc)) {
2073
- memcpy(embd->data, tokens, N*ggml_element_size(embd));
2074
  }
2075
 
2076
  struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
@@ -2078,7 +2054,8 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
2078
 
2079
  if (!ggml_allocr_is_measure(alloc)) {
2080
  for (int i = 0; i < N; ++i) {
2081
- ((int32_t *) position->data)[i] = n_past + i;
 
2082
  }
2083
  }
2084
 
@@ -2086,7 +2063,8 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
2086
  ggml_allocr_alloc(alloc, KQscale);
2087
 
2088
  if (!ggml_allocr_is_measure(alloc)) {
2089
- ggml_set_f32(KQscale, pow(float(n_state)/n_head, -0.25));
 
2090
  }
2091
 
2092
  // token encoding + position encoding
@@ -2410,25 +2388,18 @@ static bool whisper_decode_internal(
2410
 
2411
  logits = gf->nodes[gf->n_nodes - 1];
2412
 
2413
- #ifdef GGML_USE_METAL
2414
- if (wstate.ctx_metal) {
2415
- ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
2416
- ggml_metal_graph_compute(wstate.ctx_metal, gf);
2417
- } else {
2418
- ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
2419
- }
2420
- #else
2421
- ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
2422
- #endif
2423
  }
2424
 
2425
  // extract logits for all N tokens
2426
  //logits_out.resize(n_tokens*n_vocab);
2427
  //memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_tokens*n_vocab);
 
2428
 
2429
  // extract logits only for the last token
2430
  logits_out.resize(n_vocab);
2431
- memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_vocab);
 
2432
 
2433
  if (n_tokens > 1) {
2434
  //printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
@@ -2794,7 +2765,7 @@ static std::vector<whisper_vocab::id> tokenize(const whisper_vocab & vocab, cons
2794
  --j;
2795
  }
2796
  if (!found) {
2797
- log("unknown token\n");
2798
  ++i;
2799
  }
2800
  }
@@ -2857,45 +2828,48 @@ static std::string whisper_openvino_get_path_cache(std::string path_bin) {
2857
 
2858
  struct whisper_state * whisper_init_state(whisper_context * ctx) {
2859
  fill_sin_cos_table();
 
2860
  whisper_state * state = new whisper_state;
2861
 
2862
- if (!kv_cache_init(ctx->model.hparams, state->decoders[0].kv_self, ctx->itype, ctx->model.hparams.n_text_ctx)) {
2863
- log("%s: kv_cache_init() failed for self-attention cache\n", __func__);
 
 
2864
  delete state;
2865
  return nullptr;
2866
  }
2867
 
2868
  {
2869
  const size_t memory_size = ggml_nbytes(state->decoders[0].kv_self.k) + ggml_nbytes(state->decoders[0].kv_self.v);
2870
- log("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
2871
  }
2872
 
2873
- if (!kv_cache_init(ctx->model.hparams, state->kv_cross, ctx->itype, ctx->model.hparams.n_audio_ctx)) {
2874
- log("%s: kv_cache_init() failed for cross-attention cache\n", __func__);
2875
  delete state;
2876
  return nullptr;
2877
  }
2878
 
2879
  {
2880
  const size_t memory_size = ggml_nbytes(state->kv_cross.k) + ggml_nbytes(state->kv_cross.v);
2881
- log("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
2882
  }
2883
 
2884
  #ifdef WHISPER_USE_COREML
2885
  const auto path_coreml = whisper_get_coreml_path_encoder(ctx->path_model);
2886
 
2887
- log("%s: loading Core ML model from '%s'\n", __func__, path_coreml.c_str());
2888
- log("%s: first run on a device may take a while ...\n", __func__);
2889
 
2890
  state->ctx_coreml = whisper_coreml_init(path_coreml.c_str());
2891
  if (!state->ctx_coreml) {
2892
- log("%s: failed to load Core ML model from '%s'\n", __func__, path_coreml.c_str());
2893
  #ifndef WHISPER_COREML_ALLOW_FALLBACK
2894
  delete state;
2895
  return nullptr;
2896
  #endif
2897
  } else {
2898
- log("%s: Core ML model loaded\n", __func__);
2899
  }
2900
  #endif
2901
 
@@ -2912,37 +2886,37 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
2912
 
2913
  // conv allocator
2914
  {
2915
- whisper_allocr_graph_init(state->alloc_conv,
2916
  [&]() {
2917
  return whisper_build_graph_conv(*ctx, *state, 0);
2918
  });
2919
 
2920
- log("%s: compute buffer (conv) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_conv) / 1024.0 / 1024.0);
2921
  }
2922
 
2923
  // encoder allocator
2924
  if (!whisper_encode_external(*state)) {
2925
- whisper_allocr_graph_init(state->alloc_encode,
2926
  [&]() {
2927
  return whisper_build_graph_encoder(*ctx, *state);
2928
  });
2929
 
2930
- log("%s: compute buffer (encode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_encode) / 1024.0 / 1024.0);
2931
  }
2932
 
2933
  // cross allocator
2934
  {
2935
- whisper_allocr_graph_init(state->alloc_cross,
2936
  [&]() {
2937
  return whisper_build_graph_cross(*ctx, *state);
2938
  });
2939
 
2940
- log("%s: compute buffer (cross) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_cross) / 1024.0 / 1024.0);
2941
  }
2942
 
2943
  // decoder allocator
2944
  {
2945
- whisper_allocr_graph_init(state->alloc_decode,
2946
  [&]() {
2947
  const auto & hparams = ctx->model.hparams;
2948
 
@@ -2953,69 +2927,13 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
2953
  return whisper_build_graph_decoder(*ctx, *state, state->decoders[0], nullptr, n_tokens, n_past);
2954
  });
2955
 
2956
- log("%s: compute buffer (decode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_decode) / 1024.0 / 1024.0);
2957
  }
2958
 
2959
- #ifdef GGML_USE_METAL
2960
- if (ctx->params.use_gpu) {
2961
- state->ctx_metal = ggml_metal_init(1);
2962
- if (!state->ctx_metal) {
2963
- log("%s: ggml_metal_init() failed\n", __func__);
2964
- delete state;
2965
- return nullptr;
2966
- }
2967
- }
2968
-
2969
- if (state->ctx_metal) {
2970
- log("%s: Metal context initialized\n", __func__);
2971
-
2972
- // this allocates all Metal resources and memory buffers
2973
-
2974
- void * data_ptr = NULL;
2975
- size_t data_size = 0;
2976
-
2977
- // TODO: add mmap support
2978
- //if (params.use_mmap) {
2979
- // data_ptr = ctx->model.mapping->addr;
2980
- // data_size = ctx->model.mapping->size;
2981
- //} else {
2982
- // data_ptr = ggml_get_mem_buffer(ctx->model.ctx);
2983
- // data_size = ggml_get_mem_size (ctx->model.ctx);
2984
- //}
2985
-
2986
- data_ptr = ggml_get_mem_buffer(ctx->model.ctx);
2987
- data_size = ggml_get_mem_size (ctx->model.ctx);
2988
-
2989
- const size_t max_size = ggml_get_max_tensor_size(ctx->model.ctx);
2990
-
2991
- log("%s: max tensor size = %8.2f MB\n", __func__, max_size/1024.0/1024.0);
2992
-
2993
- #define WHISPER_METAL_CHECK_BUF(result) \
2994
- if (!(result)) { \
2995
- log("%s: failed to add metal buffer\n", __func__); \
2996
- delete state; \
2997
- return nullptr; \
2998
- }
2999
-
3000
- WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data", data_ptr, data_size, max_size));
3001
-
3002
- WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_conv", state->alloc_conv.meta.data(), state->alloc_conv.meta.size(), 0));
3003
- WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_encode", state->alloc_encode.meta.data(), state->alloc_encode.meta.size(), 0));
3004
- WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_cross", state->alloc_cross.meta.data(), state->alloc_cross.meta.size(), 0));
3005
- WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_decode", state->alloc_decode.meta.data(), state->alloc_decode.meta.size(), 0));
3006
-
3007
- WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_conv", state->alloc_conv.data.data(), state->alloc_conv.data.size(), 0));
3008
- WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_encode", state->alloc_encode.data.data(), state->alloc_encode.data.size(), 0));
3009
- WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_cross", state->alloc_cross.data.data(), state->alloc_cross.data.size(), 0));
3010
- WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_decode", state->alloc_decode.data.data(), state->alloc_decode.data.size(), 0));
3011
-
3012
- WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "kv_cross", state->kv_cross.buf.data(), state->kv_cross.buf.size(), 0));
3013
-
3014
- WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "kv_self_0", state->decoders[0].kv_self.buf.data(), state->decoders[0].kv_self.buf.size(), 0));
3015
- #undef WHISPER_METAL_CHECK_BUF
3016
-
3017
- }
3018
- #endif
3019
 
3020
  state->rng = std::mt19937(0);
3021
 
@@ -3036,7 +2954,7 @@ int whisper_ctx_init_openvino_encoder(
3036
  return 1;
3037
  #else
3038
  if (!model_path && ctx->path_model.empty()) {
3039
- log("%s: model_path is nullptr, and ctx has no model_path set.\n", __func__);
3040
  return 1;
3041
  }
3042
 
@@ -3056,15 +2974,15 @@ int whisper_ctx_init_openvino_encoder(
3056
  path_cache = cache_dir;
3057
  }
3058
 
3059
- log("%s: loading OpenVINO model from '%s'\n", __func__, path_encoder.c_str());
3060
- log("%s: first run on a device may take a while ...\n", __func__);
3061
 
3062
  ctx->state->ctx_openvino = whisper_openvino_init(path_encoder.c_str(), device, path_cache.c_str());
3063
  if (!ctx->state->ctx_openvino) {
3064
- log("%s: failed to init OpenVINO encoder from '%s'\n", __func__, path_encoder.c_str());
3065
  return 1;
3066
  } else {
3067
- log("%s: OpenVINO model loaded\n", __func__);
3068
  }
3069
 
3070
  return 0;
@@ -3079,11 +2997,11 @@ struct whisper_context_params whisper_context_default_params() {
3079
  }
3080
 
3081
  struct whisper_context * whisper_init_from_file_with_params_no_state(const char * path_model, struct whisper_context_params params) {
3082
- log("%s: loading model from '%s'\n", __func__, path_model);
3083
 
3084
  auto fin = std::ifstream(path_model, std::ios::binary);
3085
  if (!fin) {
3086
- log("%s: failed to open '%s'\n", __func__, path_model);
3087
  return nullptr;
3088
  }
3089
 
@@ -3125,7 +3043,7 @@ struct whisper_context * whisper_init_from_buffer_with_params_no_state(void * bu
3125
 
3126
  buf_context ctx = { reinterpret_cast<uint8_t*>(buffer), buffer_size, 0 };
3127
 
3128
- log("%s: loading model from buffer\n", __func__);
3129
 
3130
  whisper_model_loader loader = {};
3131
 
@@ -3161,7 +3079,7 @@ struct whisper_context * whisper_init_with_params_no_state(struct whisper_model_
3161
 
3162
  if (!whisper_model_load(loader, *ctx)) {
3163
  loader->close(loader->context);
3164
- log("%s: failed to load model\n", __func__);
3165
  delete ctx;
3166
  return nullptr;
3167
  }
@@ -3256,13 +3174,6 @@ void whisper_free_state(struct whisper_state * state)
3256
  }
3257
  #endif
3258
 
3259
- #ifdef GGML_USE_METAL
3260
- if (state->ctx_metal) {
3261
- ggml_metal_free(state->ctx_metal);
3262
- state->ctx_metal = nullptr;
3263
- }
3264
- #endif
3265
-
3266
  #ifdef WHISPER_USE_OPENVINO
3267
  if (state->ctx_openvino != nullptr) {
3268
  whisper_openvino_free(state->ctx_openvino);
@@ -3271,9 +3182,11 @@ void whisper_free_state(struct whisper_state * state)
3271
  #endif
3272
 
3273
  whisper_allocr_free(state->alloc_conv);
3274
- whisper_allocr_free(state->alloc_decode);
3275
- whisper_allocr_free(state->alloc_cross);
3276
  whisper_allocr_free(state->alloc_encode);
 
 
 
 
3277
 
3278
  delete state;
3279
  }
@@ -3284,12 +3197,15 @@ void whisper_free(struct whisper_context * ctx) {
3284
  if (ctx->model.ctx) {
3285
  ggml_free(ctx->model.ctx);
3286
  }
3287
- if (ctx->model.buf) {
3288
- delete ctx->model.buf;
 
3289
  }
3290
 
3291
  whisper_free_state(ctx->state);
3292
 
 
 
3293
  delete ctx;
3294
  }
3295
  }
@@ -3308,7 +3224,7 @@ void whisper_free_params(struct whisper_full_params * params) {
3308
 
3309
  int whisper_pcm_to_mel_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
3310
  if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, ctx->model.filters.n_mel, n_threads, ctx->model.filters, false, state->mel)) {
3311
- log("%s: failed to compute mel spectrogram\n", __func__);
3312
  return -1;
3313
  }
3314
 
@@ -3322,7 +3238,7 @@ int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int
3322
  // same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 (PV without phase lock is not good)
3323
  int whisper_pcm_to_mel_phase_vocoder_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
3324
  if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, 2 * WHISPER_N_FFT, 2 * WHISPER_HOP_LENGTH, ctx->model.filters.n_mel, n_threads, ctx->model.filters, false, state->mel)) {
3325
- log("%s: failed to compute mel spectrogram\n", __func__);
3326
  return -1;
3327
  }
3328
 
@@ -3350,7 +3266,7 @@ int whisper_set_mel_with_state(
3350
  int n_len,
3351
  int n_mel) {
3352
  if (n_mel != ctx->model.filters.n_mel) {
3353
- log("%s: invalid number of mel bands: %d (expected %d)\n", __func__, n_mel, ctx->model.filters.n_mel);
3354
  return -1;
3355
  }
3356
 
@@ -3374,7 +3290,7 @@ int whisper_set_mel(
3374
 
3375
  int whisper_encode_with_state(struct whisper_context * ctx, struct whisper_state * state, int offset, int n_threads) {
3376
  if (!whisper_encode_internal(*ctx, *state, offset, n_threads, nullptr, nullptr)) {
3377
- log("%s: failed to eval\n", __func__);
3378
  return -1;
3379
  }
3380
 
@@ -3383,7 +3299,7 @@ int whisper_encode_with_state(struct whisper_context * ctx, struct whisper_state
3383
 
3384
  int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
3385
  if (!whisper_encode_internal(*ctx, *ctx->state, offset, n_threads, nullptr, nullptr)) {
3386
- log("%s: failed to eval\n", __func__);
3387
  return -1;
3388
  }
3389
 
@@ -3394,7 +3310,7 @@ int whisper_decode_with_state(struct whisper_context * ctx, struct whisper_state
3394
  const int selected_decoder_id = 0;
3395
 
3396
  if (!whisper_decode_internal(*ctx, *state, state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads, nullptr, nullptr)) {
3397
- log("%s: failed to eval\n", __func__);
3398
  return 1;
3399
  }
3400
 
@@ -3406,12 +3322,12 @@ int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, i
3406
  const int selected_decoder_id = 0;
3407
 
3408
  if (ctx->state == nullptr) {
3409
- log("%s: ERROR state was not loaded.\n", __func__);
3410
  return false;
3411
  }
3412
 
3413
  if (!whisper_decode_internal(*ctx, *ctx->state, ctx->state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads, nullptr, nullptr)) {
3414
- log("%s: failed to eval\n", __func__);
3415
  return 1;
3416
  }
3417
 
@@ -3422,7 +3338,7 @@ int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_to
3422
  const auto res = tokenize(ctx->vocab, text);
3423
 
3424
  if (n_max_tokens < (int) res.size()) {
3425
- log("%s: too many resulting tokens: %d (max %d)\n", __func__, (int) res.size(), n_max_tokens);
3426
  return -1;
3427
  }
3428
 
@@ -3450,7 +3366,7 @@ int whisper_lang_id(const char * lang) {
3450
  }
3451
  }
3452
 
3453
- log("%s: unknown language '%s'\n", __func__, lang);
3454
  return -1;
3455
  }
3456
  return g_lang.at(lang).first;
@@ -3463,7 +3379,7 @@ const char * whisper_lang_str(int id) {
3463
  }
3464
  }
3465
 
3466
- log("%s: unknown language id %d\n", __func__, id);
3467
  return nullptr;
3468
  }
3469
 
@@ -3476,25 +3392,25 @@ int whisper_lang_auto_detect_with_state(
3476
  const int seek = offset_ms/10;
3477
 
3478
  if (seek < 0) {
3479
- log("%s: offset %dms is before the start of the audio\n", __func__, offset_ms);
3480
  return -1;
3481
  }
3482
 
3483
  if (seek >= state->mel.n_len_org) {
3484
- log("%s: offset %dms is past the end of the audio (%dms)\n", __func__, offset_ms, state->mel.n_len_org*10);
3485
  return -2;
3486
  }
3487
 
3488
  // run the encoder
3489
  if (whisper_encode_with_state(ctx, state, seek, n_threads) != 0) {
3490
- log("%s: failed to encode\n", __func__);
3491
  return -6;
3492
  }
3493
 
3494
  const std::vector<whisper_token> prompt = { whisper_token_sot(ctx) };
3495
 
3496
  if (whisper_decode_with_state(ctx, state, prompt.data(), prompt.size(), 0, n_threads) != 0) {
3497
- log("%s: failed to decode\n", __func__);
3498
  return -7;
3499
  }
3500
 
@@ -3694,8 +3610,8 @@ whisper_token whisper_token_transcribe(struct whisper_context * ctx) {
3694
  void whisper_print_timings(struct whisper_context * ctx) {
3695
  const int64_t t_end_us = ggml_time_us();
3696
 
3697
- log("\n");
3698
- log("%s: load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0f);
3699
  if (ctx->state != nullptr) {
3700
 
3701
  const int32_t n_sample = std::max(1, ctx->state->n_sample);
@@ -3703,14 +3619,14 @@ void whisper_print_timings(struct whisper_context * ctx) {
3703
  const int32_t n_decode = std::max(1, ctx->state->n_decode);
3704
  const int32_t n_prompt = std::max(1, ctx->state->n_prompt);
3705
 
3706
- log("%s: fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h);
3707
- log("%s: mel time = %8.2f ms\n", __func__, ctx->state->t_mel_us / 1000.0f);
3708
- log("%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample);
3709
- log("%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode);
3710
- log("%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode);
3711
- log("%s: prompt time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_prompt_us, n_prompt, 1e-3f * ctx->state->t_prompt_us / n_prompt);
3712
  }
3713
- log("%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f);
3714
  }
3715
 
3716
  void whisper_reset_timings(struct whisper_context * ctx) {
@@ -3762,6 +3678,7 @@ const char * whisper_print_system_info(void) {
3762
  s += "SSE3 = " + std::to_string(ggml_cpu_has_sse3()) + " | ";
3763
  s += "SSSE3 = " + std::to_string(ggml_cpu_has_ssse3()) + " | ";
3764
  s += "VSX = " + std::to_string(ggml_cpu_has_vsx()) + " | ";
 
3765
  s += "COREML = " + std::to_string(whisper_has_coreml()) + " | ";
3766
  s += "OPENVINO = " + std::to_string(whisper_has_openvino()) + " | ";
3767
 
@@ -4056,7 +3973,7 @@ static void whisper_process_logits(
4056
  const bool last_was_timestamp = tokens_cur.size() > 0 && tokens_cur.back().id >= vocab.token_beg;
4057
  const bool penultimate_was_timestamp = tokens_cur.size() < 2 || tokens_cur[tokens_cur.size() - 2].id >= vocab.token_beg;
4058
 
4059
- //log("last_was_timestamp=%d penultimate_was_timestamp=%d\n", last_was_timestamp, penultimate_was_timestamp);
4060
 
4061
  if (last_was_timestamp) {
4062
  if (penultimate_was_timestamp) {
@@ -4132,7 +4049,7 @@ static void whisper_process_logits(
4132
 
4133
  const float max_text_token_logprob = *std::max_element(logprobs.begin(), logprobs.begin() + vocab.token_beg);
4134
 
4135
- //log("timestamp_logprob=%f max_text_token_logprob=%f\n", timestamp_logprob, max_text_token_logprob);
4136
 
4137
  if (timestamp_logprob > max_text_token_logprob) {
4138
  for (int i = 0; i < vocab.token_beg; ++i) {
@@ -4427,8 +4344,10 @@ static bool whisper_kv_swap_fast(
4427
  for (auto & i : two_copy) {
4428
  // make a copy of KV caches
4429
  WHISPER_PRINT_DEBUG("%s: store KV cache into swap: idx %d\n", __func__, i);
4430
- memcpy(kv_swap_bufs[i].k.data(), src[i].kv_self.k->data, kv_swap_bufs[i].k.size());
4431
- memcpy(kv_swap_bufs[i].v.data(), src[i].kv_self.v->data, kv_swap_bufs[i].v.size());
 
 
4432
  }
4433
 
4434
  // since two-copy decoder KV caches are protected by kv_swap_bufs, modify them first
@@ -4441,13 +4360,17 @@ static bool whisper_kv_swap_fast(
4441
  if (two_copy.find(view[i]) != two_copy.end()) {
4442
  // modify KV caches of decoder using data from kv_swap_bufs
4443
  WHISPER_PRINT_DEBUG("%s: two-copy decoder using swap buffers: swap[%d] -> %d\n", __func__, view[i], i);
4444
- memcpy(src[i].kv_self.k->data, kv_swap_bufs[view[i]].k.data(), kv_swap_bufs[view[i]].k.size());
4445
- memcpy(src[i].kv_self.v->data, kv_swap_bufs[view[i]].v.data(), kv_swap_bufs[view[i]].v.size());
 
 
4446
  } else {
4447
  // modify KV caches of decoder using data from correspond decoder KV caches directly
4448
  WHISPER_PRINT_DEBUG("%s: two-copy decoder without swap buffers: %d -> %d\n", __func__, view[i], i);
4449
- memcpy(src[i].kv_self.k->data, src[view[i]].kv_self.k->data, ggml_nbytes(src[view[i]].kv_self.k));
4450
- memcpy(src[i].kv_self.v->data, src[view[i]].kv_self.v->data, ggml_nbytes(src[view[i]].kv_self.v));
 
 
4451
  }
4452
  }
4453
 
@@ -4461,13 +4384,17 @@ static bool whisper_kv_swap_fast(
4461
  if (two_copy.find(view[i]) != two_copy.end()) {
4462
  // modify KV caches of decoder using data from kv_swap_bufs
4463
  WHISPER_PRINT_DEBUG("%s: one-copy decoder using swap buffers: swap[%d] -> %d\n", __func__, view[i], i);
4464
- memcpy(src[i].kv_self.k->data, kv_swap_bufs[view[i]].k.data(), kv_swap_bufs[view[i]].k.size());
4465
- memcpy(src[i].kv_self.v->data, kv_swap_bufs[view[i]].v.data(), kv_swap_bufs[view[i]].v.size());
 
 
4466
  } else {
4467
  // modify KV caches of decoder using data from correspond decoder KV caches directly
4468
  WHISPER_PRINT_DEBUG("%s: one-copy decoder without swap buffers: %d -> %d\n", __func__, view[i], i);
4469
- memcpy(src[i].kv_self.k->data, src[view[i]].kv_self.k->data, ggml_nbytes(src[view[i]].kv_self.k));
4470
- memcpy(src[i].kv_self.v->data, src[view[i]].kv_self.v->data, ggml_nbytes(src[view[i]].kv_self.v));
 
 
4471
  }
4472
  }
4473
 
@@ -4495,11 +4422,11 @@ int whisper_full_with_state(
4495
  // compute log mel spectrogram
4496
  if (params.speed_up) {
4497
  // TODO: Replace PV with more advanced algorithm
4498
- log("%s: failed to compute log mel spectrogram\n", __func__);
4499
  return -1;
4500
  } else {
4501
  if (whisper_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) {
4502
- log("%s: failed to compute log mel spectrogram\n", __func__);
4503
  return -2;
4504
  }
4505
  }
@@ -4511,13 +4438,13 @@ int whisper_full_with_state(
4511
 
4512
  const auto lang_id = whisper_lang_auto_detect_with_state(ctx, state, 0, params.n_threads, probs.data());
4513
  if (lang_id < 0) {
4514
- log("%s: failed to auto-detect language\n", __func__);
4515
  return -3;
4516
  }
4517
  state->lang_id = lang_id;
4518
  params.language = whisper_lang_str(lang_id);
4519
 
4520
- log("%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]);
4521
  if (params.detect_language) {
4522
  return 0;
4523
  }
@@ -4575,8 +4502,8 @@ int whisper_full_with_state(
4575
 
4576
  if (decoder.kv_self.ctx == nullptr) {
4577
  decoder.kv_self = state->decoders[0].kv_self;
4578
- if (!kv_cache_reinit(decoder.kv_self)) {
4579
- log("%s: kv_cache_reinit() failed for self-attention, decoder %d\n", __func__, j);
4580
  return -4;
4581
  }
4582
 
@@ -4587,23 +4514,6 @@ int whisper_full_with_state(
4587
  decoder.probs.resize (ctx->vocab.n_vocab);
4588
  decoder.logits.resize (ctx->vocab.n_vocab);
4589
  decoder.logprobs.resize(ctx->vocab.n_vocab);
4590
-
4591
- // TODO: not very clean - look for a better way and potentially merging with the init of decoder 0
4592
- #ifdef GGML_USE_METAL
4593
- if (state->ctx_metal) {
4594
- #define WHISPER_METAL_CHECK_BUF(result) \
4595
- if (!(result)) { \
4596
- log("%s: failed to add metal buffer\n", __func__); \
4597
- return 0; \
4598
- }
4599
-
4600
- const std::string kv_name = "kv_self_" + std::to_string(j);
4601
- auto & kv_self = decoder.kv_self;
4602
-
4603
- WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, kv_name.c_str(), kv_self.buf.data(), kv_self.buf.size(), 0));
4604
- #undef WHISPER_METAL_CHECK_BUF
4605
- }
4606
- #endif
4607
  }
4608
  }
4609
 
@@ -4637,7 +4547,7 @@ int whisper_full_with_state(
4637
 
4638
  // overwrite audio_ctx, max allowed is hparams.n_audio_ctx
4639
  if (params.audio_ctx > whisper_n_audio_ctx(ctx)) {
4640
- log("%s: audio_ctx is larger than the maximum allowed (%d > %d)\n", __func__, params.audio_ctx, whisper_n_audio_ctx(ctx));
4641
  return -5;
4642
  }
4643
  state->exp_n_audio_ctx = params.audio_ctx;
@@ -4662,7 +4572,7 @@ int whisper_full_with_state(
4662
  // distilled models require the "no_timestamps" token
4663
  // TODO: add input parameter (#1229)
4664
  if (is_distil) {
4665
- log("%s: using distilled model - forcing no_timestamps\n", __func__);
4666
  prompt_init.push_back(whisper_token_not(ctx));
4667
  }
4668
  }
@@ -4699,14 +4609,14 @@ int whisper_full_with_state(
4699
 
4700
  if (params.encoder_begin_callback) {
4701
  if (params.encoder_begin_callback(ctx, state, params.encoder_begin_callback_user_data) == false) {
4702
- log("%s: encoder_begin_callback returned false - aborting\n", __func__);
4703
  break;
4704
  }
4705
  }
4706
 
4707
  // encode audio features starting at offset seek
4708
  if (!whisper_encode_internal(*ctx, *state, seek, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
4709
- log("%s: failed to encode\n", __func__);
4710
  return -6;
4711
  }
4712
 
@@ -4789,7 +4699,7 @@ int whisper_full_with_state(
4789
  WHISPER_PRINT_DEBUG("\n\n");
4790
 
4791
  if (!whisper_decode_internal(*ctx, *state, state->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
4792
- log("%s: failed to decode\n", __func__);
4793
  return -7;
4794
  }
4795
 
@@ -4803,8 +4713,11 @@ int whisper_full_with_state(
4803
  for (int j = 1; j < n_decoders_cur; ++j) {
4804
  auto & decoder = state->decoders[j];
4805
 
4806
- memcpy(decoder.kv_self.k->data, state->decoders[0].kv_self.k->data, ggml_nbytes(decoder.kv_self.k));
4807
- memcpy(decoder.kv_self.v->data, state->decoders[0].kv_self.v->data, ggml_nbytes(decoder.kv_self.v));
 
 
 
4808
 
4809
  decoder.kv_self.n += prompt.size();
4810
 
@@ -5013,7 +4926,7 @@ int whisper_full_with_state(
5013
  //WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, kv_self.n %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.kv_self.n, decoder.seek_delta);
5014
 
5015
  if (!whisper_decode_internal(*ctx, *state, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
5016
- log("%s: failed to decode\n", __func__);
5017
  return -8;
5018
  }
5019
 
@@ -5339,12 +5252,12 @@ int whisper_full_parallel(
5339
  ctx->state->t_decode_us /= n_processors;
5340
 
5341
  // print information about the audio boundaries
5342
- log("\n");
5343
- log("%s: the audio has been split into %d chunks at the following times:\n", __func__, n_processors);
5344
  for (int i = 0; i < n_processors - 1; ++i) {
5345
- log("%s: split %d - %s\n", __func__, (i + 1), to_timestamp(100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t).c_str());
5346
  }
5347
- log("%s: the transcription quality may be degraded near these boundaries\n", __func__);
5348
 
5349
  return ret;
5350
  }
@@ -5586,12 +5499,12 @@ WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads) {
5586
  double tsum = 0.0;
5587
 
5588
  // heat-up
5589
- ggml_graph_compute_helper(work, gf, n_threads, nullptr, nullptr);
5590
 
5591
  for (int i = 0; i < n_max; ++i) {
5592
  const int64_t t0 = ggml_time_us();
5593
 
5594
- ggml_graph_compute_helper(work, gf, n_threads, nullptr, nullptr);
5595
 
5596
  const int64_t t1 = ggml_time_us();
5597
 
@@ -5709,7 +5622,7 @@ static void whisper_exp_compute_token_level_timestamps(
5709
  const int n_samples = state.energy.size();
5710
 
5711
  if (n_samples == 0) {
5712
- log("%s: no signal data available\n", __func__);
5713
  return;
5714
  }
5715
 
@@ -5930,6 +5843,38 @@ static void whisper_exp_compute_token_level_timestamps(
5930
  //}
5931
  }
5932
 
5933
- void whisper_set_log_callback(whisper_log_callback callback) {
5934
- whisper_log = callback;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5935
  }
 
1
  #include "whisper.h"
2
+
3
  #ifdef WHISPER_USE_COREML
4
  #include "coreml/whisper-encoder.h"
5
  #endif
6
 
7
  #ifdef GGML_USE_METAL
8
+ #include "ggml-metal.h"
9
+ #endif
10
+
11
+ #ifdef GGML_USE_CUBLAS
12
+ #include "ggml-cuda.h"
13
  #endif
14
 
15
  #ifdef WHISPER_USE_OPENVINO
 
18
 
19
  #include "ggml.h"
20
  #include "ggml-alloc.h"
21
+ #include "ggml-backend.h"
22
 
23
  #include <algorithm>
24
  #include <cassert>
 
103
  #define BYTESWAP_TENSOR(t) do {} while (0)
104
  #endif
105
 
106
+ #ifdef __GNUC__
107
+ #ifdef __MINGW32__
108
+ #define WHISPER_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
109
+ #else
110
+ #define WHISPER_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
111
+ #endif
112
+ #else
113
+ #define WHISPER_ATTRIBUTE_FORMAT(...)
114
+ #endif
115
+
116
+ //
117
+ // logging
118
+ //
119
+
120
+ WHISPER_ATTRIBUTE_FORMAT(2, 3)
121
+ static void whisper_log_internal (ggml_log_level level, const char* format, ...);
122
+ static void whisper_log_callback_default(ggml_log_level level, const char * text, void * user_data);
123
+
124
+ #define WHISPER_LOG_INFO(...) whisper_log_internal(GGML_LOG_LEVEL_INFO , __VA_ARGS__)
125
+ #define WHISPER_LOG_WARN(...) whisper_log_internal(GGML_LOG_LEVEL_WARN , __VA_ARGS__)
126
+ #define WHISPER_LOG_ERROR(...) whisper_log_internal(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
127
+
128
  #define WHISPER_ASSERT(x) \
129
  do { \
130
  if (!(x)) { \
131
+ WHISPER_LOG_ERROR("WHISPER_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
132
  abort(); \
133
  } \
134
  } while (0)
 
155
  //
156
 
157
  static void ggml_graph_compute_helper(
158
+ struct ggml_cgraph * graph,
159
  std::vector<uint8_t> & buf,
 
160
  int n_threads,
161
  whisper_abort_callback abort_callback,
162
  void * abort_callback_data) {
 
173
  ggml_graph_compute(graph, &plan);
174
  }
175
 
176
+ static void ggml_graph_compute_helper(
177
+ struct ggml_backend * backend,
178
+ struct ggml_cgraph * graph,
179
+ int n_threads) {
180
+ if (ggml_backend_is_cpu(backend)) {
181
+ ggml_backend_cpu_set_n_threads(backend, n_threads);
182
+ }
183
+ #ifdef GGML_USE_METAL
184
+ if (ggml_backend_is_metal(backend)) {
185
+ ggml_backend_metal_set_n_cb(backend, n_threads);
186
+ }
187
+ #endif
188
+ ggml_backend_graph_compute(backend, graph);
189
+ }
190
+
191
  // faster matrix multiplications for tensors that do not have dimension 0 divisible by "pad"
192
  // the idea is to represent the original matrix multiplication:
193
  //
 
222
  }
223
 
224
  // TODO: check if other platforms can benefit from this optimization
225
+ // TODO: CUDA is currently broken - seems ggml_mul_mat does not handle views correctly
226
  #if defined(GGML_USE_METAL)
227
  #define ggml_mul_mat ggml_mul_mat_pad
228
  #endif
 
349
  { "yue", { 99, "cantonese", } },
350
  };
351
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
352
  struct whisper_mel {
353
  int n_len;
354
  int n_len_org;
 
529
 
530
  struct ggml_context * ctx;
531
 
532
+ ggml_backend_buffer_t buffer;
 
533
 
534
  int n; // number of tokens currently in the cache
535
  };
 
568
  std::vector<whisper_layer_encoder> layers_encoder;
569
  std::vector<whisper_layer_decoder> layers_decoder;
570
 
571
+ // ggml context that contains all the meta information about the model tensors
572
  struct ggml_context * ctx;
573
 
574
+ // the model backend data is read-only and can be shared between processors
575
+ struct ggml_backend_buffer * buffer;
576
 
577
  // tensors
578
  int n_loaded;
 
637
  ggml_allocr * alloc = nullptr;
638
 
639
  std::vector<uint8_t> meta;
640
+
641
+ ggml_backend_buffer_t buffer;
642
  };
643
 
644
  static size_t whisper_allocr_size(struct whisper_allocr & allocr) {
645
+ return allocr.meta.size() + ggml_allocr_max_size(allocr.alloc);
646
  }
647
 
648
  // measure the memory usage of a graph and prepare the allocr's internal data buffer
649
+ static void whisper_allocr_graph_init(struct whisper_allocr & allocr, ggml_backend_t backend, std::function<struct ggml_cgraph *()> && get_graph) {
650
+ auto & alloc = allocr.alloc;
651
+ auto & meta = allocr.meta;
652
 
653
+ alloc = ggml_allocr_new_measure_from_backend(backend);
 
 
654
 
655
  meta.resize(ggml_tensor_overhead()*WHISPER_MAX_NODES + ggml_graph_overhead());
656
 
657
+ ggml_allocr_alloc_graph(alloc, get_graph());
658
+ }
659
 
660
+ static void whisper_allocr_graph_realloc(struct whisper_allocr & allocr, ggml_backend_t backend) {
661
+ if (allocr.alloc == nullptr) {
662
+ // this can be null if we use external encoder like CoreML or OpenVINO
663
+ return;
664
+ }
665
 
666
+ auto & alloc = allocr.alloc;
667
+ auto & buffer = allocr.buffer;
668
 
669
+ size_t size = ggml_allocr_max_size(alloc);
670
 
671
+ ggml_allocr_free(alloc);
672
+
673
+ buffer = ggml_backend_alloc_buffer(backend, size);
674
+ alloc = ggml_allocr_new_from_buffer(buffer);
675
  }
676
 
677
  static void whisper_allocr_free(struct whisper_allocr & allocr) {
678
  if (allocr.alloc) {
679
  ggml_allocr_free(allocr.alloc);
680
+ ggml_backend_buffer_free(allocr.buffer);
681
  allocr.alloc = nullptr;
682
  }
683
  }
 
706
  // buffer for swapping KV caches between decoders during beam-search
707
  std::vector<kv_buf> kv_swap_bufs;
708
 
709
+ ggml_backend_t backend = nullptr;
 
710
 
711
  // ggml-alloc:
712
  // - stores meta info about the intermediate tensors into the `meta` buffers
 
720
  struct ggml_tensor * embd_conv = nullptr;
721
  struct ggml_tensor * embd_enc = nullptr;
722
 
723
+ // helper for GPU offloading
724
+ std::vector<float> inp_mel;
725
+
726
  // decode output (2-dimensional array: [n_tokens][n_vocab])
727
  std::vector<float> logits;
728
 
 
737
  int lang_id = 0; // english by default
738
 
739
  std::string path_model; // populated by whisper_init_from_file_with_params()
740
+
741
  #ifdef WHISPER_USE_COREML
742
  whisper_coreml_context * ctx_coreml = nullptr;
743
  #endif
744
 
 
 
 
 
745
  #ifdef WHISPER_USE_OPENVINO
746
  whisper_openvino_context * ctx_openvino = nullptr;
747
  #endif
748
 
749
  // [EXPERIMENTAL] token-level timestamps data
750
+ int64_t t_beg = 0;
751
  int64_t t_last = 0;
752
+
753
  whisper_token tid_last;
754
+
755
  std::vector<float> energy; // PCM signal energy
756
 
757
  // [EXPERIMENTAL] speed-up techniques
 
765
  ggml_type wtype = ggml_type::GGML_TYPE_F16; // weight type (FP32 / FP16 / QX)
766
  ggml_type itype = ggml_type::GGML_TYPE_F16; // intermediate type (FP32 or FP16)
767
 
768
+ whisper_context_params params;
769
+
770
  whisper_model model;
771
  whisper_vocab vocab;
772
+
773
  whisper_state * state = nullptr;
774
 
775
+ ggml_backend_t backend = nullptr;
776
+
777
  std::string path_model; // populated by whisper_init_from_file_with_params()
 
778
  };
779
 
780
+ struct whisper_global {
781
+ // We save the log callback globally
782
+ ggml_log_callback log_callback = whisper_log_callback_default;
783
+ void * log_callback_user_data = nullptr;
784
+ };
785
 
786
+ static whisper_global g_state;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
787
 
788
  template<typename T>
789
  static void read_safe(whisper_model_loader * loader, T & dest) {
 
794
  static bool kv_cache_init(
795
  const struct whisper_hparams & hparams,
796
  struct whisper_kv_cache & cache,
797
+ ggml_backend_t backend,
798
  ggml_type wtype,
799
  int n_ctx) {
800
  const int64_t n_text_state = hparams.n_text_state;
 
803
  const int64_t n_mem = n_text_layer*n_ctx;
804
  const int64_t n_elements = n_text_state*n_mem;
805
 
 
 
 
 
806
  struct ggml_init_params params = {
807
+ /*.mem_size =*/ 2*ggml_tensor_overhead(),
808
+ /*.mem_buffer =*/ nullptr,
809
+ /*.no_alloc =*/ true,
810
  };
811
 
812
  cache.ctx = ggml_init(params);
813
 
814
  if (!cache.ctx) {
815
+ WHISPER_LOG_ERROR("%s: failed to allocate memory for kv cache\n", __func__);
816
  return false;
817
  }
818
 
819
  cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
820
  cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
821
 
822
+ const size_t mem_bytes = ggml_nbytes(cache.k) + ggml_nbytes(cache.v);
823
+
824
+ cache.buffer = ggml_backend_alloc_buffer(backend, mem_bytes);
825
+
826
+ // allocate the tensors into the backend buffer
827
+ {
828
+ ggml_allocr * alloc = ggml_allocr_new_from_buffer(cache.buffer);
829
+
830
+ ggml_allocr_alloc(alloc, cache.k);
831
+ ggml_allocr_alloc(alloc, cache.v);
832
+
833
+ ggml_allocr_free(alloc);
834
+ }
835
+
836
  return true;
837
  }
838
 
839
+ // TODO: remove after batched decoding
840
+ static bool kv_cache_reinit(struct whisper_kv_cache & cache, ggml_backend_t backend) {
841
  WHISPER_ASSERT(cache.ctx);
842
 
843
  const int n_elements = ggml_nelements(cache.k);
 
846
  const ggml_type wtype = cache.k->type;
847
  WHISPER_ASSERT(wtype == cache.v->type);
848
 
 
 
849
  struct ggml_init_params params = {
850
+ /*.mem_size =*/ 2*ggml_tensor_overhead(),
851
+ /*.mem_buffer =*/ nullptr,
852
+ /*.no_alloc =*/ true,
853
  };
854
 
855
  cache.ctx = ggml_init(params);
856
 
857
  if (!cache.ctx) {
858
+ WHISPER_LOG_ERROR("%s: failed to allocate memory for kv cache\n", __func__);
859
  return false;
860
  }
861
 
862
  cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
863
  cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
864
 
865
+ const size_t mem_bytes = ggml_nbytes(cache.k) + ggml_nbytes(cache.v);
866
+
867
+ cache.buffer = ggml_backend_alloc_buffer(backend, mem_bytes);
868
+
869
+ // allocate the tensors into the backend buffer
870
+ {
871
+ ggml_allocr * alloc = ggml_allocr_new_from_buffer(cache.buffer);
872
+
873
+ ggml_allocr_alloc(alloc, cache.k);
874
+ ggml_allocr_alloc(alloc, cache.v);
875
+
876
+ ggml_allocr_free(alloc);
877
+ }
878
+
879
  return true;
880
  }
881
 
882
  static void kv_cache_free(struct whisper_kv_cache & cache) {
883
  if (cache.ctx) {
884
  ggml_free(cache.ctx);
885
+ ggml_backend_buffer_free(cache.buffer);
886
  cache.ctx = nullptr;
887
  }
888
  }
889
 
890
+ static ggml_backend_t whisper_backend_init(const whisper_context_params & params) {
891
+ ggml_backend_t backend_gpu = NULL;
892
+
893
+ // initialize the backends
894
+ #ifdef GGML_USE_CUBLAS
895
+ if (params.use_gpu) {
896
+ WHISPER_LOG_INFO("%s: using CUDA backend\n", __func__);
897
+ backend_gpu = ggml_backend_cuda_init();
898
+ if (!backend_gpu) {
899
+ WHISPER_LOG_ERROR("%s: ggml_backend_cuda_init() failed\n", __func__);
900
+ }
901
+ }
902
+ #endif
903
+
904
+ #ifdef GGML_USE_METAL
905
+ if (params.use_gpu) {
906
+ WHISPER_LOG_INFO("%s: using Metal backend\n", __func__);
907
+ ggml_metal_log_set_callback(whisper_log_callback_default, nullptr);
908
+ backend_gpu = ggml_backend_metal_init();
909
+ if (!backend_gpu) {
910
+ WHISPER_LOG_ERROR("%s: ggml_backend_metal_init() failed\n", __func__);
911
+ }
912
+ }
913
+ #endif
914
+
915
+ if (backend_gpu) {
916
+ return backend_gpu;
917
+ }
918
+ return ggml_backend_cpu_init();
919
+ }
920
+
921
  // load the model from a ggml file
922
  //
923
  // file format:
 
930
  // see the convert-pt-to-ggml.py script for details
931
  //
932
  static bool whisper_model_load(struct whisper_model_loader * loader, whisper_context & wctx) {
933
+ WHISPER_LOG_INFO("%s: loading model\n", __func__);
934
 
935
  const int64_t t_start_us = ggml_time_us();
936
 
 
944
  uint32_t magic;
945
  read_safe(loader, magic);
946
  if (magic != GGML_FILE_MAGIC) {
947
+ WHISPER_LOG_ERROR("%s: invalid model data (bad magic)\n", __func__);
948
  return false;
949
  }
950
  }
 
1001
  // in order to save memory and also to speed up the computation
1002
  wctx.wtype = ggml_ftype_to_ggml_type((ggml_ftype) (model.hparams.ftype));
1003
  if (wctx.wtype == GGML_TYPE_COUNT) {
1004
+ WHISPER_LOG_ERROR("%s: invalid model (bad ftype value %d)\n", __func__, model.hparams.ftype);
1005
  return false;
1006
  }
1007
 
1008
+ WHISPER_LOG_INFO("%s: n_vocab = %d\n", __func__, hparams.n_vocab);
1009
+ WHISPER_LOG_INFO("%s: n_audio_ctx = %d\n", __func__, hparams.n_audio_ctx);
1010
+ WHISPER_LOG_INFO("%s: n_audio_state = %d\n", __func__, hparams.n_audio_state);
1011
+ WHISPER_LOG_INFO("%s: n_audio_head = %d\n", __func__, hparams.n_audio_head);
1012
+ WHISPER_LOG_INFO("%s: n_audio_layer = %d\n", __func__, hparams.n_audio_layer);
1013
+ WHISPER_LOG_INFO("%s: n_text_ctx = %d\n", __func__, hparams.n_text_ctx);
1014
+ WHISPER_LOG_INFO("%s: n_text_state = %d\n", __func__, hparams.n_text_state);
1015
+ WHISPER_LOG_INFO("%s: n_text_head = %d\n", __func__, hparams.n_text_head);
1016
+ WHISPER_LOG_INFO("%s: n_text_layer = %d\n", __func__, hparams.n_text_layer);
1017
+ WHISPER_LOG_INFO("%s: n_mels = %d\n", __func__, hparams.n_mels);
1018
+ WHISPER_LOG_INFO("%s: ftype = %d\n", __func__, model.hparams.ftype);
1019
+ WHISPER_LOG_INFO("%s: qntvr = %d\n", __func__, qntvr);
1020
+ WHISPER_LOG_INFO("%s: type = %d (%s%s)\n", __func__, model.type, g_model_name.at(model.type).c_str(), mver.c_str());
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1021
  }
1022
 
1023
  // load mel filters
 
1038
  read_safe(loader, n_vocab);
1039
 
1040
  //if (n_vocab != model.hparams.n_vocab) {
1041
+ // WHISPER_LOG_ERROR("%s: invalid model file '%s' (bad vocab size %d != %d)\n",
1042
  // __func__, fname.c_str(), n_vocab, model.hparams.n_vocab);
1043
  // return false;
1044
  //}
 
1058
  word.assign(&tmp[0], tmp.size());
1059
  } else {
1060
  // seems like we have an empty-string token in multi-language models (i = 50256)
1061
+ //WHISPER_LOG_WARN("%s: warning: empty-string token in vocab, i = %d\n", __func__, i);
1062
  word = "";
1063
  }
1064
 
 
1086
  }
1087
 
1088
  if (n_vocab < model.hparams.n_vocab) {
1089
+ WHISPER_LOG_INFO("%s: adding %d extra tokens\n", __func__, model.hparams.n_vocab - n_vocab);
1090
  for (int i = n_vocab; i < model.hparams.n_vocab; i++) {
1091
  if (i > vocab.token_beg) {
1092
  word = "[_TT_" + std::to_string(i - vocab.token_beg) + "]";
 
1112
  }
1113
  }
1114
 
1115
+ WHISPER_LOG_INFO("%s: n_langs = %d\n", __func__, vocab.num_languages());
1116
  }
1117
 
 
 
1118
  const ggml_type wtype = wctx.wtype;
1119
  const ggml_type vtype = wctx.wtype == GGML_TYPE_F32 ? GGML_TYPE_F32 : GGML_TYPE_F16; // conv type
1120
 
1121
+ // create the ggml context
1122
  {
1123
  const auto & hparams = model.hparams;
1124
 
 
 
 
 
1125
  const int n_audio_layer = hparams.n_audio_layer;
1126
+ const int n_text_layer = hparams.n_text_layer;
1127
 
1128
+ const size_t n_tensors = 10 /* input */ + 15 + 15*n_audio_layer + 24*n_text_layer;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1129
 
 
 
1130
  struct ggml_init_params params = {
1131
+ /*.mem_size =*/ n_tensors*ggml_tensor_overhead(),
1132
+ /*.mem_buffer =*/ nullptr,
1133
+ /*.no_alloc =*/ true,
1134
  };
1135
 
1136
  model.ctx = ggml_init(params);
1137
  if (!model.ctx) {
1138
+ WHISPER_LOG_ERROR("%s: ggml_init() failed\n", __func__);
1139
  return false;
1140
  }
1141
  }
1142
 
1143
+ // prepare tensors for the weights
1144
  {
1145
  auto & ctx = model.ctx;
1146
 
 
1163
 
1164
  // encoder
1165
  {
1166
+ model.e_pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_state, n_audio_ctx);
1167
 
1168
+ model.e_conv_1_w = ggml_new_tensor_3d(ctx, vtype, 3, n_mels, n_audio_state);
1169
+ model.e_conv_1_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 2*n_audio_ctx, n_audio_state);
1170
 
1171
+ model.e_conv_2_w = ggml_new_tensor_3d(ctx, vtype, 3, n_audio_state, n_audio_state);
1172
+ model.e_conv_2_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_ctx, n_audio_state);
1173
 
1174
+ model.e_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
1175
+ model.e_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
1176
 
1177
  // map by name
1178
  model.tensors["encoder.positional_embedding"] = model.e_pe;
 
1336
  }
1337
  }
1338
 
1339
+ wctx.backend = whisper_backend_init(wctx.params);
1340
+
1341
+ {
1342
+ size_t size_main = 0;
1343
+
1344
+ for (const auto & t : model.tensors) {
1345
+ size_main += ggml_nbytes(t.second) + ggml_tensor_overhead();
1346
+ }
1347
+
1348
+ model.buffer = ggml_backend_alloc_buffer(wctx.backend, size_main);
1349
+
1350
+ WHISPER_LOG_INFO("%s: %8s buffer size = %8.2f MB\n", __func__, ggml_backend_name(wctx.backend), size_main / 1024.0 / 1024.0);
1351
+ }
1352
+
1353
+ ggml_allocr * alloc = ggml_allocr_new_from_buffer(model.buffer);
1354
+
1355
+ // allocate tensors in the backend buffers
1356
+ {
1357
+ for (const auto & t : model.tensors) {
1358
+ ggml_allocr_alloc(alloc, t.second);
1359
+ }
1360
+ }
1361
+
1362
  // load weights
1363
  {
1364
  size_t total_size = 0;
1365
 
1366
  model.n_loaded = 0;
1367
 
1368
+ std::vector<char> read_buf;
1369
+
1370
  while (true) {
1371
  int32_t n_dims;
1372
  int32_t length;
 
1393
  name.assign(&tmp[0], tmp.size());
1394
 
1395
  if (model.tensors.find(name) == model.tensors.end()) {
1396
+ WHISPER_LOG_ERROR("%s: unknown tensor '%s' in model file\n", __func__, name.data());
1397
  return false;
1398
  }
1399
 
1400
  auto tensor = model.tensors[name.data()];
 
 
 
 
 
 
1401
 
1402
+ const bool is_conv_bias = (name == "encoder.conv1.bias" || name == "encoder.conv2.bias");
 
 
 
 
1403
 
1404
+ if (!is_conv_bias) {
1405
+ if (ggml_nelements(tensor) != nelements) {
1406
+ WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
1407
+ WHISPER_LOG_ERROR("%s: shape: [%d, %d, %d], expected: [%d, %d, %d]\n",
1408
+ __func__, ne[0], ne[1], ne[2], (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2]);
1409
+ return false;
1410
+ }
1411
 
1412
+ if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2]) {
1413
+ WHISPER_LOG_ERROR("%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]\n",
1414
+ __func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2], ne[0], ne[1], ne[2]);
1415
+ return false;
1416
+ }
1417
+
1418
+ const size_t bpe = ggml_type_size(ggml_type(ttype));
1419
+
1420
+ if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
1421
+ WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
1422
+ __func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
1423
+ return false;
1424
+ }
1425
  }
1426
 
1427
+ ggml_backend_t backend = wctx.backend;
1428
+
1429
+ //printf("%s: [%5.5s] %s\n", __func__, ggml_backend_name(backend), name.c_str());
1430
+
1431
+ if ((ggml_backend_is_cpu(backend)
1432
+ #ifdef GGML_USE_METAL
1433
+ || ggml_backend_is_metal(backend)
1434
+ #endif
1435
+ ) && !is_conv_bias) {
1436
+ // for the CPU and Metal backend, we can read directly into the tensor
1437
+ loader->read(loader->context, tensor->data, ggml_nbytes(tensor));
1438
+ BYTESWAP_TENSOR(tensor);
1439
+ } else {
1440
+ // read into a temporary buffer first, then copy to device memory
1441
+ read_buf.resize(ggml_nbytes(tensor));
1442
+
1443
+ // we repeat the 2 bias tensors along dim 0:
1444
+ // [1, 512] -> [3000, 512] (conv1.bias)
1445
+ // [1, 512] -> [1500, 512] (conv2.bias)
1446
+ if (is_conv_bias) {
1447
+ loader->read(loader->context, read_buf.data(), read_buf.size() / tensor->ne[0]);
1448
+
1449
+ float * data_f32 = (float *) read_buf.data();
1450
+ for (int64_t y = 0; y < tensor->ne[1]; ++y) {
1451
+ const int64_t yy = tensor->ne[1] - y - 1;
1452
+ const float val = data_f32[yy];
1453
+
1454
+ for (int64_t x = 0; x < tensor->ne[0]; ++x) {
1455
+ data_f32[yy*tensor->ne[0] + x] = val;
1456
+ }
1457
+ }
1458
+ } else {
1459
+ loader->read(loader->context, read_buf.data(), read_buf.size());
1460
+ }
1461
+
1462
+ ggml_backend_tensor_set(tensor, read_buf.data(), 0, ggml_nbytes(tensor));
1463
+ }
1464
 
1465
  //printf("%48s - [%5d, %5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ne[2], ggml_type_name((ggml_type) ttype), ggml_nbytes(tensor)/1024.0/1024.0);
1466
  total_size += ggml_nbytes(tensor);
1467
  model.n_loaded++;
1468
  }
1469
 
1470
+ WHISPER_LOG_INFO("%s: model size = %7.2f MB\n", __func__, total_size/1024.0/1024.0);
1471
 
1472
  if (model.n_loaded == 0) {
1473
+ WHISPER_LOG_WARN("%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__);
1474
  } else if (model.n_loaded != (int) model.tensors.size()) {
1475
+ WHISPER_LOG_ERROR("%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", __func__, model.tensors.size(), model.n_loaded);
1476
  return false;
1477
  }
1478
  }
1479
 
1480
+ ggml_allocr_free(alloc);
1481
+
1482
  wctx.t_load_us = ggml_time_us() - t_start_us;
1483
 
1484
  return true;
 
1534
  if (!ggml_allocr_is_measure(alloc)) {
1535
  assert(mel_inp.n_mel == n_mels);
1536
 
1537
+ wstate.inp_mel.resize(ggml_nelements(mel));
1538
+
1539
+ float * dst = wstate.inp_mel.data();
1540
  memset(dst, 0, ggml_nbytes(mel));
1541
 
1542
+ const int i0 = std::min(mel_offset, mel_inp.n_len);
1543
  const int i1 = std::min(mel_offset + 2*n_ctx, mel_inp.n_len);
1544
 
1545
  for (int j = 0; j < mel_inp.n_mel; ++j) {
 
1547
  dst[j*2*n_ctx + (i - i0)] = mel_inp.data[j*mel_inp.n_len + i];
1548
  }
1549
  }
1550
+
1551
+ ggml_backend_tensor_set(mel, wstate.inp_mel.data(), 0, ggml_nelements(mel)*sizeof(float));
1552
  }
1553
 
1554
  struct ggml_tensor * cur = nullptr;
 
1557
  // convolution + gelu
1558
  {
1559
  cur = ggml_conv_1d_ph(ctx0, model.e_conv_1_w, mel, 1, 1);
1560
+ cur = ggml_add(ctx0, cur, model.e_conv_1_b);
1561
+ //cur = ggml_add(ctx0,
1562
+ // ggml_repeat(ctx0,
1563
+ // model.e_conv_1_b,
1564
+ // cur),
1565
+ // cur);
1566
 
1567
  cur = ggml_gelu(ctx0, cur);
1568
 
1569
  cur = ggml_conv_1d_ph(ctx0, model.e_conv_2_w, cur, 2, 1);
1570
+ cur = ggml_add(ctx0, cur, model.e_conv_2_b);
1571
+ //cur = ggml_add(ctx0,
1572
+ // ggml_repeat(ctx0,
1573
+ // model.e_conv_2_b,
1574
+ // cur),
1575
+ // cur);
1576
 
1577
  cur = ggml_gelu(ctx0, cur);
1578
  }
1579
 
1580
+ ggml_set_name(cur, "embd_conv");
1581
  wstate.embd_conv = cur;
1582
  } else {
1583
  #ifdef WHISPER_USE_COREML
 
1597
  }
1598
  #endif
1599
 
1600
+ ggml_set_name(cur, "embd_enc");
1601
  wstate.embd_enc = cur;
1602
  }
1603
 
 
1631
 
1632
  ggml_allocr * alloc = wstate.alloc_encode.alloc;
1633
 
1634
+ //struct ggml_tensor * cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_ctx, n_state);
1635
+ //ggml_allocr_alloc(alloc, cur);
1636
+
1637
+ //if (!ggml_allocr_is_measure(alloc)) {
1638
+ // ggml_backend_tensor_copy(wstate.embd_conv, cur);
1639
+ //}
1640
+ struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_conv);
1641
+
1642
  struct ggml_tensor * KQscale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
1643
  ggml_allocr_alloc(alloc, KQscale);
1644
 
1645
  if (!ggml_allocr_is_measure(alloc)) {
1646
+ const float val = 1.0f/sqrtf(float(n_state)/n_head);
1647
+ ggml_backend_tensor_set(KQscale, &val, 0, sizeof(float));
1648
  }
1649
 
 
 
1650
  // ===================================================================
1651
  // NOTE: experimenting with partial evaluation of the encoder (ignore)
1652
  //static int iter = -1;
 
1665
  const size_t e_pe_offset = model.e_pe->ne[0]*ggml_element_size(model.e_pe)*n_ctx*iter;
1666
 
1667
  struct ggml_tensor * e_pe = ggml_view_2d(ctx0, model.e_pe, model.e_pe->ne[0], n_ctx, e_pe_stride, e_pe_offset);
 
1668
  cur = ggml_add(ctx0, e_pe, ggml_cont(ctx0, ggml_transpose(ctx0, cur)));
1669
 
1670
  // ===================================================================
 
1886
 
1887
  ggml_allocr * alloc = wstate.alloc_cross.alloc;
1888
 
1889
+ //struct ggml_tensor * cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx);
1890
+ //ggml_allocr_alloc(alloc, cur);
1891
+
1892
+ //if (!ggml_allocr_is_measure(alloc)) {
1893
+ // ggml_backend_tensor_copy(wstate.embd_enc, cur);
1894
+ //}
1895
  struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_enc);
1896
 
1897
  struct ggml_tensor * Kscale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
1898
  ggml_allocr_alloc(alloc, Kscale);
1899
 
1900
  if (!ggml_allocr_is_measure(alloc)) {
1901
+ const float val = pow(float(n_state) / n_head, -0.25);
1902
+ ggml_backend_tensor_set(Kscale, &val, 0, sizeof(float));
1903
  }
1904
 
1905
  for (int il = 0; il < model.hparams.n_text_layer; ++il) {
 
1970
  ggml_allocr_alloc_graph(alloc, gf);
1971
 
1972
  if (!whisper_encode_external(wstate)) {
1973
+ ggml_graph_compute_helper(wstate.backend, gf, n_threads);
1974
  }
1975
  }
1976
 
 
1984
 
1985
  ggml_allocr_alloc_graph(alloc, gf);
1986
 
1987
+ ggml_graph_compute_helper(wstate.backend, gf, n_threads);
 
 
 
 
 
 
 
 
 
1988
  }
1989
 
1990
  // cross
 
1997
 
1998
  ggml_allocr_alloc_graph(alloc, gf);
1999
 
2000
+ ggml_graph_compute_helper(wstate.backend, gf, n_threads);
 
 
 
 
 
 
 
 
 
2001
  }
2002
 
 
 
2003
  wstate.t_encode_us += ggml_time_us() - t_start_us;
2004
  wstate.n_encode++;
2005
 
 
2046
  ggml_allocr_alloc(alloc, embd);
2047
 
2048
  if (!ggml_allocr_is_measure(alloc)) {
2049
+ ggml_backend_tensor_set(embd, tokens, 0, N*ggml_element_size(embd));
2050
  }
2051
 
2052
  struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
 
2054
 
2055
  if (!ggml_allocr_is_measure(alloc)) {
2056
  for (int i = 0; i < N; ++i) {
2057
+ const int32_t val = n_past + i;
2058
+ ggml_backend_tensor_set(position, &val, i*sizeof(int32_t), sizeof(int32_t));
2059
  }
2060
  }
2061
 
 
2063
  ggml_allocr_alloc(alloc, KQscale);
2064
 
2065
  if (!ggml_allocr_is_measure(alloc)) {
2066
+ const float val = pow(float(n_state)/n_head, -0.25);
2067
+ ggml_backend_tensor_set(KQscale, &val, 0, sizeof(float));
2068
  }
2069
 
2070
  // token encoding + position encoding
 
2388
 
2389
  logits = gf->nodes[gf->n_nodes - 1];
2390
 
2391
+ ggml_graph_compute_helper(wstate.backend, gf, n_threads);
 
 
 
 
 
 
 
 
 
2392
  }
2393
 
2394
  // extract logits for all N tokens
2395
  //logits_out.resize(n_tokens*n_vocab);
2396
  //memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_tokens*n_vocab);
2397
+ //ggml_backend_tensor_get(logits, logits_out.data(), (n_vocab*(n_tokens - 1))*sizeof(float), sizeof(float)*n_vocab);
2398
 
2399
  // extract logits only for the last token
2400
  logits_out.resize(n_vocab);
2401
+ //memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_vocab);
2402
+ ggml_backend_tensor_get(logits, logits_out.data(), 0, sizeof(float)*n_vocab);
2403
 
2404
  if (n_tokens > 1) {
2405
  //printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
 
2765
  --j;
2766
  }
2767
  if (!found) {
2768
+ WHISPER_LOG_ERROR("unknown token\n");
2769
  ++i;
2770
  }
2771
  }
 
2828
 
2829
  struct whisper_state * whisper_init_state(whisper_context * ctx) {
2830
  fill_sin_cos_table();
2831
+
2832
  whisper_state * state = new whisper_state;
2833
 
2834
+ state->backend = whisper_backend_init(ctx->params);
2835
+
2836
+ if (!kv_cache_init(ctx->model.hparams, state->decoders[0].kv_self, ctx->backend, ctx->itype, ctx->model.hparams.n_text_ctx)) {
2837
+ WHISPER_LOG_ERROR("%s: kv_cache_init() failed for self-attention cache\n", __func__);
2838
  delete state;
2839
  return nullptr;
2840
  }
2841
 
2842
  {
2843
  const size_t memory_size = ggml_nbytes(state->decoders[0].kv_self.k) + ggml_nbytes(state->decoders[0].kv_self.v);
2844
+ WHISPER_LOG_INFO("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
2845
  }
2846
 
2847
+ if (!kv_cache_init(ctx->model.hparams, state->kv_cross, ctx->backend, ctx->itype, ctx->model.hparams.n_audio_ctx)) {
2848
+ WHISPER_LOG_ERROR("%s: kv_cache_init() failed for cross-attention cache\n", __func__);
2849
  delete state;
2850
  return nullptr;
2851
  }
2852
 
2853
  {
2854
  const size_t memory_size = ggml_nbytes(state->kv_cross.k) + ggml_nbytes(state->kv_cross.v);
2855
+ WHISPER_LOG_INFO("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
2856
  }
2857
 
2858
  #ifdef WHISPER_USE_COREML
2859
  const auto path_coreml = whisper_get_coreml_path_encoder(ctx->path_model);
2860
 
2861
+ WHISPER_LOG_INFO("%s: loading Core ML model from '%s'\n", __func__, path_coreml.c_str());
2862
+ WHISPER_LOG_INFO("%s: first run on a device may take a while ...\n", __func__);
2863
 
2864
  state->ctx_coreml = whisper_coreml_init(path_coreml.c_str());
2865
  if (!state->ctx_coreml) {
2866
+ WHISPER_LOG_ERROR("%s: failed to load Core ML model from '%s'\n", __func__, path_coreml.c_str());
2867
  #ifndef WHISPER_COREML_ALLOW_FALLBACK
2868
  delete state;
2869
  return nullptr;
2870
  #endif
2871
  } else {
2872
+ WHISPER_LOG_INFO("%s: Core ML model loaded\n", __func__);
2873
  }
2874
  #endif
2875
 
 
2886
 
2887
  // conv allocator
2888
  {
2889
+ whisper_allocr_graph_init(state->alloc_conv, ctx->backend,
2890
  [&]() {
2891
  return whisper_build_graph_conv(*ctx, *state, 0);
2892
  });
2893
 
2894
+ WHISPER_LOG_INFO("%s: compute buffer (conv) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_conv) / 1024.0 / 1024.0);
2895
  }
2896
 
2897
  // encoder allocator
2898
  if (!whisper_encode_external(*state)) {
2899
+ whisper_allocr_graph_init(state->alloc_encode, ctx->backend,
2900
  [&]() {
2901
  return whisper_build_graph_encoder(*ctx, *state);
2902
  });
2903
 
2904
+ WHISPER_LOG_INFO("%s: compute buffer (encode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_encode) / 1024.0 / 1024.0);
2905
  }
2906
 
2907
  // cross allocator
2908
  {
2909
+ whisper_allocr_graph_init(state->alloc_cross, ctx->backend,
2910
  [&]() {
2911
  return whisper_build_graph_cross(*ctx, *state);
2912
  });
2913
 
2914
+ WHISPER_LOG_INFO("%s: compute buffer (cross) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_cross) / 1024.0 / 1024.0);
2915
  }
2916
 
2917
  // decoder allocator
2918
  {
2919
+ whisper_allocr_graph_init(state->alloc_decode, ctx->backend,
2920
  [&]() {
2921
  const auto & hparams = ctx->model.hparams;
2922
 
 
2927
  return whisper_build_graph_decoder(*ctx, *state, state->decoders[0], nullptr, n_tokens, n_past);
2928
  });
2929
 
2930
+ WHISPER_LOG_INFO("%s: compute buffer (decode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_decode) / 1024.0 / 1024.0);
2931
  }
2932
 
2933
+ whisper_allocr_graph_realloc(state->alloc_conv, ctx->backend);
2934
+ whisper_allocr_graph_realloc(state->alloc_encode, ctx->backend);
2935
+ whisper_allocr_graph_realloc(state->alloc_cross, ctx->backend);
2936
+ whisper_allocr_graph_realloc(state->alloc_decode, ctx->backend);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2937
 
2938
  state->rng = std::mt19937(0);
2939
 
 
2954
  return 1;
2955
  #else
2956
  if (!model_path && ctx->path_model.empty()) {
2957
+ WHISPER_LOG_ERROR("%s: model_path is nullptr, and ctx has no model_path set.\n", __func__);
2958
  return 1;
2959
  }
2960
 
 
2974
  path_cache = cache_dir;
2975
  }
2976
 
2977
+ WHISPER_LOG_INFO("%s: loading OpenVINO model from '%s'\n", __func__, path_encoder.c_str());
2978
+ WHISPER_LOG_INFO("%s: first run on a device may take a while ...\n", __func__);
2979
 
2980
  ctx->state->ctx_openvino = whisper_openvino_init(path_encoder.c_str(), device, path_cache.c_str());
2981
  if (!ctx->state->ctx_openvino) {
2982
+ WHISPER_LOG_ERROR("%s: failed to init OpenVINO encoder from '%s'\n", __func__, path_encoder.c_str());
2983
  return 1;
2984
  } else {
2985
+ WHISPER_LOG_INFO("%s: OpenVINO model loaded\n", __func__);
2986
  }
2987
 
2988
  return 0;
 
2997
  }
2998
 
2999
  struct whisper_context * whisper_init_from_file_with_params_no_state(const char * path_model, struct whisper_context_params params) {
3000
+ WHISPER_LOG_INFO("%s: loading model from '%s'\n", __func__, path_model);
3001
 
3002
  auto fin = std::ifstream(path_model, std::ios::binary);
3003
  if (!fin) {
3004
+ WHISPER_LOG_ERROR("%s: failed to open '%s'\n", __func__, path_model);
3005
  return nullptr;
3006
  }
3007
 
 
3043
 
3044
  buf_context ctx = { reinterpret_cast<uint8_t*>(buffer), buffer_size, 0 };
3045
 
3046
+ WHISPER_LOG_INFO("%s: loading model from buffer\n", __func__);
3047
 
3048
  whisper_model_loader loader = {};
3049
 
 
3079
 
3080
  if (!whisper_model_load(loader, *ctx)) {
3081
  loader->close(loader->context);
3082
+ WHISPER_LOG_ERROR("%s: failed to load model\n", __func__);
3083
  delete ctx;
3084
  return nullptr;
3085
  }
 
3174
  }
3175
  #endif
3176
 
 
 
 
 
 
 
 
3177
  #ifdef WHISPER_USE_OPENVINO
3178
  if (state->ctx_openvino != nullptr) {
3179
  whisper_openvino_free(state->ctx_openvino);
 
3182
  #endif
3183
 
3184
  whisper_allocr_free(state->alloc_conv);
 
 
3185
  whisper_allocr_free(state->alloc_encode);
3186
+ whisper_allocr_free(state->alloc_cross);
3187
+ whisper_allocr_free(state->alloc_decode);
3188
+
3189
+ ggml_backend_free(state->backend);
3190
 
3191
  delete state;
3192
  }
 
3197
  if (ctx->model.ctx) {
3198
  ggml_free(ctx->model.ctx);
3199
  }
3200
+
3201
+ if (ctx->model.buffer) {
3202
+ ggml_backend_buffer_free(ctx->model.buffer);
3203
  }
3204
 
3205
  whisper_free_state(ctx->state);
3206
 
3207
+ ggml_backend_free(ctx->backend);
3208
+
3209
  delete ctx;
3210
  }
3211
  }
 
3224
 
3225
  int whisper_pcm_to_mel_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
3226
  if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, ctx->model.filters.n_mel, n_threads, ctx->model.filters, false, state->mel)) {
3227
+ WHISPER_LOG_ERROR("%s: failed to compute mel spectrogram\n", __func__);
3228
  return -1;
3229
  }
3230
 
 
3238
  // same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 (PV without phase lock is not good)
3239
  int whisper_pcm_to_mel_phase_vocoder_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
3240
  if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, 2 * WHISPER_N_FFT, 2 * WHISPER_HOP_LENGTH, ctx->model.filters.n_mel, n_threads, ctx->model.filters, false, state->mel)) {
3241
+ WHISPER_LOG_ERROR("%s: failed to compute mel spectrogram\n", __func__);
3242
  return -1;
3243
  }
3244
 
 
3266
  int n_len,
3267
  int n_mel) {
3268
  if (n_mel != ctx->model.filters.n_mel) {
3269
+ WHISPER_LOG_ERROR("%s: invalid number of mel bands: %d (expected %d)\n", __func__, n_mel, ctx->model.filters.n_mel);
3270
  return -1;
3271
  }
3272
 
 
3290
 
3291
  int whisper_encode_with_state(struct whisper_context * ctx, struct whisper_state * state, int offset, int n_threads) {
3292
  if (!whisper_encode_internal(*ctx, *state, offset, n_threads, nullptr, nullptr)) {
3293
+ WHISPER_LOG_ERROR("%s: failed to eval\n", __func__);
3294
  return -1;
3295
  }
3296
 
 
3299
 
3300
  int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
3301
  if (!whisper_encode_internal(*ctx, *ctx->state, offset, n_threads, nullptr, nullptr)) {
3302
+ WHISPER_LOG_ERROR("%s: failed to eval\n", __func__);
3303
  return -1;
3304
  }
3305
 
 
3310
  const int selected_decoder_id = 0;
3311
 
3312
  if (!whisper_decode_internal(*ctx, *state, state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads, nullptr, nullptr)) {
3313
+ WHISPER_LOG_ERROR("%s: failed to eval\n", __func__);
3314
  return 1;
3315
  }
3316
 
 
3322
  const int selected_decoder_id = 0;
3323
 
3324
  if (ctx->state == nullptr) {
3325
+ WHISPER_LOG_ERROR("%s: ERROR state was not loaded.\n", __func__);
3326
  return false;
3327
  }
3328
 
3329
  if (!whisper_decode_internal(*ctx, *ctx->state, ctx->state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads, nullptr, nullptr)) {
3330
+ WHISPER_LOG_ERROR("%s: failed to eval\n", __func__);
3331
  return 1;
3332
  }
3333
 
 
3338
  const auto res = tokenize(ctx->vocab, text);
3339
 
3340
  if (n_max_tokens < (int) res.size()) {
3341
+ WHISPER_LOG_ERROR("%s: too many resulting tokens: %d (max %d)\n", __func__, (int) res.size(), n_max_tokens);
3342
  return -1;
3343
  }
3344
 
 
3366
  }
3367
  }
3368
 
3369
+ WHISPER_LOG_ERROR("%s: unknown language '%s'\n", __func__, lang);
3370
  return -1;
3371
  }
3372
  return g_lang.at(lang).first;
 
3379
  }
3380
  }
3381
 
3382
+ WHISPER_LOG_ERROR("%s: unknown language id %d\n", __func__, id);
3383
  return nullptr;
3384
  }
3385
 
 
3392
  const int seek = offset_ms/10;
3393
 
3394
  if (seek < 0) {
3395
+ WHISPER_LOG_ERROR("%s: offset %dms is before the start of the audio\n", __func__, offset_ms);
3396
  return -1;
3397
  }
3398
 
3399
  if (seek >= state->mel.n_len_org) {
3400
+ WHISPER_LOG_ERROR("%s: offset %dms is past the end of the audio (%dms)\n", __func__, offset_ms, state->mel.n_len_org*10);
3401
  return -2;
3402
  }
3403
 
3404
  // run the encoder
3405
  if (whisper_encode_with_state(ctx, state, seek, n_threads) != 0) {
3406
+ WHISPER_LOG_ERROR("%s: failed to encode\n", __func__);
3407
  return -6;
3408
  }
3409
 
3410
  const std::vector<whisper_token> prompt = { whisper_token_sot(ctx) };
3411
 
3412
  if (whisper_decode_with_state(ctx, state, prompt.data(), prompt.size(), 0, n_threads) != 0) {
3413
+ WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
3414
  return -7;
3415
  }
3416
 
 
3610
  void whisper_print_timings(struct whisper_context * ctx) {
3611
  const int64_t t_end_us = ggml_time_us();
3612
 
3613
+ WHISPER_LOG_INFO("\n");
3614
+ WHISPER_LOG_INFO("%s: load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0f);
3615
  if (ctx->state != nullptr) {
3616
 
3617
  const int32_t n_sample = std::max(1, ctx->state->n_sample);
 
3619
  const int32_t n_decode = std::max(1, ctx->state->n_decode);
3620
  const int32_t n_prompt = std::max(1, ctx->state->n_prompt);
3621
 
3622
+ WHISPER_LOG_INFO("%s: fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h);
3623
+ WHISPER_LOG_INFO("%s: mel time = %8.2f ms\n", __func__, ctx->state->t_mel_us / 1000.0f);
3624
+ WHISPER_LOG_INFO("%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample);
3625
+ WHISPER_LOG_INFO("%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode);
3626
+ WHISPER_LOG_INFO("%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode);
3627
+ WHISPER_LOG_INFO("%s: prompt time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_prompt_us, n_prompt, 1e-3f * ctx->state->t_prompt_us / n_prompt);
3628
  }
3629
+ WHISPER_LOG_INFO("%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f);
3630
  }
3631
 
3632
  void whisper_reset_timings(struct whisper_context * ctx) {
 
3678
  s += "SSE3 = " + std::to_string(ggml_cpu_has_sse3()) + " | ";
3679
  s += "SSSE3 = " + std::to_string(ggml_cpu_has_ssse3()) + " | ";
3680
  s += "VSX = " + std::to_string(ggml_cpu_has_vsx()) + " | ";
3681
+ s += "CUDA = " + std::to_string(ggml_cpu_has_cublas()) + " | ";
3682
  s += "COREML = " + std::to_string(whisper_has_coreml()) + " | ";
3683
  s += "OPENVINO = " + std::to_string(whisper_has_openvino()) + " | ";
3684
 
 
3973
  const bool last_was_timestamp = tokens_cur.size() > 0 && tokens_cur.back().id >= vocab.token_beg;
3974
  const bool penultimate_was_timestamp = tokens_cur.size() < 2 || tokens_cur[tokens_cur.size() - 2].id >= vocab.token_beg;
3975
 
3976
+ //WHISPER_LOG_INFO("last_was_timestamp=%d penultimate_was_timestamp=%d\n", last_was_timestamp, penultimate_was_timestamp);
3977
 
3978
  if (last_was_timestamp) {
3979
  if (penultimate_was_timestamp) {
 
4049
 
4050
  const float max_text_token_logprob = *std::max_element(logprobs.begin(), logprobs.begin() + vocab.token_beg);
4051
 
4052
+ //WHISPER_LOG_INFO("timestamp_logprob=%f max_text_token_logprob=%f\n", timestamp_logprob, max_text_token_logprob);
4053
 
4054
  if (timestamp_logprob > max_text_token_logprob) {
4055
  for (int i = 0; i < vocab.token_beg; ++i) {
 
4344
  for (auto & i : two_copy) {
4345
  // make a copy of KV caches
4346
  WHISPER_PRINT_DEBUG("%s: store KV cache into swap: idx %d\n", __func__, i);
4347
+ //memcpy(kv_swap_bufs[i].k.data(), src[i].kv_self.k->data, kv_swap_bufs[i].k.size());
4348
+ //memcpy(kv_swap_bufs[i].v.data(), src[i].kv_self.v->data, kv_swap_bufs[i].v.size());
4349
+ ggml_backend_tensor_get(src[i].kv_self.k, kv_swap_bufs[i].k.data(), 0, kv_swap_bufs[i].k.size());
4350
+ ggml_backend_tensor_get(src[i].kv_self.v, kv_swap_bufs[i].v.data(), 0, kv_swap_bufs[i].v.size());
4351
  }
4352
 
4353
  // since two-copy decoder KV caches are protected by kv_swap_bufs, modify them first
 
4360
  if (two_copy.find(view[i]) != two_copy.end()) {
4361
  // modify KV caches of decoder using data from kv_swap_bufs
4362
  WHISPER_PRINT_DEBUG("%s: two-copy decoder using swap buffers: swap[%d] -> %d\n", __func__, view[i], i);
4363
+ //memcpy(src[i].kv_self.k->data, kv_swap_bufs[view[i]].k.data(), kv_swap_bufs[view[i]].k.size());
4364
+ //memcpy(src[i].kv_self.v->data, kv_swap_bufs[view[i]].v.data(), kv_swap_bufs[view[i]].v.size());
4365
+ ggml_backend_tensor_set(src[i].kv_self.k, kv_swap_bufs[view[i]].k.data(), 0, kv_swap_bufs[view[i]].k.size());
4366
+ ggml_backend_tensor_set(src[i].kv_self.v, kv_swap_bufs[view[i]].v.data(), 0, kv_swap_bufs[view[i]].v.size());
4367
  } else {
4368
  // modify KV caches of decoder using data from correspond decoder KV caches directly
4369
  WHISPER_PRINT_DEBUG("%s: two-copy decoder without swap buffers: %d -> %d\n", __func__, view[i], i);
4370
+ //memcpy(src[i].kv_self.k->data, src[view[i]].kv_self.k->data, ggml_nbytes(src[view[i]].kv_self.k));
4371
+ //memcpy(src[i].kv_self.v->data, src[view[i]].kv_self.v->data, ggml_nbytes(src[view[i]].kv_self.v));
4372
+ ggml_backend_tensor_copy(src[view[i]].kv_self.k, src[i].kv_self.k);
4373
+ ggml_backend_tensor_copy(src[view[i]].kv_self.v, src[i].kv_self.v);
4374
  }
4375
  }
4376
 
 
4384
  if (two_copy.find(view[i]) != two_copy.end()) {
4385
  // modify KV caches of decoder using data from kv_swap_bufs
4386
  WHISPER_PRINT_DEBUG("%s: one-copy decoder using swap buffers: swap[%d] -> %d\n", __func__, view[i], i);
4387
+ //memcpy(src[i].kv_self.k->data, kv_swap_bufs[view[i]].k.data(), kv_swap_bufs[view[i]].k.size());
4388
+ //memcpy(src[i].kv_self.v->data, kv_swap_bufs[view[i]].v.data(), kv_swap_bufs[view[i]].v.size());
4389
+ ggml_backend_tensor_set(src[i].kv_self.k, kv_swap_bufs[view[i]].k.data(), 0, kv_swap_bufs[view[i]].k.size());
4390
+ ggml_backend_tensor_set(src[i].kv_self.v, kv_swap_bufs[view[i]].v.data(), 0, kv_swap_bufs[view[i]].v.size());
4391
  } else {
4392
  // modify KV caches of decoder using data from correspond decoder KV caches directly
4393
  WHISPER_PRINT_DEBUG("%s: one-copy decoder without swap buffers: %d -> %d\n", __func__, view[i], i);
4394
+ //memcpy(src[i].kv_self.k->data, src[view[i]].kv_self.k->data, ggml_nbytes(src[view[i]].kv_self.k));
4395
+ //memcpy(src[i].kv_self.v->data, src[view[i]].kv_self.v->data, ggml_nbytes(src[view[i]].kv_self.v));
4396
+ ggml_backend_tensor_copy(src[view[i]].kv_self.k, src[i].kv_self.k);
4397
+ ggml_backend_tensor_copy(src[view[i]].kv_self.v, src[i].kv_self.v);
4398
  }
4399
  }
4400
 
 
4422
  // compute log mel spectrogram
4423
  if (params.speed_up) {
4424
  // TODO: Replace PV with more advanced algorithm
4425
+ WHISPER_LOG_ERROR("%s: failed to compute log mel spectrogram\n", __func__);
4426
  return -1;
4427
  } else {
4428
  if (whisper_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) {
4429
+ WHISPER_LOG_ERROR("%s: failed to compute log mel spectrogram\n", __func__);
4430
  return -2;
4431
  }
4432
  }
 
4438
 
4439
  const auto lang_id = whisper_lang_auto_detect_with_state(ctx, state, 0, params.n_threads, probs.data());
4440
  if (lang_id < 0) {
4441
+ WHISPER_LOG_ERROR("%s: failed to auto-detect language\n", __func__);
4442
  return -3;
4443
  }
4444
  state->lang_id = lang_id;
4445
  params.language = whisper_lang_str(lang_id);
4446
 
4447
+ WHISPER_LOG_INFO("%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]);
4448
  if (params.detect_language) {
4449
  return 0;
4450
  }
 
4502
 
4503
  if (decoder.kv_self.ctx == nullptr) {
4504
  decoder.kv_self = state->decoders[0].kv_self;
4505
+ if (!kv_cache_reinit(decoder.kv_self, ctx->backend)) {
4506
+ WHISPER_LOG_ERROR("%s: kv_cache_reinit() failed for self-attention, decoder %d\n", __func__, j);
4507
  return -4;
4508
  }
4509
 
 
4514
  decoder.probs.resize (ctx->vocab.n_vocab);
4515
  decoder.logits.resize (ctx->vocab.n_vocab);
4516
  decoder.logprobs.resize(ctx->vocab.n_vocab);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4517
  }
4518
  }
4519
 
 
4547
 
4548
  // overwrite audio_ctx, max allowed is hparams.n_audio_ctx
4549
  if (params.audio_ctx > whisper_n_audio_ctx(ctx)) {
4550
+ WHISPER_LOG_ERROR("%s: audio_ctx is larger than the maximum allowed (%d > %d)\n", __func__, params.audio_ctx, whisper_n_audio_ctx(ctx));
4551
  return -5;
4552
  }
4553
  state->exp_n_audio_ctx = params.audio_ctx;
 
4572
  // distilled models require the "no_timestamps" token
4573
  // TODO: add input parameter (#1229)
4574
  if (is_distil) {
4575
+ WHISPER_LOG_WARN("%s: using distilled model - forcing no_timestamps\n", __func__);
4576
  prompt_init.push_back(whisper_token_not(ctx));
4577
  }
4578
  }
 
4609
 
4610
  if (params.encoder_begin_callback) {
4611
  if (params.encoder_begin_callback(ctx, state, params.encoder_begin_callback_user_data) == false) {
4612
+ WHISPER_LOG_ERROR("%s: encoder_begin_callback returned false - aborting\n", __func__);
4613
  break;
4614
  }
4615
  }
4616
 
4617
  // encode audio features starting at offset seek
4618
  if (!whisper_encode_internal(*ctx, *state, seek, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
4619
+ WHISPER_LOG_ERROR("%s: failed to encode\n", __func__);
4620
  return -6;
4621
  }
4622
 
 
4699
  WHISPER_PRINT_DEBUG("\n\n");
4700
 
4701
  if (!whisper_decode_internal(*ctx, *state, state->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
4702
+ WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
4703
  return -7;
4704
  }
4705
 
 
4713
  for (int j = 1; j < n_decoders_cur; ++j) {
4714
  auto & decoder = state->decoders[j];
4715
 
4716
+ // TODO: fix CUDA
4717
+ //memcpy(decoder.kv_self.k->data, state->decoders[0].kv_self.k->data, ggml_nbytes(decoder.kv_self.k));
4718
+ //memcpy(decoder.kv_self.v->data, state->decoders[0].kv_self.v->data, ggml_nbytes(decoder.kv_self.v));
4719
+ ggml_backend_tensor_copy(state->decoders[0].kv_self.k, decoder.kv_self.k);
4720
+ ggml_backend_tensor_copy(state->decoders[0].kv_self.v, decoder.kv_self.v);
4721
 
4722
  decoder.kv_self.n += prompt.size();
4723
 
 
4926
  //WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, kv_self.n %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.kv_self.n, decoder.seek_delta);
4927
 
4928
  if (!whisper_decode_internal(*ctx, *state, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
4929
+ WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
4930
  return -8;
4931
  }
4932
 
 
5252
  ctx->state->t_decode_us /= n_processors;
5253
 
5254
  // print information about the audio boundaries
5255
+ WHISPER_LOG_WARN("\n");
5256
+ WHISPER_LOG_WARN("%s: the audio has been split into %d chunks at the following times:\n", __func__, n_processors);
5257
  for (int i = 0; i < n_processors - 1; ++i) {
5258
+ WHISPER_LOG_WARN("%s: split %d - %s\n", __func__, (i + 1), to_timestamp(100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t).c_str());
5259
  }
5260
+ WHISPER_LOG_WARN("%s: the transcription quality may be degraded near these boundaries\n", __func__);
5261
 
5262
  return ret;
5263
  }
 
5499
  double tsum = 0.0;
5500
 
5501
  // heat-up
5502
+ ggml_graph_compute_helper(gf, work, n_threads, nullptr, nullptr);
5503
 
5504
  for (int i = 0; i < n_max; ++i) {
5505
  const int64_t t0 = ggml_time_us();
5506
 
5507
+ ggml_graph_compute_helper(gf, work, n_threads, nullptr, nullptr);
5508
 
5509
  const int64_t t1 = ggml_time_us();
5510
 
 
5622
  const int n_samples = state.energy.size();
5623
 
5624
  if (n_samples == 0) {
5625
+ WHISPER_LOG_ERROR("%s: no signal data available\n", __func__);
5626
  return;
5627
  }
5628
 
 
5843
  //}
5844
  }
5845
 
5846
+ void whisper_log_set(ggml_log_callback log_callback, void * user_data) {
5847
+ g_state.log_callback = log_callback ? log_callback : whisper_log_callback_default;
5848
+ g_state.log_callback_user_data = user_data;
5849
+ }
5850
+
5851
+ static void whisper_log_internal_v(ggml_log_level level, const char * format, va_list args) {
5852
+ va_list args_copy;
5853
+ va_copy(args_copy, args);
5854
+ char buffer[128];
5855
+ int len = vsnprintf(buffer, 128, format, args);
5856
+ if (len < 128) {
5857
+ g_state.log_callback(level, buffer, g_state.log_callback_user_data);
5858
+ } else {
5859
+ char* buffer2 = new char[len+1];
5860
+ vsnprintf(buffer2, len+1, format, args_copy);
5861
+ buffer2[len] = 0;
5862
+ g_state.log_callback(level, buffer2, g_state.log_callback_user_data);
5863
+ delete[] buffer2;
5864
+ }
5865
+ va_end(args_copy);
5866
+ }
5867
+
5868
+ static void whisper_log_internal(ggml_log_level level, const char * format, ...) {
5869
+ va_list args;
5870
+ va_start(args, format);
5871
+ whisper_log_internal_v(level, format, args);
5872
+ va_end(args);
5873
+ }
5874
+
5875
+ static void whisper_log_callback_default(ggml_log_level level, const char * text, void * user_data) {
5876
+ (void) level;
5877
+ (void) user_data;
5878
+ fputs(text, stderr);
5879
+ fflush(stderr);
5880
  }
whisper.h CHANGED
@@ -1,6 +1,8 @@
1
  #ifndef WHISPER_H
2
  #define WHISPER_H
3
 
 
 
4
  #include <stddef.h>
5
  #include <stdint.h>
6
  #include <stdbool.h>
@@ -110,15 +112,15 @@ extern "C" {
110
  // Various functions for loading a ggml whisper model.
111
  // Allocate (almost) all memory needed for the model.
112
  // Return NULL on failure
113
- WHISPER_API struct whisper_context * whisper_init_from_file_with_params(const char * path_model, struct whisper_context_params params);
114
- WHISPER_API struct whisper_context * whisper_init_from_buffer_with_params(void * buffer, size_t buffer_size, struct whisper_context_params params);
115
- WHISPER_API struct whisper_context * whisper_init_with_params(struct whisper_model_loader * loader, struct whisper_context_params params);
116
 
117
  // These are the same as the above, but the internal state of the context is not allocated automatically
118
  // It is the responsibility of the caller to allocate the state using whisper_init_state() (#523)
119
- WHISPER_API struct whisper_context * whisper_init_from_file_with_params_no_state(const char * path_model, struct whisper_context_params params);
120
- WHISPER_API struct whisper_context * whisper_init_from_buffer_with_params_no_state(void * buffer, size_t buffer_size, struct whisper_context_params params);
121
- WHISPER_API struct whisper_context * whisper_init_with_params_no_state(struct whisper_model_loader * loader, struct whisper_context_params params);
122
 
123
  WHISPER_DEPRECATED(
124
  WHISPER_API struct whisper_context * whisper_init_from_file(const char * path_model),
@@ -570,8 +572,7 @@ extern "C" {
570
 
571
  // Control logging output; default behavior is to print to stderr
572
 
573
- typedef void (*whisper_log_callback)(const char * line);
574
- WHISPER_API void whisper_set_log_callback(whisper_log_callback callback);
575
 
576
  #ifdef __cplusplus
577
  }
 
1
  #ifndef WHISPER_H
2
  #define WHISPER_H
3
 
4
+ #include "ggml.h"
5
+
6
  #include <stddef.h>
7
  #include <stdint.h>
8
  #include <stdbool.h>
 
112
  // Various functions for loading a ggml whisper model.
113
  // Allocate (almost) all memory needed for the model.
114
  // Return NULL on failure
115
+ WHISPER_API struct whisper_context * whisper_init_from_file_with_params (const char * path_model, struct whisper_context_params params);
116
+ WHISPER_API struct whisper_context * whisper_init_from_buffer_with_params(void * buffer, size_t buffer_size, struct whisper_context_params params);
117
+ WHISPER_API struct whisper_context * whisper_init_with_params (struct whisper_model_loader * loader, struct whisper_context_params params);
118
 
119
  // These are the same as the above, but the internal state of the context is not allocated automatically
120
  // It is the responsibility of the caller to allocate the state using whisper_init_state() (#523)
121
+ WHISPER_API struct whisper_context * whisper_init_from_file_with_params_no_state (const char * path_model, struct whisper_context_params params);
122
+ WHISPER_API struct whisper_context * whisper_init_from_buffer_with_params_no_state(void * buffer, size_t buffer_size, struct whisper_context_params params);
123
+ WHISPER_API struct whisper_context * whisper_init_with_params_no_state (struct whisper_model_loader * loader, struct whisper_context_params params);
124
 
125
  WHISPER_DEPRECATED(
126
  WHISPER_API struct whisper_context * whisper_init_from_file(const char * path_model),
 
572
 
573
  // Control logging output; default behavior is to print to stderr
574
 
575
+ WHISPER_API void whisper_log_set(ggml_log_callback log_callback, void * user_data);
 
576
 
577
  #ifdef __cplusplus
578
  }