Spaces:
Running
whisper : add full CUDA and Metal offloading (#1472)
Browse files* whisper : migrate to ggml-backend
* whisper : fix logit reading
* whisper : fix tensor allocation during load
* whisper : fix beam-search with CUDA
* whisper : free backends + fix compile warning
* whisper : print when CUDA is enabled
* whisper : fix CoreML
* make : clean-up
* talk : fix compile warning
* whisper : support ggml_conv with CUDA and Metal (#1473)
* ggml : add CUDA support for ggml_conv
* whisper : remove ggml_repeat for conv bias + single backend
* cuda : fix im2col kernel
* metal : add im2col support + mul mat-vec f16 x f16
* bench-all : add q4 models
* whisper : clean-up
* quantize-all : fix
* ggml : im2col opts
* whisper : avoid whisper_model_data wrapper
* whisper : add note that ggml_mul_mat_pad does not work with CUDA
* whisper : factor out graph compute in common function
* whisper : fixes
* whisper : fix UB with measure buffers
* whisper : try to fix the parallel whisper_state functionality (#1479)
* whisper : try to fix the parallel whisper_state functionality
* whisper : fix multi-state Metal
* whisper : free backend instances in whisper_state
- .gitignore +1 -0
- Makefile +21 -21
- examples/common.h +1 -1
- examples/talk/gpt-2.cpp +4 -4
- extra/bench-all.sh +9 -5
- extra/quantize-all.sh +7 -27
- ggml-cuda.cu +95 -1
- ggml-metal.h +1 -1
- ggml-metal.m +74 -6
- ggml-metal.metal +107 -1
- ggml.c +201 -1085
- ggml.h +13 -6
- whisper.cpp +486 -541
- whisper.h +9 -8
|
@@ -8,6 +8,7 @@
|
|
| 8 |
.DS_Store
|
| 9 |
|
| 10 |
build/
|
|
|
|
| 11 |
build-em/
|
| 12 |
build-debug/
|
| 13 |
build-release/
|
|
|
|
| 8 |
.DS_Store
|
| 9 |
|
| 10 |
build/
|
| 11 |
+
build-coreml/
|
| 12 |
build-em/
|
| 13 |
build-debug/
|
| 14 |
build-release/
|
|
@@ -307,7 +307,7 @@ ggml-backend.o: ggml-backend.c ggml.h ggml-backend.h
|
|
| 307 |
ggml-quants.o: ggml-quants.c ggml.h ggml-quants.h
|
| 308 |
$(CC) $(CFLAGS) -c $< -o $@
|
| 309 |
|
| 310 |
-
WHISPER_OBJ += ggml-alloc.o ggml-backend.o ggml-quants.o
|
| 311 |
|
| 312 |
whisper.o: whisper.cpp whisper.h ggml.h ggml-cuda.h
|
| 313 |
$(CXX) $(CXXFLAGS) -c $< -o $@
|
|
@@ -331,11 +331,11 @@ ggml-metal.o: ggml-metal.m ggml-metal.h
|
|
| 331 |
WHISPER_OBJ += ggml-metal.o
|
| 332 |
endif
|
| 333 |
|
| 334 |
-
libwhisper.a:
|
| 335 |
-
$(AR) rcs libwhisper.a
|
| 336 |
|
| 337 |
-
libwhisper.so:
|
| 338 |
-
$(CXX) $(CXXFLAGS) -shared -o libwhisper.so
|
| 339 |
|
| 340 |
clean:
|
| 341 |
rm -f *.o main stream command talk talk-llama bench quantize lsp libwhisper.a libwhisper.so
|
|
@@ -349,30 +349,30 @@ CC_SDL=`sdl2-config --cflags --libs`
|
|
| 349 |
SRC_COMMON = examples/common.cpp examples/common-ggml.cpp
|
| 350 |
SRC_COMMON_SDL = examples/common-sdl.cpp
|
| 351 |
|
| 352 |
-
main: examples/main/main.cpp $(SRC_COMMON)
|
| 353 |
-
$(CXX) $(CXXFLAGS) examples/main/main.cpp $(SRC_COMMON)
|
| 354 |
./main -h
|
| 355 |
|
| 356 |
-
bench: examples/bench/bench.cpp
|
| 357 |
-
$(CXX) $(CXXFLAGS) examples/bench/bench.cpp
|
| 358 |
|
| 359 |
-
quantize: examples/quantize/quantize.cpp
|
| 360 |
-
$(CXX) $(CXXFLAGS) examples/quantize/quantize.cpp $(SRC_COMMON)
|
| 361 |
|
| 362 |
-
stream: examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL)
|
| 363 |
-
$(CXX) $(CXXFLAGS) examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL)
|
| 364 |
|
| 365 |
-
command: examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL)
|
| 366 |
-
$(CXX) $(CXXFLAGS) examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL)
|
| 367 |
|
| 368 |
-
lsp: examples/lsp/lsp.cpp $(SRC_COMMON) $(SRC_COMMON_SDL)
|
| 369 |
-
$(CXX) $(CXXFLAGS) examples/lsp/lsp.cpp $(SRC_COMMON) $(SRC_COMMON_SDL)
|
| 370 |
|
| 371 |
-
talk: examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL)
|
| 372 |
-
$(CXX) $(CXXFLAGS) examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL)
|
| 373 |
|
| 374 |
-
talk-llama: examples/talk-llama/talk-llama.cpp examples/talk-llama/llama.cpp $(SRC_COMMON) $(SRC_COMMON_SDL)
|
| 375 |
-
$(CXX) $(CXXFLAGS) examples/talk-llama/talk-llama.cpp examples/talk-llama/llama.cpp $(SRC_COMMON) $(SRC_COMMON_SDL)
|
| 376 |
|
| 377 |
#
|
| 378 |
# Audio samples
|
|
|
|
| 307 |
ggml-quants.o: ggml-quants.c ggml.h ggml-quants.h
|
| 308 |
$(CC) $(CFLAGS) -c $< -o $@
|
| 309 |
|
| 310 |
+
WHISPER_OBJ += ggml.o ggml-alloc.o ggml-backend.o ggml-quants.o
|
| 311 |
|
| 312 |
whisper.o: whisper.cpp whisper.h ggml.h ggml-cuda.h
|
| 313 |
$(CXX) $(CXXFLAGS) -c $< -o $@
|
|
|
|
| 331 |
WHISPER_OBJ += ggml-metal.o
|
| 332 |
endif
|
| 333 |
|
| 334 |
+
libwhisper.a: $(WHISPER_OBJ)
|
| 335 |
+
$(AR) rcs libwhisper.a $(WHISPER_OBJ)
|
| 336 |
|
| 337 |
+
libwhisper.so: $(WHISPER_OBJ)
|
| 338 |
+
$(CXX) $(CXXFLAGS) -shared -o libwhisper.so $(WHISPER_OBJ) $(LDFLAGS)
|
| 339 |
|
| 340 |
clean:
|
| 341 |
rm -f *.o main stream command talk talk-llama bench quantize lsp libwhisper.a libwhisper.so
|
|
|
|
| 349 |
SRC_COMMON = examples/common.cpp examples/common-ggml.cpp
|
| 350 |
SRC_COMMON_SDL = examples/common-sdl.cpp
|
| 351 |
|
| 352 |
+
main: examples/main/main.cpp $(SRC_COMMON) $(WHISPER_OBJ)
|
| 353 |
+
$(CXX) $(CXXFLAGS) examples/main/main.cpp $(SRC_COMMON) $(WHISPER_OBJ) -o main $(LDFLAGS)
|
| 354 |
./main -h
|
| 355 |
|
| 356 |
+
bench: examples/bench/bench.cpp $(WHISPER_OBJ)
|
| 357 |
+
$(CXX) $(CXXFLAGS) examples/bench/bench.cpp $(WHISPER_OBJ) -o bench $(LDFLAGS)
|
| 358 |
|
| 359 |
+
quantize: examples/quantize/quantize.cpp $(WHISPER_OBJ) $(SRC_COMMON)
|
| 360 |
+
$(CXX) $(CXXFLAGS) examples/quantize/quantize.cpp $(SRC_COMMON) $(WHISPER_OBJ) -o quantize $(LDFLAGS)
|
| 361 |
|
| 362 |
+
stream: examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ)
|
| 363 |
+
$(CXX) $(CXXFLAGS) examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) -o stream $(CC_SDL) $(LDFLAGS)
|
| 364 |
|
| 365 |
+
command: examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ)
|
| 366 |
+
$(CXX) $(CXXFLAGS) examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) -o command $(CC_SDL) $(LDFLAGS)
|
| 367 |
|
| 368 |
+
lsp: examples/lsp/lsp.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ)
|
| 369 |
+
$(CXX) $(CXXFLAGS) examples/lsp/lsp.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) -o lsp $(CC_SDL) $(LDFLAGS)
|
| 370 |
|
| 371 |
+
talk: examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ)
|
| 372 |
+
$(CXX) $(CXXFLAGS) examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) -o talk $(CC_SDL) $(LDFLAGS)
|
| 373 |
|
| 374 |
+
talk-llama: examples/talk-llama/talk-llama.cpp examples/talk-llama/llama.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ)
|
| 375 |
+
$(CXX) $(CXXFLAGS) examples/talk-llama/talk-llama.cpp examples/talk-llama/llama.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) -o talk-llama $(CC_SDL) $(LDFLAGS)
|
| 376 |
|
| 377 |
#
|
| 378 |
# Audio samples
|
|
@@ -181,7 +181,7 @@ private:
|
|
| 181 |
// It is assumed that PCM data is normalized to a range from -1 to 1
|
| 182 |
bool write_audio(const float * data, size_t length) {
|
| 183 |
for (size_t i = 0; i < length; ++i) {
|
| 184 |
-
const
|
| 185 |
file.write(reinterpret_cast<const char *>(&intSample), sizeof(int16_t));
|
| 186 |
dataSize += sizeof(int16_t);
|
| 187 |
}
|
|
|
|
| 181 |
// It is assumed that PCM data is normalized to a range from -1 to 1
|
| 182 |
bool write_audio(const float * data, size_t length) {
|
| 183 |
for (size_t i = 0; i < length; ++i) {
|
| 184 |
+
const int16_t intSample = data[i] * 32767;
|
| 185 |
file.write(reinterpret_cast<const char *>(&intSample), sizeof(int16_t));
|
| 186 |
dataSize += sizeof(int16_t);
|
| 187 |
}
|
|
@@ -121,13 +121,13 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
|
|
| 121 |
return false;
|
| 122 |
}
|
| 123 |
|
| 124 |
-
|
|
|
|
| 125 |
for (int i = 0; i < n_vocab; i++) {
|
| 126 |
uint32_t len;
|
| 127 |
fin.read((char *) &len, sizeof(len));
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
fin.read((char *) word.data(), len);
|
| 131 |
|
| 132 |
vocab.token_to_id[word] = i;
|
| 133 |
vocab.id_to_token[i] = word;
|
|
|
|
| 121 |
return false;
|
| 122 |
}
|
| 123 |
|
| 124 |
+
char word[129];
|
| 125 |
+
|
| 126 |
for (int i = 0; i < n_vocab; i++) {
|
| 127 |
uint32_t len;
|
| 128 |
fin.read((char *) &len, sizeof(len));
|
| 129 |
+
word[len] = '\0';
|
| 130 |
+
fin.read((char *) word, len);
|
|
|
|
| 131 |
|
| 132 |
vocab.token_to_id[word] = i;
|
| 133 |
vocab.id_to_token[i] = word;
|
|
@@ -18,11 +18,11 @@ else
|
|
| 18 |
fi
|
| 19 |
|
| 20 |
models=( \
|
| 21 |
-
"tiny" "tiny-q5_0" "tiny-q5_1" "tiny-q8_0" \
|
| 22 |
-
"base" "base-q5_0" "base-q5_1" "base-q8_0" \
|
| 23 |
-
"small" "small-q5_0" "small-q5_1" "small-q8_0" \
|
| 24 |
-
"medium" "medium-q5_0" "medium-q5_1" "medium-q8_0" \
|
| 25 |
-
"large" "large-q5_0" "large-q5_1" "large-q8_0" \
|
| 26 |
)
|
| 27 |
|
| 28 |
if [ "$encoder_only" -eq 0 ]; then
|
|
@@ -83,6 +83,10 @@ for model in "${models[@]}"; do
|
|
| 83 |
config="$config COREML"
|
| 84 |
fi
|
| 85 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
if [[ $system_info == *"METAL = 1"* ]]; then
|
| 87 |
config="$config METAL"
|
| 88 |
fi
|
|
|
|
| 18 |
fi
|
| 19 |
|
| 20 |
models=( \
|
| 21 |
+
"tiny" "tiny-q4_0" "tiny-q4_1" "tiny-q5_0" "tiny-q5_1" "tiny-q8_0" \
|
| 22 |
+
"base" "base-q4_0" "base-q4_1" "base-q5_0" "base-q5_1" "base-q8_0" \
|
| 23 |
+
"small" "small-q4_0" "small-q4_1" "small-q5_0" "small-q5_1" "small-q8_0" \
|
| 24 |
+
"medium" "medium-q4_0" "medium-q4_1" "medium-q5_0" "medium-q5_1" "medium-q8_0" \
|
| 25 |
+
"large" "large-q4_0" "large-q4_1" "large-q5_0" "large-q5_1" "large-q8_0" \
|
| 26 |
)
|
| 27 |
|
| 28 |
if [ "$encoder_only" -eq 0 ]; then
|
|
|
|
| 83 |
config="$config COREML"
|
| 84 |
fi
|
| 85 |
|
| 86 |
+
if [[ $system_info == *"CUDA = 1"* ]]; then
|
| 87 |
+
config="$config CUDA"
|
| 88 |
+
fi
|
| 89 |
+
|
| 90 |
if [[ $system_info == *"METAL = 1"* ]]; then
|
| 91 |
config="$config METAL"
|
| 92 |
fi
|
|
@@ -15,33 +15,13 @@ declare -a filedex
|
|
| 15 |
cd `dirname $0`
|
| 16 |
cd ../
|
| 17 |
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
#
|
| 25 |
-
for f in "$i"/*.bin; do
|
| 26 |
-
# [Neuron Activation]
|
| 27 |
-
newfile=`echo "${f##*/}" | cut -d _ -f 1`;
|
| 28 |
-
if [ "$newfile" != "q5" ]; then
|
| 29 |
-
./quantize "${f}" "${i:-4}/${i:9:${#i}-4}-${qtype1}.bin" ${qtype1};
|
| 30 |
-
./quantize "${f}" "${i:-4}/${i:9:${#i}-4}-${qtype0}.bin" ${qtype0};
|
| 31 |
-
filedex+=( "${i:-4}/${i:9:${#i}-4}-${qtype1}.bin" "${i:-4}/${i:9:${#i}-4}-${qtype0}.bin" )
|
| 32 |
-
fi
|
| 33 |
-
done
|
| 34 |
-
fi
|
| 35 |
-
else
|
| 36 |
-
# It's a file! Let's make sure it's the right type:
|
| 37 |
-
if [ "${i##*.}" == "bin" ]; then
|
| 38 |
-
# And we probably want to skip the testing files
|
| 39 |
-
if [ "${i:9:8}" != "for-test" ]; then
|
| 40 |
-
# [Neuron Activation]
|
| 41 |
-
./quantize "${i}" "${i:-4}-${qtype1}.bin" ${qtype1};
|
| 42 |
-
./quantize "${i}" "${i:-4}-${qtype0}.bin" ${qtype0};
|
| 43 |
-
filedex+=( "${i:-4}-${qtype1}.bin" "${i:-4}-${qtype0}.bin" )
|
| 44 |
-
fi
|
| 45 |
fi
|
| 46 |
fi
|
| 47 |
done
|
|
|
|
| 15 |
cd `dirname $0`
|
| 16 |
cd ../
|
| 17 |
|
| 18 |
+
for i in `ls ./models | grep ^ggml-.*.bin | grep -v "\-q"`; do
|
| 19 |
+
m="models/$i"
|
| 20 |
+
if [ -f "$m" ]; then
|
| 21 |
+
if [ "${m##*.}" == "bin" ]; then
|
| 22 |
+
./quantize "${m}" "${m::${#m}-4}-${qtype1}.bin" ${qtype1};
|
| 23 |
+
./quantize "${m}" "${m::${#m}-4}-${qtype0}.bin" ${qtype0};
|
| 24 |
+
filedex+=( "${m::${#m}-4}-${qtype1}.bin" "${m::${#m}-4}-${qtype0}.bin" )
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
fi
|
| 26 |
fi
|
| 27 |
done
|
|
@@ -4476,6 +4476,13 @@ static __device__ void cpy_1_f32_f16(const char * cxi, char * cdsti) {
|
|
| 4476 |
*dsti = __float2half(*xi);
|
| 4477 |
}
|
| 4478 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4479 |
template <cpy_kernel_t cpy_1>
|
| 4480 |
static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
|
| 4481 |
const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
|
|
@@ -4729,6 +4736,25 @@ static __global__ void clamp_f32(const float * x, float * dst, const float min,
|
|
| 4729 |
dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
|
| 4730 |
}
|
| 4731 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4732 |
template<int qk, int qr, dequantize_kernel_t dq>
|
| 4733 |
static void get_rows_cuda(const void * x, const int32_t * y, float * dst, const int nrows, const int ncols, cudaStream_t stream) {
|
| 4734 |
const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
|
|
@@ -5618,6 +5644,16 @@ static void ggml_cpy_f32_f16_cuda(
|
|
| 5618 |
(cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
|
| 5619 |
}
|
| 5620 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5621 |
static void scale_f32_cuda(const float * x, float * dst, const float scale, const int k, cudaStream_t stream) {
|
| 5622 |
const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE;
|
| 5623 |
scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, k);
|
|
@@ -5701,6 +5737,15 @@ static void soft_max_f32_cuda(const float * x, float * dst, const int ncols_x, c
|
|
| 5701 |
soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x);
|
| 5702 |
}
|
| 5703 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5704 |
// buffer pool for cuda
|
| 5705 |
#define MAX_CUDA_BUFFERS 256
|
| 5706 |
|
|
@@ -6483,7 +6528,7 @@ inline void ggml_cuda_op_mul_mat_cublas(
|
|
| 6483 |
src1_as_f16 = (half *) ggml_cuda_pool_malloc_async(ne * sizeof(half), &src1_as, id, stream);
|
| 6484 |
to_fp16_cuda(src1_ddf_i, src1_as_f16, ne, stream);
|
| 6485 |
}
|
| 6486 |
-
const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *)
|
| 6487 |
size_t dst_f16_as = 0;
|
| 6488 |
half * dst_f16 = (half *) ggml_cuda_pool_malloc_async(row_diff*src1_ncols * sizeof(half), &dst_f16_as, id, stream);
|
| 6489 |
|
|
@@ -6659,6 +6704,45 @@ inline void ggml_cuda_op_alibi(
|
|
| 6659 |
(void) src1_dd;
|
| 6660 |
}
|
| 6661 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6662 |
inline void ggml_cuda_op_diag_mask_inf(
|
| 6663 |
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
| 6664 |
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
|
@@ -7549,6 +7633,9 @@ static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, gg
|
|
| 7549 |
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
|
| 7550 |
ggml_cpy_f32_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
|
| 7551 |
ne10, ne11, nb10, nb11, nb12, main_stream);
|
|
|
|
|
|
|
|
|
|
| 7552 |
} else {
|
| 7553 |
fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__,
|
| 7554 |
ggml_type_name(src0->type), ggml_type_name(src1->type));
|
|
@@ -7580,6 +7667,10 @@ static void ggml_cuda_alibi(const ggml_tensor * src0, const ggml_tensor * src1,
|
|
| 7580 |
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_alibi);
|
| 7581 |
}
|
| 7582 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7583 |
static void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 7584 |
(void) src0;
|
| 7585 |
(void) src1;
|
|
@@ -7943,6 +8034,9 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
|
|
| 7943 |
case GGML_OP_ALIBI:
|
| 7944 |
func = ggml_cuda_alibi;
|
| 7945 |
break;
|
|
|
|
|
|
|
|
|
|
| 7946 |
default:
|
| 7947 |
return false;
|
| 7948 |
}
|
|
|
|
| 4476 |
*dsti = __float2half(*xi);
|
| 4477 |
}
|
| 4478 |
|
| 4479 |
+
static __device__ void cpy_1_f16_f16(const char * cxi, char * cdsti) {
|
| 4480 |
+
const half * xi = (const half *) cxi;
|
| 4481 |
+
half * dsti = (half *) cdsti;
|
| 4482 |
+
|
| 4483 |
+
*dsti = *xi;
|
| 4484 |
+
}
|
| 4485 |
+
|
| 4486 |
template <cpy_kernel_t cpy_1>
|
| 4487 |
static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
|
| 4488 |
const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
|
|
|
|
| 4736 |
dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
|
| 4737 |
}
|
| 4738 |
|
| 4739 |
+
static __global__ void im2col_f32_f16(
|
| 4740 |
+
const float * x, half * dst,
|
| 4741 |
+
int ofs0, int ofs1, int IW, int IH, int CHW,
|
| 4742 |
+
int s0, int s1, int p0, int p1, int d0, int d1) {
|
| 4743 |
+
const int iiw = blockIdx.z * s0 + threadIdx.z * d0 - p0;
|
| 4744 |
+
const int iih = blockIdx.y * s1 + threadIdx.y * d1 - p1;
|
| 4745 |
+
|
| 4746 |
+
const int offset_dst =
|
| 4747 |
+
(threadIdx.x * gridDim.y * gridDim.z + blockIdx.y * gridDim.z + blockIdx.z) * CHW +
|
| 4748 |
+
(blockIdx.x * (blockDim.y * blockDim.z) + threadIdx.y * blockDim.z + threadIdx.z);
|
| 4749 |
+
|
| 4750 |
+
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
|
| 4751 |
+
dst[offset_dst] = __float2half(0.0f);
|
| 4752 |
+
} else {
|
| 4753 |
+
const int offset_src = threadIdx.x * ofs0 + blockIdx.x * ofs1;
|
| 4754 |
+
dst[offset_dst] = __float2half(x[offset_src + iih * IW + iiw]);
|
| 4755 |
+
}
|
| 4756 |
+
}
|
| 4757 |
+
|
| 4758 |
template<int qk, int qr, dequantize_kernel_t dq>
|
| 4759 |
static void get_rows_cuda(const void * x, const int32_t * y, float * dst, const int nrows, const int ncols, cudaStream_t stream) {
|
| 4760 |
const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
|
|
|
|
| 5644 |
(cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
|
| 5645 |
}
|
| 5646 |
|
| 5647 |
+
static void ggml_cpy_f16_f16_cuda(
|
| 5648 |
+
const char * cx, char * cdst, const int ne,
|
| 5649 |
+
const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
|
| 5650 |
+
const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) {
|
| 5651 |
+
|
| 5652 |
+
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
| 5653 |
+
cpy_f32_f16<cpy_1_f16_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
| 5654 |
+
(cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
|
| 5655 |
+
}
|
| 5656 |
+
|
| 5657 |
static void scale_f32_cuda(const float * x, float * dst, const float scale, const int k, cudaStream_t stream) {
|
| 5658 |
const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE;
|
| 5659 |
scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, k);
|
|
|
|
| 5737 |
soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x);
|
| 5738 |
}
|
| 5739 |
|
| 5740 |
+
static void im2col_f32_f16_cuda(const float * x, half * dst,
|
| 5741 |
+
int OH, int IW, int IH, int OW, int IC,
|
| 5742 |
+
int KH, int KW, int N, int ofs0, int ofs1,
|
| 5743 |
+
int s0, int s1, int p0, int p1, int d0, int d1, cudaStream_t stream) {
|
| 5744 |
+
dim3 block_nums(IC, OH, OW);
|
| 5745 |
+
dim3 block_dims(N, KH, KW);
|
| 5746 |
+
im2col_f32_f16<<<block_nums, block_dims, 0, stream>>>(x, dst, ofs0, ofs1, IW, IH, (IC * KH * KW), s0, s1, p0, p1, d0, d1);
|
| 5747 |
+
}
|
| 5748 |
+
|
| 5749 |
// buffer pool for cuda
|
| 5750 |
#define MAX_CUDA_BUFFERS 256
|
| 5751 |
|
|
|
|
| 6528 |
src1_as_f16 = (half *) ggml_cuda_pool_malloc_async(ne * sizeof(half), &src1_as, id, stream);
|
| 6529 |
to_fp16_cuda(src1_ddf_i, src1_as_f16, ne, stream);
|
| 6530 |
}
|
| 6531 |
+
const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16;
|
| 6532 |
size_t dst_f16_as = 0;
|
| 6533 |
half * dst_f16 = (half *) ggml_cuda_pool_malloc_async(row_diff*src1_ncols * sizeof(half), &dst_f16_as, id, stream);
|
| 6534 |
|
|
|
|
| 6704 |
(void) src1_dd;
|
| 6705 |
}
|
| 6706 |
|
| 6707 |
+
inline void ggml_cuda_op_im2col(
|
| 6708 |
+
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
| 6709 |
+
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
| 6710 |
+
|
| 6711 |
+
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
| 6712 |
+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
| 6713 |
+
GGML_ASSERT( dst->type == GGML_TYPE_F16);
|
| 6714 |
+
|
| 6715 |
+
const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
|
| 6716 |
+
const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
|
| 6717 |
+
const int32_t p0 = ((const int32_t*)(dst->op_params))[2];
|
| 6718 |
+
const int32_t p1 = ((const int32_t*)(dst->op_params))[3];
|
| 6719 |
+
const int32_t d0 = ((const int32_t*)(dst->op_params))[4];
|
| 6720 |
+
const int32_t d1 = ((const int32_t*)(dst->op_params))[5];
|
| 6721 |
+
|
| 6722 |
+
const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1;
|
| 6723 |
+
|
| 6724 |
+
const int64_t N = src1->ne[is_2D ? 3 : 2];
|
| 6725 |
+
const int64_t IC = src1->ne[is_2D ? 2 : 1];
|
| 6726 |
+
const int64_t IH = is_2D ? src1->ne[1] : 1;
|
| 6727 |
+
const int64_t IW = src1->ne[0];
|
| 6728 |
+
|
| 6729 |
+
const int64_t KH = is_2D ? src0->ne[1] : 1;
|
| 6730 |
+
const int64_t KW = src0->ne[0];
|
| 6731 |
+
|
| 6732 |
+
const int64_t OH = is_2D ? dst->ne[2] : 1;
|
| 6733 |
+
const int64_t OW = dst->ne[1];
|
| 6734 |
+
|
| 6735 |
+
const size_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32
|
| 6736 |
+
const size_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
|
| 6737 |
+
|
| 6738 |
+
im2col_f32_f16_cuda(src1_dd, (half*) dst_dd,
|
| 6739 |
+
OH, IW, IH, OW, IC, KH, KW, N,
|
| 6740 |
+
ofs0, ofs1, s0, s1, p0, p1, d0, d1, main_stream);
|
| 6741 |
+
|
| 6742 |
+
(void) src0;
|
| 6743 |
+
(void) src0_dd;
|
| 6744 |
+
}
|
| 6745 |
+
|
| 6746 |
inline void ggml_cuda_op_diag_mask_inf(
|
| 6747 |
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
| 6748 |
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
|
|
|
| 7633 |
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
|
| 7634 |
ggml_cpy_f32_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
|
| 7635 |
ne10, ne11, nb10, nb11, nb12, main_stream);
|
| 7636 |
+
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
|
| 7637 |
+
ggml_cpy_f16_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
|
| 7638 |
+
ne10, ne11, nb10, nb11, nb12, main_stream);
|
| 7639 |
} else {
|
| 7640 |
fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__,
|
| 7641 |
ggml_type_name(src0->type), ggml_type_name(src1->type));
|
|
|
|
| 7667 |
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_alibi);
|
| 7668 |
}
|
| 7669 |
|
| 7670 |
+
void ggml_cuda_im2col(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 7671 |
+
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_im2col);
|
| 7672 |
+
}
|
| 7673 |
+
|
| 7674 |
static void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 7675 |
(void) src0;
|
| 7676 |
(void) src1;
|
|
|
|
| 8034 |
case GGML_OP_ALIBI:
|
| 8035 |
func = ggml_cuda_alibi;
|
| 8036 |
break;
|
| 8037 |
+
case GGML_OP_IM2COL:
|
| 8038 |
+
func = ggml_cuda_im2col;
|
| 8039 |
+
break;
|
| 8040 |
default:
|
| 8041 |
return false;
|
| 8042 |
}
|
|
@@ -26,7 +26,7 @@
|
|
| 26 |
#include <stdbool.h>
|
| 27 |
|
| 28 |
// max memory buffers that can be mapped to the device
|
| 29 |
-
#define GGML_METAL_MAX_BUFFERS
|
| 30 |
#define GGML_METAL_MAX_COMMAND_BUFFERS 32
|
| 31 |
|
| 32 |
struct ggml_tensor;
|
|
|
|
| 26 |
#include <stdbool.h>
|
| 27 |
|
| 28 |
// max memory buffers that can be mapped to the device
|
| 29 |
+
#define GGML_METAL_MAX_BUFFERS 64
|
| 30 |
#define GGML_METAL_MAX_COMMAND_BUFFERS 32
|
| 31 |
|
| 32 |
struct ggml_tensor;
|
|
@@ -86,6 +86,7 @@ struct ggml_metal_context {
|
|
| 86 |
GGML_METAL_DECL_KERNEL(rms_norm);
|
| 87 |
GGML_METAL_DECL_KERNEL(norm);
|
| 88 |
GGML_METAL_DECL_KERNEL(mul_mv_f32_f32);
|
|
|
|
| 89 |
GGML_METAL_DECL_KERNEL(mul_mv_f16_f32);
|
| 90 |
GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_1row);
|
| 91 |
GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_l4);
|
|
@@ -114,6 +115,7 @@ struct ggml_metal_context {
|
|
| 114 |
GGML_METAL_DECL_KERNEL(rope_f32);
|
| 115 |
GGML_METAL_DECL_KERNEL(rope_f16);
|
| 116 |
GGML_METAL_DECL_KERNEL(alibi_f32);
|
|
|
|
| 117 |
GGML_METAL_DECL_KERNEL(cpy_f32_f16);
|
| 118 |
GGML_METAL_DECL_KERNEL(cpy_f32_f32);
|
| 119 |
GGML_METAL_DECL_KERNEL(cpy_f16_f16);
|
|
@@ -287,6 +289,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
| 287 |
GGML_METAL_ADD_KERNEL(rms_norm);
|
| 288 |
GGML_METAL_ADD_KERNEL(norm);
|
| 289 |
GGML_METAL_ADD_KERNEL(mul_mv_f32_f32);
|
|
|
|
| 290 |
GGML_METAL_ADD_KERNEL(mul_mv_f16_f32);
|
| 291 |
GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_1row);
|
| 292 |
GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_l4);
|
|
@@ -317,6 +320,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
| 317 |
GGML_METAL_ADD_KERNEL(rope_f32);
|
| 318 |
GGML_METAL_ADD_KERNEL(rope_f16);
|
| 319 |
GGML_METAL_ADD_KERNEL(alibi_f32);
|
|
|
|
| 320 |
GGML_METAL_ADD_KERNEL(cpy_f32_f16);
|
| 321 |
GGML_METAL_ADD_KERNEL(cpy_f32_f32);
|
| 322 |
GGML_METAL_ADD_KERNEL(cpy_f16_f16);
|
|
@@ -386,6 +390,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|
| 386 |
GGML_METAL_DEL_KERNEL(rms_norm);
|
| 387 |
GGML_METAL_DEL_KERNEL(norm);
|
| 388 |
GGML_METAL_DEL_KERNEL(mul_mv_f32_f32);
|
|
|
|
| 389 |
GGML_METAL_DEL_KERNEL(mul_mv_f16_f32);
|
| 390 |
GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_1row);
|
| 391 |
GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_l4);
|
|
@@ -416,6 +421,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|
| 416 |
GGML_METAL_DEL_KERNEL(rope_f32);
|
| 417 |
GGML_METAL_DEL_KERNEL(rope_f16);
|
| 418 |
GGML_METAL_DEL_KERNEL(alibi_f32);
|
|
|
|
| 419 |
GGML_METAL_DEL_KERNEL(cpy_f32_f16);
|
| 420 |
GGML_METAL_DEL_KERNEL(cpy_f32_f32);
|
| 421 |
GGML_METAL_DEL_KERNEL(cpy_f16_f16);
|
|
@@ -473,6 +479,10 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, stru
|
|
| 473 |
|
| 474 |
const int64_t tsize = ggml_nbytes(t);
|
| 475 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 476 |
// find the view that contains the tensor fully
|
| 477 |
for (int i = 0; i < ctx->n_buffers; ++i) {
|
| 478 |
const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data;
|
|
@@ -1139,6 +1149,7 @@ void ggml_metal_graph_compute(
|
|
| 1139 |
switch (src0t) {
|
| 1140 |
case GGML_TYPE_F32:
|
| 1141 |
{
|
|
|
|
| 1142 |
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f32_f32];
|
| 1143 |
nrows = 4;
|
| 1144 |
} break;
|
|
@@ -1146,13 +1157,18 @@ void ggml_metal_graph_compute(
|
|
| 1146 |
{
|
| 1147 |
nth0 = 32;
|
| 1148 |
nth1 = 1;
|
| 1149 |
-
if (
|
| 1150 |
-
|
| 1151 |
-
|
| 1152 |
-
|
| 1153 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1154 |
} else {
|
| 1155 |
-
[encoder setComputePipelineState:ctx->
|
| 1156 |
nrows = 4;
|
| 1157 |
}
|
| 1158 |
} break;
|
|
@@ -1464,6 +1480,58 @@ void ggml_metal_graph_compute(
|
|
| 1464 |
|
| 1465 |
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
| 1466 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1467 |
case GGML_OP_DUP:
|
| 1468 |
case GGML_OP_CPY:
|
| 1469 |
case GGML_OP_CONT:
|
|
|
|
| 86 |
GGML_METAL_DECL_KERNEL(rms_norm);
|
| 87 |
GGML_METAL_DECL_KERNEL(norm);
|
| 88 |
GGML_METAL_DECL_KERNEL(mul_mv_f32_f32);
|
| 89 |
+
GGML_METAL_DECL_KERNEL(mul_mv_f16_f16);
|
| 90 |
GGML_METAL_DECL_KERNEL(mul_mv_f16_f32);
|
| 91 |
GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_1row);
|
| 92 |
GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_l4);
|
|
|
|
| 115 |
GGML_METAL_DECL_KERNEL(rope_f32);
|
| 116 |
GGML_METAL_DECL_KERNEL(rope_f16);
|
| 117 |
GGML_METAL_DECL_KERNEL(alibi_f32);
|
| 118 |
+
GGML_METAL_DECL_KERNEL(im2col_f16);
|
| 119 |
GGML_METAL_DECL_KERNEL(cpy_f32_f16);
|
| 120 |
GGML_METAL_DECL_KERNEL(cpy_f32_f32);
|
| 121 |
GGML_METAL_DECL_KERNEL(cpy_f16_f16);
|
|
|
|
| 289 |
GGML_METAL_ADD_KERNEL(rms_norm);
|
| 290 |
GGML_METAL_ADD_KERNEL(norm);
|
| 291 |
GGML_METAL_ADD_KERNEL(mul_mv_f32_f32);
|
| 292 |
+
GGML_METAL_ADD_KERNEL(mul_mv_f16_f16);
|
| 293 |
GGML_METAL_ADD_KERNEL(mul_mv_f16_f32);
|
| 294 |
GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_1row);
|
| 295 |
GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_l4);
|
|
|
|
| 320 |
GGML_METAL_ADD_KERNEL(rope_f32);
|
| 321 |
GGML_METAL_ADD_KERNEL(rope_f16);
|
| 322 |
GGML_METAL_ADD_KERNEL(alibi_f32);
|
| 323 |
+
GGML_METAL_ADD_KERNEL(im2col_f16);
|
| 324 |
GGML_METAL_ADD_KERNEL(cpy_f32_f16);
|
| 325 |
GGML_METAL_ADD_KERNEL(cpy_f32_f32);
|
| 326 |
GGML_METAL_ADD_KERNEL(cpy_f16_f16);
|
|
|
|
| 390 |
GGML_METAL_DEL_KERNEL(rms_norm);
|
| 391 |
GGML_METAL_DEL_KERNEL(norm);
|
| 392 |
GGML_METAL_DEL_KERNEL(mul_mv_f32_f32);
|
| 393 |
+
GGML_METAL_DEL_KERNEL(mul_mv_f16_f16);
|
| 394 |
GGML_METAL_DEL_KERNEL(mul_mv_f16_f32);
|
| 395 |
GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_1row);
|
| 396 |
GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_l4);
|
|
|
|
| 421 |
GGML_METAL_DEL_KERNEL(rope_f32);
|
| 422 |
GGML_METAL_DEL_KERNEL(rope_f16);
|
| 423 |
GGML_METAL_DEL_KERNEL(alibi_f32);
|
| 424 |
+
GGML_METAL_DEL_KERNEL(im2col_f16);
|
| 425 |
GGML_METAL_DEL_KERNEL(cpy_f32_f16);
|
| 426 |
GGML_METAL_DEL_KERNEL(cpy_f32_f32);
|
| 427 |
GGML_METAL_DEL_KERNEL(cpy_f16_f16);
|
|
|
|
| 479 |
|
| 480 |
const int64_t tsize = ggml_nbytes(t);
|
| 481 |
|
| 482 |
+
if (t->buffer && t->buffer->backend && t->buffer->backend->context) {
|
| 483 |
+
ctx = t->buffer->backend->context;
|
| 484 |
+
}
|
| 485 |
+
|
| 486 |
// find the view that contains the tensor fully
|
| 487 |
for (int i = 0; i < ctx->n_buffers; ++i) {
|
| 488 |
const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data;
|
|
|
|
| 1149 |
switch (src0t) {
|
| 1150 |
case GGML_TYPE_F32:
|
| 1151 |
{
|
| 1152 |
+
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
| 1153 |
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f32_f32];
|
| 1154 |
nrows = 4;
|
| 1155 |
} break;
|
|
|
|
| 1157 |
{
|
| 1158 |
nth0 = 32;
|
| 1159 |
nth1 = 1;
|
| 1160 |
+
if (src1t == GGML_TYPE_F32) {
|
| 1161 |
+
if (ne11 * ne12 < 4) {
|
| 1162 |
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_1row];
|
| 1163 |
+
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
|
| 1164 |
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_l4];
|
| 1165 |
+
nrows = ne11;
|
| 1166 |
+
} else {
|
| 1167 |
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32];
|
| 1168 |
+
nrows = 4;
|
| 1169 |
+
}
|
| 1170 |
} else {
|
| 1171 |
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f16];
|
| 1172 |
nrows = 4;
|
| 1173 |
}
|
| 1174 |
} break;
|
|
|
|
| 1480 |
|
| 1481 |
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
| 1482 |
} break;
|
| 1483 |
+
case GGML_OP_IM2COL:
|
| 1484 |
+
{
|
| 1485 |
+
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
| 1486 |
+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
| 1487 |
+
GGML_ASSERT( dst->type == GGML_TYPE_F16);
|
| 1488 |
+
|
| 1489 |
+
const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
|
| 1490 |
+
const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
|
| 1491 |
+
const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
|
| 1492 |
+
const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
|
| 1493 |
+
const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
|
| 1494 |
+
const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
|
| 1495 |
+
const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
|
| 1496 |
+
|
| 1497 |
+
const int32_t N = src1->ne[is_2D ? 3 : 2];
|
| 1498 |
+
const int32_t IC = src1->ne[is_2D ? 2 : 1];
|
| 1499 |
+
const int32_t IH = is_2D ? src1->ne[1] : 1;
|
| 1500 |
+
const int32_t IW = src1->ne[0];
|
| 1501 |
+
|
| 1502 |
+
const int32_t KH = is_2D ? src0->ne[1] : 1;
|
| 1503 |
+
const int32_t KW = src0->ne[0];
|
| 1504 |
+
|
| 1505 |
+
const int32_t OH = is_2D ? dst->ne[2] : 1;
|
| 1506 |
+
const int32_t OW = dst->ne[1];
|
| 1507 |
+
|
| 1508 |
+
const int32_t CHW = IC * KH * KW;
|
| 1509 |
+
|
| 1510 |
+
const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
|
| 1511 |
+
const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
|
| 1512 |
+
|
| 1513 |
+
switch (src0->type) {
|
| 1514 |
+
case GGML_TYPE_F32: GGML_ASSERT(false && "not implemented"); break;
|
| 1515 |
+
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_im2col_f16]; break;
|
| 1516 |
+
default: GGML_ASSERT(false);
|
| 1517 |
+
};
|
| 1518 |
+
|
| 1519 |
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
|
| 1520 |
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
| 1521 |
+
[encoder setBytes:&ofs0 length:sizeof( int32_t) atIndex:2];
|
| 1522 |
+
[encoder setBytes:&ofs1 length:sizeof( int32_t) atIndex:3];
|
| 1523 |
+
[encoder setBytes:&IW length:sizeof( int32_t) atIndex:4];
|
| 1524 |
+
[encoder setBytes:&IH length:sizeof( int32_t) atIndex:5];
|
| 1525 |
+
[encoder setBytes:&CHW length:sizeof( int32_t) atIndex:6];
|
| 1526 |
+
[encoder setBytes:&s0 length:sizeof( int32_t) atIndex:7];
|
| 1527 |
+
[encoder setBytes:&s1 length:sizeof( int32_t) atIndex:8];
|
| 1528 |
+
[encoder setBytes:&p0 length:sizeof( int32_t) atIndex:9];
|
| 1529 |
+
[encoder setBytes:&p1 length:sizeof( int32_t) atIndex:10];
|
| 1530 |
+
[encoder setBytes:&d0 length:sizeof( int32_t) atIndex:11];
|
| 1531 |
+
[encoder setBytes:&d1 length:sizeof( int32_t) atIndex:12];
|
| 1532 |
+
|
| 1533 |
+
[encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
|
| 1534 |
+
} break;
|
| 1535 |
case GGML_OP_DUP:
|
| 1536 |
case GGML_OP_CPY:
|
| 1537 |
case GGML_OP_CONT:
|
|
@@ -792,7 +792,7 @@ kernel void kernel_mul_mv_f32_f32(
|
|
| 792 |
constant int64_t & ne0,
|
| 793 |
constant int64_t & ne1,
|
| 794 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 795 |
-
uint
|
| 796 |
|
| 797 |
const int64_t r0 = tgpig.x;
|
| 798 |
const int64_t rb = tgpig.y*N_F32_F32;
|
|
@@ -844,6 +844,79 @@ kernel void kernel_mul_mv_f32_f32(
|
|
| 844 |
}
|
| 845 |
}
|
| 846 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 847 |
kernel void kernel_mul_mv_f16_f32_1row(
|
| 848 |
device const char * src0,
|
| 849 |
device const char * src1,
|
|
@@ -1229,6 +1302,39 @@ kernel void kernel_rope(
|
|
| 1229 |
template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope<float>;
|
| 1230 |
template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope<half>;
|
| 1231 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1232 |
kernel void kernel_cpy_f16_f16(
|
| 1233 |
device const half * src0,
|
| 1234 |
device half * dst,
|
|
|
|
| 792 |
constant int64_t & ne0,
|
| 793 |
constant int64_t & ne1,
|
| 794 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 795 |
+
uint tiisg[[thread_index_in_simdgroup]]) {
|
| 796 |
|
| 797 |
const int64_t r0 = tgpig.x;
|
| 798 |
const int64_t rb = tgpig.y*N_F32_F32;
|
|
|
|
| 844 |
}
|
| 845 |
}
|
| 846 |
|
| 847 |
+
#define N_F16_F16 4
|
| 848 |
+
|
| 849 |
+
kernel void kernel_mul_mv_f16_f16(
|
| 850 |
+
device const char * src0,
|
| 851 |
+
device const char * src1,
|
| 852 |
+
device float * dst,
|
| 853 |
+
constant int64_t & ne00,
|
| 854 |
+
constant int64_t & ne01,
|
| 855 |
+
constant int64_t & ne02,
|
| 856 |
+
constant uint64_t & nb00,
|
| 857 |
+
constant uint64_t & nb01,
|
| 858 |
+
constant uint64_t & nb02,
|
| 859 |
+
constant int64_t & ne10,
|
| 860 |
+
constant int64_t & ne11,
|
| 861 |
+
constant int64_t & ne12,
|
| 862 |
+
constant uint64_t & nb10,
|
| 863 |
+
constant uint64_t & nb11,
|
| 864 |
+
constant uint64_t & nb12,
|
| 865 |
+
constant int64_t & ne0,
|
| 866 |
+
constant int64_t & ne1,
|
| 867 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 868 |
+
uint tiisg[[thread_index_in_simdgroup]]) {
|
| 869 |
+
|
| 870 |
+
const int64_t r0 = tgpig.x;
|
| 871 |
+
const int64_t rb = tgpig.y*N_F16_F16;
|
| 872 |
+
const int64_t im = tgpig.z;
|
| 873 |
+
|
| 874 |
+
device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
|
| 875 |
+
|
| 876 |
+
if (ne00 < 128) {
|
| 877 |
+
for (int row = 0; row < N_F16_F16; ++row) {
|
| 878 |
+
int r1 = rb + row;
|
| 879 |
+
if (r1 >= ne11) {
|
| 880 |
+
break;
|
| 881 |
+
}
|
| 882 |
+
|
| 883 |
+
device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12);
|
| 884 |
+
|
| 885 |
+
float sumf = 0;
|
| 886 |
+
for (int i = tiisg; i < ne00; i += 32) {
|
| 887 |
+
sumf += (half) x[i] * (half) y[i];
|
| 888 |
+
}
|
| 889 |
+
|
| 890 |
+
float all_sum = simd_sum(sumf);
|
| 891 |
+
if (tiisg == 0) {
|
| 892 |
+
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
| 893 |
+
}
|
| 894 |
+
}
|
| 895 |
+
} else {
|
| 896 |
+
device const half4 * x4 = (device const half4 *)x;
|
| 897 |
+
for (int row = 0; row < N_F16_F16; ++row) {
|
| 898 |
+
int r1 = rb + row;
|
| 899 |
+
if (r1 >= ne11) {
|
| 900 |
+
break;
|
| 901 |
+
}
|
| 902 |
+
|
| 903 |
+
device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12);
|
| 904 |
+
device const half4 * y4 = (device const half4 *) y;
|
| 905 |
+
|
| 906 |
+
float sumf = 0;
|
| 907 |
+
for (int i = tiisg; i < ne00/4; i += 32) {
|
| 908 |
+
for (int k = 0; k < 4; ++k) sumf += (half) x4[i][k] * y4[i][k];
|
| 909 |
+
}
|
| 910 |
+
|
| 911 |
+
float all_sum = simd_sum(sumf);
|
| 912 |
+
if (tiisg == 0) {
|
| 913 |
+
for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (half) x[i] * y[i];
|
| 914 |
+
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
| 915 |
+
}
|
| 916 |
+
}
|
| 917 |
+
}
|
| 918 |
+
}
|
| 919 |
+
|
| 920 |
kernel void kernel_mul_mv_f16_f32_1row(
|
| 921 |
device const char * src0,
|
| 922 |
device const char * src1,
|
|
|
|
| 1302 |
template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope<float>;
|
| 1303 |
template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope<half>;
|
| 1304 |
|
| 1305 |
+
kernel void kernel_im2col_f16(
|
| 1306 |
+
device const float * x,
|
| 1307 |
+
device half * dst,
|
| 1308 |
+
constant int32_t & ofs0,
|
| 1309 |
+
constant int32_t & ofs1,
|
| 1310 |
+
constant int32_t & IW,
|
| 1311 |
+
constant int32_t & IH,
|
| 1312 |
+
constant int32_t & CHW,
|
| 1313 |
+
constant int32_t & s0,
|
| 1314 |
+
constant int32_t & s1,
|
| 1315 |
+
constant int32_t & p0,
|
| 1316 |
+
constant int32_t & p1,
|
| 1317 |
+
constant int32_t & d0,
|
| 1318 |
+
constant int32_t & d1,
|
| 1319 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1320 |
+
uint3 tgpg[[threadgroups_per_grid]],
|
| 1321 |
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
| 1322 |
+
uint3 ntg[[threads_per_threadgroup]]) {
|
| 1323 |
+
const int32_t iiw = tgpig[2] * s0 + tpitg[2] * d0 - p0;
|
| 1324 |
+
const int32_t iih = tgpig[1] * s1 + tpitg[1] * d1 - p1;
|
| 1325 |
+
|
| 1326 |
+
const int32_t offset_dst =
|
| 1327 |
+
(tpitg[0] * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW +
|
| 1328 |
+
(tgpig[0] * (ntg[1] * ntg[2]) + tpitg[1] * ntg[2] + tpitg[2]);
|
| 1329 |
+
|
| 1330 |
+
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
|
| 1331 |
+
dst[offset_dst] = 0.0f;
|
| 1332 |
+
} else {
|
| 1333 |
+
const int32_t offset_src = tpitg[0] * ofs0 + tgpig[0] * ofs1;
|
| 1334 |
+
dst[offset_dst] = x[offset_src + iih * IW + iiw];
|
| 1335 |
+
}
|
| 1336 |
+
}
|
| 1337 |
+
|
| 1338 |
kernel void kernel_cpy_f16_f16(
|
| 1339 |
device const half * src0,
|
| 1340 |
device half * dst,
|
|
@@ -1634,13 +1634,8 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|
| 1634 |
"ROPE_BACK",
|
| 1635 |
"ALIBI",
|
| 1636 |
"CLAMP",
|
| 1637 |
-
"CONV_1D",
|
| 1638 |
-
"CONV_1D_STAGE_0",
|
| 1639 |
-
"CONV_1D_STAGE_1",
|
| 1640 |
"CONV_TRANSPOSE_1D",
|
| 1641 |
-
"
|
| 1642 |
-
"CONV_2D_STAGE_0",
|
| 1643 |
-
"CONV_2D_STAGE_1",
|
| 1644 |
"CONV_TRANSPOSE_2D",
|
| 1645 |
"POOL_1D",
|
| 1646 |
"POOL_2D",
|
|
@@ -1671,7 +1666,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|
| 1671 |
"CROSS_ENTROPY_LOSS_BACK",
|
| 1672 |
};
|
| 1673 |
|
| 1674 |
-
static_assert(GGML_OP_COUNT ==
|
| 1675 |
|
| 1676 |
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
| 1677 |
"none",
|
|
@@ -1721,13 +1716,8 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|
| 1721 |
"rope_back(x)",
|
| 1722 |
"alibi(x)",
|
| 1723 |
"clamp(x)",
|
| 1724 |
-
"conv_1d(x)",
|
| 1725 |
-
"conv_1d_stage_0(x)",
|
| 1726 |
-
"conv_1d_stage_1(x)",
|
| 1727 |
"conv_transpose_1d(x)",
|
| 1728 |
-
"
|
| 1729 |
-
"conv_2d_stage_0(x)",
|
| 1730 |
-
"conv_2d_stage_1(x)",
|
| 1731 |
"conv_transpose_2d(x)",
|
| 1732 |
"pool_1d(x)",
|
| 1733 |
"pool_2d(x)",
|
|
@@ -1758,7 +1748,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|
| 1758 |
"cross_entropy_loss_back(x,y)",
|
| 1759 |
};
|
| 1760 |
|
| 1761 |
-
static_assert(GGML_OP_COUNT ==
|
| 1762 |
|
| 1763 |
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
| 1764 |
|
|
@@ -1786,13 +1776,7 @@ static void ggml_setup_op_has_task_pass(void) {
|
|
| 1786 |
p[GGML_OP_GET_ROWS_BACK ] = true;
|
| 1787 |
p[GGML_OP_DIAG_MASK_INF ] = true;
|
| 1788 |
p[GGML_OP_DIAG_MASK_ZERO ] = true;
|
| 1789 |
-
p[GGML_OP_CONV_1D ] = true;
|
| 1790 |
-
p[GGML_OP_CONV_1D_STAGE_0 ] = true;
|
| 1791 |
-
p[GGML_OP_CONV_1D_STAGE_1 ] = true;
|
| 1792 |
p[GGML_OP_CONV_TRANSPOSE_1D ] = true;
|
| 1793 |
-
p[GGML_OP_CONV_2D ] = true;
|
| 1794 |
-
p[GGML_OP_CONV_2D_STAGE_0 ] = true;
|
| 1795 |
-
p[GGML_OP_CONV_2D_STAGE_1 ] = true;
|
| 1796 |
p[GGML_OP_CONV_TRANSPOSE_2D ] = true;
|
| 1797 |
p[GGML_OP_FLASH_ATTN_BACK ] = true;
|
| 1798 |
p[GGML_OP_CROSS_ENTROPY_LOSS ] = true;
|
|
@@ -5137,82 +5121,6 @@ static int64_t ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p,
|
|
| 5137 |
return (ins + 2 * p - d * (ks - 1) - 1) / s + 1;
|
| 5138 |
}
|
| 5139 |
|
| 5140 |
-
// im2col: [N, IC, IL] => [N, OL, IC*K]
|
| 5141 |
-
// a: [OC,IC, K]
|
| 5142 |
-
// b: [N, IC, IL]
|
| 5143 |
-
// result: [N, OL, IC*K]
|
| 5144 |
-
static struct ggml_tensor * ggml_conv_1d_stage_0(
|
| 5145 |
-
struct ggml_context * ctx,
|
| 5146 |
-
struct ggml_tensor * a,
|
| 5147 |
-
struct ggml_tensor * b,
|
| 5148 |
-
int s0,
|
| 5149 |
-
int p0,
|
| 5150 |
-
int d0) {
|
| 5151 |
-
GGML_ASSERT(a->ne[1] == b->ne[1]);
|
| 5152 |
-
bool is_node = false;
|
| 5153 |
-
|
| 5154 |
-
if (a->grad || b->grad) {
|
| 5155 |
-
GGML_ASSERT(false); // TODO: implement backward
|
| 5156 |
-
is_node = true;
|
| 5157 |
-
}
|
| 5158 |
-
|
| 5159 |
-
const int64_t OL = ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0);
|
| 5160 |
-
|
| 5161 |
-
const int64_t ne[4] = {
|
| 5162 |
-
a->ne[1] * a->ne[0],
|
| 5163 |
-
OL,
|
| 5164 |
-
b->ne[2],
|
| 5165 |
-
1,
|
| 5166 |
-
};
|
| 5167 |
-
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F16, 4, ne);
|
| 5168 |
-
|
| 5169 |
-
int32_t params[] = { s0, p0, d0 };
|
| 5170 |
-
ggml_set_op_params(result, params, sizeof(params));
|
| 5171 |
-
|
| 5172 |
-
result->op = GGML_OP_CONV_1D_STAGE_0;
|
| 5173 |
-
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
| 5174 |
-
result->src[0] = a;
|
| 5175 |
-
result->src[1] = b;
|
| 5176 |
-
|
| 5177 |
-
return result;
|
| 5178 |
-
}
|
| 5179 |
-
|
| 5180 |
-
// ggml_conv_1d_stage_1
|
| 5181 |
-
|
| 5182 |
-
// gemm: [N, OC, OL] = [OC, IC * K] x [N*OL, IC * K]
|
| 5183 |
-
// a: [OC, IC, K]
|
| 5184 |
-
// b: [N, OL, IC * K]
|
| 5185 |
-
// result: [N, OC, OL]
|
| 5186 |
-
static struct ggml_tensor * ggml_conv_1d_stage_1(
|
| 5187 |
-
struct ggml_context * ctx,
|
| 5188 |
-
struct ggml_tensor * a,
|
| 5189 |
-
struct ggml_tensor * b) {
|
| 5190 |
-
|
| 5191 |
-
bool is_node = false;
|
| 5192 |
-
|
| 5193 |
-
if (a->grad || b->grad) {
|
| 5194 |
-
GGML_ASSERT(false); // TODO: implement backward
|
| 5195 |
-
is_node = true;
|
| 5196 |
-
}
|
| 5197 |
-
|
| 5198 |
-
const int64_t ne[4] = {
|
| 5199 |
-
b->ne[1],
|
| 5200 |
-
a->ne[2],
|
| 5201 |
-
b->ne[2],
|
| 5202 |
-
1,
|
| 5203 |
-
};
|
| 5204 |
-
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
|
| 5205 |
-
|
| 5206 |
-
result->op = GGML_OP_CONV_1D_STAGE_1;
|
| 5207 |
-
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
| 5208 |
-
result->src[0] = a;
|
| 5209 |
-
result->src[1] = b;
|
| 5210 |
-
|
| 5211 |
-
return result;
|
| 5212 |
-
}
|
| 5213 |
-
|
| 5214 |
-
// ggml_conv_1d
|
| 5215 |
-
|
| 5216 |
GGML_API struct ggml_tensor * ggml_conv_1d(
|
| 5217 |
struct ggml_context * ctx,
|
| 5218 |
struct ggml_tensor * a,
|
|
@@ -5220,43 +5128,17 @@ GGML_API struct ggml_tensor * ggml_conv_1d(
|
|
| 5220 |
int s0,
|
| 5221 |
int p0,
|
| 5222 |
int d0) {
|
| 5223 |
-
struct ggml_tensor *
|
| 5224 |
-
result = ggml_conv_1d_stage_1(ctx, a, result);
|
| 5225 |
-
return result;
|
| 5226 |
-
}
|
| 5227 |
|
| 5228 |
-
|
| 5229 |
-
|
| 5230 |
-
//
|
| 5231 |
-
//
|
| 5232 |
-
// int s0,
|
| 5233 |
-
// int p0,
|
| 5234 |
-
// int d0) {
|
| 5235 |
-
// GGML_ASSERT(ggml_is_matrix(b));
|
| 5236 |
-
// GGML_ASSERT(a->ne[1] == b->ne[1]);
|
| 5237 |
-
// bool is_node = false;
|
| 5238 |
|
| 5239 |
-
|
| 5240 |
-
// GGML_ASSERT(false); // TODO: implement backward
|
| 5241 |
-
// is_node = true;
|
| 5242 |
-
// }
|
| 5243 |
|
| 5244 |
-
|
| 5245 |
-
|
| 5246 |
-
// a->ne[2], 1, 1,
|
| 5247 |
-
// };
|
| 5248 |
-
// struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 2, ne);
|
| 5249 |
-
|
| 5250 |
-
// int32_t params[] = { s0, p0, d0 };
|
| 5251 |
-
// ggml_set_op_params(result, params, sizeof(params));
|
| 5252 |
-
|
| 5253 |
-
// result->op = GGML_OP_CONV_1D;
|
| 5254 |
-
// result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
| 5255 |
-
// result->src[0] = a;
|
| 5256 |
-
// result->src[1] = b;
|
| 5257 |
-
|
| 5258 |
-
// return result;
|
| 5259 |
-
// }
|
| 5260 |
|
| 5261 |
// ggml_conv_1d_ph
|
| 5262 |
|
|
@@ -5319,7 +5201,7 @@ GGML_API struct ggml_tensor * ggml_conv_transpose_1d(
|
|
| 5319 |
// a: [OC,IC, KH, KW]
|
| 5320 |
// b: [N, IC, IH, IW]
|
| 5321 |
// result: [N, OH, OW, IC*KH*KW]
|
| 5322 |
-
|
| 5323 |
struct ggml_context * ctx,
|
| 5324 |
struct ggml_tensor * a,
|
| 5325 |
struct ggml_tensor * b,
|
|
@@ -5328,9 +5210,14 @@ static struct ggml_tensor * ggml_conv_2d_stage_0(
|
|
| 5328 |
int p0,
|
| 5329 |
int p1,
|
| 5330 |
int d0,
|
| 5331 |
-
int d1
|
|
|
|
| 5332 |
|
| 5333 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5334 |
bool is_node = false;
|
| 5335 |
|
| 5336 |
if (a->grad || b->grad) {
|
|
@@ -5338,81 +5225,51 @@ static struct ggml_tensor * ggml_conv_2d_stage_0(
|
|
| 5338 |
is_node = true;
|
| 5339 |
}
|
| 5340 |
|
| 5341 |
-
const int64_t OH = ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1);
|
| 5342 |
-
const int64_t OW =
|
| 5343 |
|
| 5344 |
const int64_t ne[4] = {
|
| 5345 |
-
a->ne[2] * a->ne[1] * a->ne[0],
|
| 5346 |
OW,
|
| 5347 |
-
OH,
|
| 5348 |
-
b->ne[3],
|
| 5349 |
};
|
| 5350 |
-
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F16, 4, ne);
|
| 5351 |
|
| 5352 |
-
|
|
|
|
| 5353 |
ggml_set_op_params(result, params, sizeof(params));
|
| 5354 |
|
| 5355 |
-
result->op =
|
| 5356 |
-
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
| 5357 |
-
result->src[0] = a;
|
| 5358 |
-
result->src[1] = b;
|
| 5359 |
-
|
| 5360 |
-
return result;
|
| 5361 |
-
|
| 5362 |
-
}
|
| 5363 |
-
|
| 5364 |
-
// gemm: [N, OC, OH, OW] = [OC, IC * KH * KW] x [N*OH*OW, IC * KH * KW]
|
| 5365 |
-
// a: [OC, IC, KH, KW]
|
| 5366 |
-
// b: [N, OH, OW, IC * KH * KW]
|
| 5367 |
-
// result: [N, OC, OH, OW]
|
| 5368 |
-
static struct ggml_tensor * ggml_conv_2d_stage_1(
|
| 5369 |
-
struct ggml_context * ctx,
|
| 5370 |
-
struct ggml_tensor * a,
|
| 5371 |
-
struct ggml_tensor * b) {
|
| 5372 |
-
|
| 5373 |
-
bool is_node = false;
|
| 5374 |
-
|
| 5375 |
-
if (a->grad || b->grad) {
|
| 5376 |
-
GGML_ASSERT(false); // TODO: implement backward
|
| 5377 |
-
is_node = true;
|
| 5378 |
-
}
|
| 5379 |
-
|
| 5380 |
-
const int64_t ne[4] = {
|
| 5381 |
-
b->ne[1],
|
| 5382 |
-
b->ne[2],
|
| 5383 |
-
a->ne[3],
|
| 5384 |
-
b->ne[3],
|
| 5385 |
-
};
|
| 5386 |
-
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
|
| 5387 |
-
|
| 5388 |
-
result->op = GGML_OP_CONV_2D_STAGE_1;
|
| 5389 |
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
| 5390 |
result->src[0] = a;
|
| 5391 |
result->src[1] = b;
|
| 5392 |
|
| 5393 |
return result;
|
| 5394 |
-
|
| 5395 |
}
|
| 5396 |
|
| 5397 |
// a: [OC,IC, KH, KW]
|
| 5398 |
// b: [N, IC, IH, IW]
|
| 5399 |
// result: [N, OC, OH, OW]
|
| 5400 |
struct ggml_tensor * ggml_conv_2d(
|
| 5401 |
-
|
| 5402 |
-
|
| 5403 |
-
|
| 5404 |
-
|
| 5405 |
-
|
| 5406 |
-
|
| 5407 |
-
|
| 5408 |
-
|
| 5409 |
-
|
|
|
|
| 5410 |
|
| 5411 |
-
struct ggml_tensor * result =
|
| 5412 |
-
|
|
|
|
|
|
|
| 5413 |
|
| 5414 |
-
|
| 5415 |
|
|
|
|
| 5416 |
}
|
| 5417 |
|
| 5418 |
// ggml_conv_2d_sk_p0
|
|
@@ -9507,6 +9364,8 @@ static bool ggml_compute_forward_mul_mat_use_blas(
|
|
| 9507 |
// TODO: find the optimal values for these
|
| 9508 |
if (ggml_is_contiguous(src0) &&
|
| 9509 |
ggml_is_contiguous(src1) &&
|
|
|
|
|
|
|
| 9510 |
(ne0 >= 32 && ne1 >= 32 && ne10 >= 32)) {
|
| 9511 |
|
| 9512 |
/*printf("BLAS: %d %d %d %d %d\n", ne0, ne1, ne10, ne00, ne01);*/
|
|
@@ -9517,6 +9376,7 @@ static bool ggml_compute_forward_mul_mat_use_blas(
|
|
| 9517 |
}
|
| 9518 |
#endif
|
| 9519 |
|
|
|
|
| 9520 |
static void ggml_compute_forward_mul_mat(
|
| 9521 |
const struct ggml_compute_params * params,
|
| 9522 |
const struct ggml_tensor * src0,
|
|
@@ -9545,7 +9405,7 @@ static void ggml_compute_forward_mul_mat(
|
|
| 9545 |
|
| 9546 |
// we don't support permuted src0 or src1
|
| 9547 |
GGML_ASSERT(nb00 == ggml_type_size(type));
|
| 9548 |
-
GGML_ASSERT(nb10 ==
|
| 9549 |
|
| 9550 |
// dst cannot be transposed or permuted
|
| 9551 |
GGML_ASSERT(nb0 == sizeof(float));
|
|
@@ -11637,9 +11497,9 @@ static void ggml_compute_forward_rope_back(
|
|
| 11637 |
}
|
| 11638 |
}
|
| 11639 |
|
| 11640 |
-
//
|
| 11641 |
|
| 11642 |
-
static void
|
| 11643 |
const struct ggml_compute_params * params,
|
| 11644 |
const struct ggml_tensor * src0,
|
| 11645 |
const struct ggml_tensor * src1,
|
|
@@ -11656,14 +11516,7 @@ static void ggml_compute_forward_conv_1d_f16_f32(
|
|
| 11656 |
const int ith = params->ith;
|
| 11657 |
const int nth = params->nth;
|
| 11658 |
|
| 11659 |
-
const int nk = ne00;
|
| 11660 |
-
|
| 11661 |
-
// size of the convolution row - the kernel size unrolled across all input channels
|
| 11662 |
-
const int ew0 = nk*ne01;
|
| 11663 |
-
|
| 11664 |
-
const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
|
| 11665 |
-
const int32_t p0 = ((const int32_t*)(dst->op_params))[1];
|
| 11666 |
-
const int32_t d0 = ((const int32_t*)(dst->op_params))[2];
|
| 11667 |
|
| 11668 |
GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
|
| 11669 |
GGML_ASSERT(nb10 == sizeof(float));
|
|
@@ -11671,23 +11524,37 @@ static void ggml_compute_forward_conv_1d_f16_f32(
|
|
| 11671 |
if (params->type == GGML_TASK_INIT) {
|
| 11672 |
memset(params->wdata, 0, params->wsize);
|
| 11673 |
|
| 11674 |
-
|
|
|
|
|
|
|
| 11675 |
|
| 11676 |
-
|
| 11677 |
-
|
| 11678 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11679 |
|
| 11680 |
-
|
| 11681 |
-
|
| 11682 |
-
|
|
|
|
| 11683 |
|
| 11684 |
-
|
| 11685 |
-
|
| 11686 |
-
|
|
|
|
| 11687 |
}
|
| 11688 |
}
|
| 11689 |
}
|
| 11690 |
|
|
|
|
|
|
|
|
|
|
| 11691 |
return;
|
| 11692 |
}
|
| 11693 |
|
|
@@ -11695,8 +11562,10 @@ static void ggml_compute_forward_conv_1d_f16_f32(
|
|
| 11695 |
return;
|
| 11696 |
}
|
| 11697 |
|
|
|
|
|
|
|
| 11698 |
// total rows in dst
|
| 11699 |
-
const int nr =
|
| 11700 |
|
| 11701 |
// rows per thread
|
| 11702 |
const int dr = (nr + nth - 1)/nth;
|
|
@@ -11705,22 +11574,26 @@ static void ggml_compute_forward_conv_1d_f16_f32(
|
|
| 11705 |
const int ir0 = dr*ith;
|
| 11706 |
const int ir1 = MIN(ir0 + dr, nr);
|
| 11707 |
|
| 11708 |
-
ggml_fp16_t * const wdata
|
| 11709 |
-
|
| 11710 |
-
for (int i2 = 0; i2 < ne2; i2++) {
|
| 11711 |
-
for (int i1 = ir0; i1 < ir1; i1++) {
|
| 11712 |
-
float * dst_data = (float *)((char *) dst->data + i2*nb2 + i1*nb1);
|
| 11713 |
|
| 11714 |
-
|
| 11715 |
-
|
| 11716 |
-
|
| 11717 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11718 |
}
|
| 11719 |
}
|
| 11720 |
}
|
| 11721 |
}
|
| 11722 |
|
| 11723 |
-
static void
|
| 11724 |
const struct ggml_compute_params * params,
|
| 11725 |
const struct ggml_tensor * src0,
|
| 11726 |
const struct ggml_tensor * src1,
|
|
@@ -11737,13 +11610,7 @@ static void ggml_compute_forward_conv_1d_f32(
|
|
| 11737 |
const int ith = params->ith;
|
| 11738 |
const int nth = params->nth;
|
| 11739 |
|
| 11740 |
-
const int nk = ne00;
|
| 11741 |
-
|
| 11742 |
-
const int ew0 = nk*ne01;
|
| 11743 |
-
|
| 11744 |
-
const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
|
| 11745 |
-
const int32_t p0 = ((const int32_t*)(dst->op_params))[1];
|
| 11746 |
-
const int32_t d0 = ((const int32_t*)(dst->op_params))[2];
|
| 11747 |
|
| 11748 |
GGML_ASSERT(nb00 == sizeof(float));
|
| 11749 |
GGML_ASSERT(nb10 == sizeof(float));
|
|
@@ -11751,23 +11618,37 @@ static void ggml_compute_forward_conv_1d_f32(
|
|
| 11751 |
if (params->type == GGML_TASK_INIT) {
|
| 11752 |
memset(params->wdata, 0, params->wsize);
|
| 11753 |
|
| 11754 |
-
|
|
|
|
|
|
|
| 11755 |
|
| 11756 |
-
|
| 11757 |
-
|
| 11758 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11759 |
|
| 11760 |
-
|
| 11761 |
-
|
| 11762 |
-
|
|
|
|
| 11763 |
|
| 11764 |
-
|
| 11765 |
-
|
| 11766 |
-
|
|
|
|
| 11767 |
}
|
| 11768 |
}
|
| 11769 |
}
|
| 11770 |
|
|
|
|
|
|
|
|
|
|
| 11771 |
return;
|
| 11772 |
}
|
| 11773 |
|
|
@@ -11775,8 +11656,10 @@ static void ggml_compute_forward_conv_1d_f32(
|
|
| 11775 |
return;
|
| 11776 |
}
|
| 11777 |
|
|
|
|
|
|
|
| 11778 |
// total rows in dst
|
| 11779 |
-
const int nr =
|
| 11780 |
|
| 11781 |
// rows per thread
|
| 11782 |
const int dr = (nr + nth - 1)/nth;
|
|
@@ -11785,94 +11668,50 @@ static void ggml_compute_forward_conv_1d_f32(
|
|
| 11785 |
const int ir0 = dr*ith;
|
| 11786 |
const int ir1 = MIN(ir0 + dr, nr);
|
| 11787 |
|
| 11788 |
-
float * const wdata
|
| 11789 |
-
|
| 11790 |
-
for (int i2 = 0; i2 < ne2; i2++) {
|
| 11791 |
-
for (int i1 = ir0; i1 < ir1; i1++) {
|
| 11792 |
-
float * dst_data = (float *)((char *) dst->data + i2*nb2 + i1*nb1);
|
| 11793 |
|
| 11794 |
-
|
| 11795 |
-
|
| 11796 |
-
|
| 11797 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11798 |
}
|
| 11799 |
}
|
| 11800 |
}
|
| 11801 |
}
|
| 11802 |
|
| 11803 |
-
|
| 11804 |
-
|
| 11805 |
-
|
| 11806 |
-
|
| 11807 |
-
|
| 11808 |
-
|
| 11809 |
-
|
| 11810 |
-
|
| 11811 |
-
|
| 11812 |
-
|
| 11813 |
-
|
| 11814 |
-
|
| 11815 |
-
|
| 11816 |
-
|
| 11817 |
-
|
| 11818 |
-
|
| 11819 |
-
|
| 11820 |
-
|
| 11821 |
-
|
| 11822 |
-
// patch range for this thread
|
| 11823 |
-
m0 = dp*ith;
|
| 11824 |
-
m1 = MIN(m0 + dp, np);
|
| 11825 |
-
} else {
|
| 11826 |
-
m0 = 0;
|
| 11827 |
-
m1 = m;
|
| 11828 |
-
|
| 11829 |
-
// total patches in dst
|
| 11830 |
-
const int np = n;
|
| 11831 |
-
|
| 11832 |
-
// patches per thread
|
| 11833 |
-
const int dp = (np + nth - 1)/nth;
|
| 11834 |
-
|
| 11835 |
-
// patch range for this thread
|
| 11836 |
-
n0 = dp*ith;
|
| 11837 |
-
n1 = MIN(n0 + dp, np);
|
| 11838 |
-
}
|
| 11839 |
-
|
| 11840 |
-
// block-tiling attempt
|
| 11841 |
-
int64_t blck_n = 16;
|
| 11842 |
-
int64_t blck_m = 16;
|
| 11843 |
-
|
| 11844 |
-
// int64_t CACHE_SIZE = 2 * 1024 * 1024; // 2MB
|
| 11845 |
-
// int64_t blck_size = CACHE_SIZE / (sizeof(float) + 2 * sizeof(ggml_fp16_t) * K);
|
| 11846 |
-
// if (blck_size > 0) {
|
| 11847 |
-
// blck_0 = 4;
|
| 11848 |
-
// blck_1 = blck_size / blck_0;
|
| 11849 |
-
// if (blck_1 < 0) {
|
| 11850 |
-
// blck_1 = 1;
|
| 11851 |
-
// }
|
| 11852 |
-
// // blck_0 = (int64_t)sqrt(blck_size);
|
| 11853 |
-
// // blck_1 = blck_0;
|
| 11854 |
-
// }
|
| 11855 |
-
// // printf("%zd %zd %zd %zd\n", blck_size, K, blck_0, blck_1);
|
| 11856 |
-
|
| 11857 |
-
for (int j = n0; j < n1; j+=blck_n) {
|
| 11858 |
-
for (int i = m0; i < m1; i+=blck_m) {
|
| 11859 |
-
// printf("i j k => %d %d %d\n", i, j, K);
|
| 11860 |
-
for (int ii = i; ii < i + blck_m && ii < m1; ii++) {
|
| 11861 |
-
for (int jj = j; jj < j + blck_n && jj < n1; jj++) {
|
| 11862 |
-
ggml_vec_dot_f16(k,
|
| 11863 |
-
C + ii*n + jj,
|
| 11864 |
-
A + ii * k,
|
| 11865 |
-
B + jj * k);
|
| 11866 |
-
}
|
| 11867 |
-
}
|
| 11868 |
-
}
|
| 11869 |
}
|
| 11870 |
}
|
| 11871 |
|
| 11872 |
-
// src0: kernel [OC, IC,
|
| 11873 |
-
// src1:
|
| 11874 |
-
// dst: result [N,
|
| 11875 |
-
static void
|
| 11876 |
const struct ggml_compute_params * params,
|
| 11877 |
const struct ggml_tensor * src0,
|
| 11878 |
const struct ggml_tensor * src1,
|
|
@@ -11886,26 +11725,35 @@ static void ggml_compute_forward_conv_1d_stage_0_f32(
|
|
| 11886 |
|
| 11887 |
GGML_TENSOR_BINARY_OP_LOCALS;
|
| 11888 |
|
| 11889 |
-
const
|
| 11890 |
-
const
|
| 11891 |
-
const
|
| 11892 |
-
|
| 11893 |
-
const
|
| 11894 |
-
|
| 11895 |
-
const
|
| 11896 |
|
| 11897 |
const int ith = params->ith;
|
| 11898 |
const int nth = params->nth;
|
| 11899 |
|
| 11900 |
-
const
|
| 11901 |
-
const
|
| 11902 |
-
const
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11903 |
|
| 11904 |
GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
|
| 11905 |
GGML_ASSERT(nb10 == sizeof(float));
|
| 11906 |
|
| 11907 |
if (params->type == GGML_TASK_INIT) {
|
| 11908 |
-
memset(dst->data, 0, ggml_nbytes(dst));
|
| 11909 |
return;
|
| 11910 |
}
|
| 11911 |
|
|
@@ -11913,23 +11761,30 @@ static void ggml_compute_forward_conv_1d_stage_0_f32(
|
|
| 11913 |
return;
|
| 11914 |
}
|
| 11915 |
|
| 11916 |
-
// im2col: [N, IC,
|
| 11917 |
{
|
| 11918 |
ggml_fp16_t * const wdata = (ggml_fp16_t *) dst->data;
|
| 11919 |
|
| 11920 |
for (int64_t in = 0; in < N; in++) {
|
| 11921 |
-
for (int64_t
|
| 11922 |
-
for (int64_t
|
|
|
|
| 11923 |
|
| 11924 |
-
|
| 11925 |
-
|
| 11926 |
-
|
| 11927 |
|
| 11928 |
-
|
| 11929 |
-
|
|
|
|
|
|
|
| 11930 |
|
| 11931 |
-
|
| 11932 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11933 |
}
|
| 11934 |
}
|
| 11935 |
}
|
|
@@ -11938,627 +11793,7 @@ static void ggml_compute_forward_conv_1d_stage_0_f32(
|
|
| 11938 |
}
|
| 11939 |
}
|
| 11940 |
|
| 11941 |
-
|
| 11942 |
-
// src0: [OC, IC, K]
|
| 11943 |
-
// src1: [N, OL, IC * K]
|
| 11944 |
-
// result: [N, OC, OL]
|
| 11945 |
-
static void ggml_compute_forward_conv_1d_stage_1_f16(
|
| 11946 |
-
const struct ggml_compute_params * params,
|
| 11947 |
-
const struct ggml_tensor * src0,
|
| 11948 |
-
const struct ggml_tensor * src1,
|
| 11949 |
-
struct ggml_tensor * dst) {
|
| 11950 |
-
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
| 11951 |
-
GGML_ASSERT(src1->type == GGML_TYPE_F16);
|
| 11952 |
-
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
| 11953 |
-
|
| 11954 |
-
int64_t t0 = ggml_perf_time_us();
|
| 11955 |
-
UNUSED(t0);
|
| 11956 |
-
|
| 11957 |
-
if (params->type == GGML_TASK_INIT) {
|
| 11958 |
-
return;
|
| 11959 |
-
}
|
| 11960 |
-
|
| 11961 |
-
if (params->type == GGML_TASK_FINALIZE) {
|
| 11962 |
-
return;
|
| 11963 |
-
}
|
| 11964 |
-
|
| 11965 |
-
GGML_TENSOR_BINARY_OP_LOCALS;
|
| 11966 |
-
|
| 11967 |
-
GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
|
| 11968 |
-
GGML_ASSERT(nb10 == sizeof(ggml_fp16_t));
|
| 11969 |
-
GGML_ASSERT(nb0 == sizeof(float));
|
| 11970 |
-
|
| 11971 |
-
const int N = ne12;
|
| 11972 |
-
const int OL = ne11;
|
| 11973 |
-
|
| 11974 |
-
const int OC = ne02;
|
| 11975 |
-
const int IC = ne01;
|
| 11976 |
-
const int K = ne00;
|
| 11977 |
-
|
| 11978 |
-
const int ith = params->ith;
|
| 11979 |
-
const int nth = params->nth;
|
| 11980 |
-
|
| 11981 |
-
int64_t m = OC;
|
| 11982 |
-
int64_t n = OL;
|
| 11983 |
-
int64_t k = IC * K;
|
| 11984 |
-
|
| 11985 |
-
// [N, OC, OL] = [OC, IC * K] x [N*OL, IC * K]
|
| 11986 |
-
for (int i = 0; i < N; i++) {
|
| 11987 |
-
ggml_fp16_t * A = (ggml_fp16_t *)src0->data; // [m, k]
|
| 11988 |
-
ggml_fp16_t * B = (ggml_fp16_t *)src1->data + i * m * k; // [n, k]
|
| 11989 |
-
float * C = (float *)dst->data + i * m * n; // [m, n]
|
| 11990 |
-
|
| 11991 |
-
gemm_f16_out_f32(m, n, k, A, B, C, ith, nth);
|
| 11992 |
-
}
|
| 11993 |
-
}
|
| 11994 |
-
|
| 11995 |
-
static void ggml_compute_forward_conv_1d(
|
| 11996 |
-
const struct ggml_compute_params * params,
|
| 11997 |
-
const struct ggml_tensor * src0,
|
| 11998 |
-
const struct ggml_tensor * src1,
|
| 11999 |
-
struct ggml_tensor * dst) {
|
| 12000 |
-
switch(src0->type) {
|
| 12001 |
-
case GGML_TYPE_F16:
|
| 12002 |
-
{
|
| 12003 |
-
ggml_compute_forward_conv_1d_f16_f32(params, src0, src1, dst);
|
| 12004 |
-
} break;
|
| 12005 |
-
case GGML_TYPE_F32:
|
| 12006 |
-
{
|
| 12007 |
-
ggml_compute_forward_conv_1d_f32(params, src0, src1, dst);
|
| 12008 |
-
} break;
|
| 12009 |
-
default:
|
| 12010 |
-
{
|
| 12011 |
-
GGML_ASSERT(false);
|
| 12012 |
-
} break;
|
| 12013 |
-
}
|
| 12014 |
-
}
|
| 12015 |
-
|
| 12016 |
-
static void ggml_compute_forward_conv_1d_stage_0(
|
| 12017 |
-
const struct ggml_compute_params * params,
|
| 12018 |
-
const struct ggml_tensor * src0,
|
| 12019 |
-
const struct ggml_tensor * src1,
|
| 12020 |
-
struct ggml_tensor * dst) {
|
| 12021 |
-
switch(src0->type) {
|
| 12022 |
-
case GGML_TYPE_F16:
|
| 12023 |
-
{
|
| 12024 |
-
ggml_compute_forward_conv_1d_stage_0_f32(params, src0, src1, dst);
|
| 12025 |
-
} break;
|
| 12026 |
-
default:
|
| 12027 |
-
{
|
| 12028 |
-
GGML_ASSERT(false);
|
| 12029 |
-
} break;
|
| 12030 |
-
}
|
| 12031 |
-
}
|
| 12032 |
-
|
| 12033 |
-
static void ggml_compute_forward_conv_1d_stage_1(
|
| 12034 |
-
const struct ggml_compute_params * params,
|
| 12035 |
-
const struct ggml_tensor * src0,
|
| 12036 |
-
const struct ggml_tensor * src1,
|
| 12037 |
-
struct ggml_tensor * dst) {
|
| 12038 |
-
switch(src0->type) {
|
| 12039 |
-
case GGML_TYPE_F16:
|
| 12040 |
-
{
|
| 12041 |
-
ggml_compute_forward_conv_1d_stage_1_f16(params, src0, src1, dst);
|
| 12042 |
-
} break;
|
| 12043 |
-
default:
|
| 12044 |
-
{
|
| 12045 |
-
GGML_ASSERT(false);
|
| 12046 |
-
} break;
|
| 12047 |
-
}
|
| 12048 |
-
}
|
| 12049 |
-
|
| 12050 |
-
// ggml_compute_forward_conv_transpose_1d
|
| 12051 |
-
|
| 12052 |
-
static void ggml_compute_forward_conv_transpose_1d_f16_f32(
|
| 12053 |
-
const struct ggml_compute_params * params,
|
| 12054 |
-
const struct ggml_tensor * src0,
|
| 12055 |
-
const struct ggml_tensor * src1,
|
| 12056 |
-
struct ggml_tensor * dst) {
|
| 12057 |
-
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
| 12058 |
-
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
| 12059 |
-
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
| 12060 |
-
|
| 12061 |
-
int64_t t0 = ggml_perf_time_us();
|
| 12062 |
-
UNUSED(t0);
|
| 12063 |
-
|
| 12064 |
-
GGML_TENSOR_BINARY_OP_LOCALS
|
| 12065 |
-
|
| 12066 |
-
const int ith = params->ith;
|
| 12067 |
-
const int nth = params->nth;
|
| 12068 |
-
|
| 12069 |
-
const int nk = ne00*ne01*ne02;
|
| 12070 |
-
|
| 12071 |
-
GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
|
| 12072 |
-
GGML_ASSERT(nb10 == sizeof(float));
|
| 12073 |
-
|
| 12074 |
-
if (params->type == GGML_TASK_INIT) {
|
| 12075 |
-
memset(params->wdata, 0, params->wsize);
|
| 12076 |
-
|
| 12077 |
-
// permute kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout)
|
| 12078 |
-
{
|
| 12079 |
-
ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
|
| 12080 |
-
|
| 12081 |
-
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
| 12082 |
-
for (int64_t i01 = 0; i01 < ne01; i01++) {
|
| 12083 |
-
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01);
|
| 12084 |
-
ggml_fp16_t * dst_data = wdata + i01*ne00*ne02;
|
| 12085 |
-
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
| 12086 |
-
dst_data[i00*ne02 + i02] = src[i00];
|
| 12087 |
-
}
|
| 12088 |
-
}
|
| 12089 |
-
}
|
| 12090 |
-
}
|
| 12091 |
-
|
| 12092 |
-
// permute source data (src1) from (L x Cin) to (Cin x L)
|
| 12093 |
-
{
|
| 12094 |
-
ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + nk;
|
| 12095 |
-
ggml_fp16_t * dst_data = wdata;
|
| 12096 |
-
|
| 12097 |
-
for (int64_t i11 = 0; i11 < ne11; i11++) {
|
| 12098 |
-
const float * const src = (float *)((char *) src1->data + i11*nb11);
|
| 12099 |
-
for (int64_t i10 = 0; i10 < ne10; i10++) {
|
| 12100 |
-
dst_data[i10*ne11 + i11] = GGML_FP32_TO_FP16(src[i10]);
|
| 12101 |
-
}
|
| 12102 |
-
}
|
| 12103 |
-
}
|
| 12104 |
-
|
| 12105 |
-
// need to zero dst since we are accumulating into it
|
| 12106 |
-
memset(dst->data, 0, ggml_nbytes(dst));
|
| 12107 |
-
|
| 12108 |
-
return;
|
| 12109 |
-
}
|
| 12110 |
-
|
| 12111 |
-
if (params->type == GGML_TASK_FINALIZE) {
|
| 12112 |
-
return;
|
| 12113 |
-
}
|
| 12114 |
-
|
| 12115 |
-
const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
|
| 12116 |
-
|
| 12117 |
-
// total rows in dst
|
| 12118 |
-
const int nr = ne1;
|
| 12119 |
-
|
| 12120 |
-
// rows per thread
|
| 12121 |
-
const int dr = (nr + nth - 1)/nth;
|
| 12122 |
-
|
| 12123 |
-
// row range for this thread
|
| 12124 |
-
const int ir0 = dr*ith;
|
| 12125 |
-
const int ir1 = MIN(ir0 + dr, nr);
|
| 12126 |
-
|
| 12127 |
-
ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
|
| 12128 |
-
ggml_fp16_t * const wdata_src = wdata + nk;
|
| 12129 |
-
|
| 12130 |
-
for (int i1 = ir0; i1 < ir1; i1++) {
|
| 12131 |
-
float * dst_data = (float *)((char *) dst->data + i1*nb1);
|
| 12132 |
-
ggml_fp16_t * wdata_kernel = wdata + i1*ne02*ne00;
|
| 12133 |
-
for (int i10 = 0; i10 < ne10; i10++) {
|
| 12134 |
-
const int i1n = i10*ne11;
|
| 12135 |
-
for (int i00 = 0; i00 < ne00; i00++) {
|
| 12136 |
-
float v = 0;
|
| 12137 |
-
ggml_vec_dot_f16(ne02, &v,
|
| 12138 |
-
(ggml_fp16_t *) wdata_src + i1n,
|
| 12139 |
-
(ggml_fp16_t *) wdata_kernel + i00*ne02);
|
| 12140 |
-
dst_data[i10*s0 + i00] += v;
|
| 12141 |
-
}
|
| 12142 |
-
}
|
| 12143 |
-
}
|
| 12144 |
-
}
|
| 12145 |
-
|
| 12146 |
-
static void ggml_compute_forward_conv_transpose_1d_f32(
|
| 12147 |
-
const struct ggml_compute_params * params,
|
| 12148 |
-
const struct ggml_tensor * src0,
|
| 12149 |
-
const struct ggml_tensor * src1,
|
| 12150 |
-
struct ggml_tensor * dst) {
|
| 12151 |
-
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
| 12152 |
-
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
| 12153 |
-
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
| 12154 |
-
|
| 12155 |
-
int64_t t0 = ggml_perf_time_us();
|
| 12156 |
-
UNUSED(t0);
|
| 12157 |
-
|
| 12158 |
-
GGML_TENSOR_BINARY_OP_LOCALS
|
| 12159 |
-
|
| 12160 |
-
const int ith = params->ith;
|
| 12161 |
-
const int nth = params->nth;
|
| 12162 |
-
|
| 12163 |
-
const int nk = ne00*ne01*ne02;
|
| 12164 |
-
|
| 12165 |
-
GGML_ASSERT(nb00 == sizeof(float));
|
| 12166 |
-
GGML_ASSERT(nb10 == sizeof(float));
|
| 12167 |
-
|
| 12168 |
-
if (params->type == GGML_TASK_INIT) {
|
| 12169 |
-
memset(params->wdata, 0, params->wsize);
|
| 12170 |
-
|
| 12171 |
-
// prepare kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout)
|
| 12172 |
-
{
|
| 12173 |
-
float * const wdata = (float *) params->wdata + 0;
|
| 12174 |
-
|
| 12175 |
-
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
| 12176 |
-
for (int64_t i01 = 0; i01 < ne01; i01++) {
|
| 12177 |
-
const float * const src = (float *)((char *) src0->data + i02*nb02 + i01*nb01);
|
| 12178 |
-
float * dst_data = wdata + i01*ne00*ne02;
|
| 12179 |
-
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
| 12180 |
-
dst_data[i00*ne02 + i02] = src[i00];
|
| 12181 |
-
}
|
| 12182 |
-
}
|
| 12183 |
-
}
|
| 12184 |
-
}
|
| 12185 |
-
|
| 12186 |
-
// prepare source data (src1)
|
| 12187 |
-
{
|
| 12188 |
-
float * const wdata = (float *) params->wdata + nk;
|
| 12189 |
-
float * dst_data = wdata;
|
| 12190 |
-
|
| 12191 |
-
for (int64_t i11 = 0; i11 < ne11; i11++) {
|
| 12192 |
-
const float * const src = (float *)((char *) src1->data + i11*nb11);
|
| 12193 |
-
for (int64_t i10 = 0; i10 < ne10; i10++) {
|
| 12194 |
-
dst_data[i10*ne11 + i11] = src[i10];
|
| 12195 |
-
}
|
| 12196 |
-
}
|
| 12197 |
-
}
|
| 12198 |
-
|
| 12199 |
-
// need to zero dst since we are accumulating into it
|
| 12200 |
-
memset(dst->data, 0, ggml_nbytes(dst));
|
| 12201 |
-
|
| 12202 |
-
return;
|
| 12203 |
-
}
|
| 12204 |
-
|
| 12205 |
-
if (params->type == GGML_TASK_FINALIZE) {
|
| 12206 |
-
return;
|
| 12207 |
-
}
|
| 12208 |
-
|
| 12209 |
-
const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
|
| 12210 |
-
|
| 12211 |
-
// total rows in dst
|
| 12212 |
-
const int nr = ne1;
|
| 12213 |
-
|
| 12214 |
-
// rows per thread
|
| 12215 |
-
const int dr = (nr + nth - 1)/nth;
|
| 12216 |
-
|
| 12217 |
-
// row range for this thread
|
| 12218 |
-
const int ir0 = dr*ith;
|
| 12219 |
-
const int ir1 = MIN(ir0 + dr, nr);
|
| 12220 |
-
|
| 12221 |
-
float * const wdata = (float *) params->wdata + 0;
|
| 12222 |
-
float * const wdata_src = wdata + nk;
|
| 12223 |
-
|
| 12224 |
-
for (int i1 = ir0; i1 < ir1; i1++) {
|
| 12225 |
-
float * dst_data = (float *)((char *) dst->data + i1*nb1);
|
| 12226 |
-
float * wdata_kernel = wdata + i1*ne02*ne00;
|
| 12227 |
-
for (int i10 = 0; i10 < ne10; i10++) {
|
| 12228 |
-
const int i1n = i10*ne11;
|
| 12229 |
-
for (int i00 = 0; i00 < ne00; i00++) {
|
| 12230 |
-
float v = 0;
|
| 12231 |
-
ggml_vec_dot_f32(ne02, &v,
|
| 12232 |
-
wdata_src + i1n,
|
| 12233 |
-
wdata_kernel + i00*ne02);
|
| 12234 |
-
dst_data[i10*s0 + i00] += v;
|
| 12235 |
-
}
|
| 12236 |
-
}
|
| 12237 |
-
}
|
| 12238 |
-
}
|
| 12239 |
-
|
| 12240 |
-
static void ggml_compute_forward_conv_transpose_1d(
|
| 12241 |
-
const struct ggml_compute_params * params,
|
| 12242 |
-
const struct ggml_tensor * src0,
|
| 12243 |
-
const struct ggml_tensor * src1,
|
| 12244 |
-
struct ggml_tensor * dst) {
|
| 12245 |
-
switch (src0->type) {
|
| 12246 |
-
case GGML_TYPE_F16:
|
| 12247 |
-
{
|
| 12248 |
-
ggml_compute_forward_conv_transpose_1d_f16_f32(params, src0, src1, dst);
|
| 12249 |
-
} break;
|
| 12250 |
-
case GGML_TYPE_F32:
|
| 12251 |
-
{
|
| 12252 |
-
ggml_compute_forward_conv_transpose_1d_f32(params, src0, src1, dst);
|
| 12253 |
-
} break;
|
| 12254 |
-
default:
|
| 12255 |
-
{
|
| 12256 |
-
GGML_ASSERT(false);
|
| 12257 |
-
} break;
|
| 12258 |
-
}
|
| 12259 |
-
}
|
| 12260 |
-
|
| 12261 |
-
// ggml_compute_forward_conv_2d
|
| 12262 |
-
|
| 12263 |
-
// src0: kernel [OC, IC, KH, KW]
|
| 12264 |
-
// src1: image [N, IC, IH, IW]
|
| 12265 |
-
// dst: result [N, OH, OW, IC*KH*KW]
|
| 12266 |
-
static void ggml_compute_forward_conv_2d_stage_0_f32(
|
| 12267 |
-
const struct ggml_compute_params * params,
|
| 12268 |
-
const struct ggml_tensor * src0,
|
| 12269 |
-
const struct ggml_tensor * src1,
|
| 12270 |
-
struct ggml_tensor * dst) {
|
| 12271 |
-
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
| 12272 |
-
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
| 12273 |
-
GGML_ASSERT( dst->type == GGML_TYPE_F16);
|
| 12274 |
-
|
| 12275 |
-
int64_t t0 = ggml_perf_time_us();
|
| 12276 |
-
UNUSED(t0);
|
| 12277 |
-
|
| 12278 |
-
GGML_TENSOR_BINARY_OP_LOCALS;
|
| 12279 |
-
|
| 12280 |
-
const int64_t N = ne13;
|
| 12281 |
-
const int64_t IC = ne12;
|
| 12282 |
-
const int64_t IH = ne11;
|
| 12283 |
-
const int64_t IW = ne10;
|
| 12284 |
-
|
| 12285 |
-
// const int64_t OC = ne03;
|
| 12286 |
-
// const int64_t IC = ne02;
|
| 12287 |
-
const int64_t KH = ne01;
|
| 12288 |
-
const int64_t KW = ne00;
|
| 12289 |
-
|
| 12290 |
-
const int64_t OH = ne2;
|
| 12291 |
-
const int64_t OW = ne1;
|
| 12292 |
-
|
| 12293 |
-
const int ith = params->ith;
|
| 12294 |
-
const int nth = params->nth;
|
| 12295 |
-
|
| 12296 |
-
const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
|
| 12297 |
-
const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
|
| 12298 |
-
const int32_t p0 = ((const int32_t*)(dst->op_params))[2];
|
| 12299 |
-
const int32_t p1 = ((const int32_t*)(dst->op_params))[3];
|
| 12300 |
-
const int32_t d0 = ((const int32_t*)(dst->op_params))[4];
|
| 12301 |
-
const int32_t d1 = ((const int32_t*)(dst->op_params))[5];
|
| 12302 |
-
|
| 12303 |
-
GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
|
| 12304 |
-
GGML_ASSERT(nb10 == sizeof(float));
|
| 12305 |
-
|
| 12306 |
-
if (params->type == GGML_TASK_INIT) {
|
| 12307 |
-
memset(dst->data, 0, ggml_nbytes(dst));
|
| 12308 |
-
return;
|
| 12309 |
-
}
|
| 12310 |
-
|
| 12311 |
-
if (params->type == GGML_TASK_FINALIZE) {
|
| 12312 |
-
return;
|
| 12313 |
-
}
|
| 12314 |
-
|
| 12315 |
-
// im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
|
| 12316 |
-
{
|
| 12317 |
-
ggml_fp16_t * const wdata = (ggml_fp16_t *) dst->data;
|
| 12318 |
-
|
| 12319 |
-
for (int64_t in = 0; in < N; in++) {
|
| 12320 |
-
for (int64_t ioh = 0; ioh < OH; ioh++) {
|
| 12321 |
-
for (int64_t iow = 0; iow < OW; iow++) {
|
| 12322 |
-
for (int64_t iic = ith; iic < IC; iic+=nth) {
|
| 12323 |
-
|
| 12324 |
-
// micro kernel
|
| 12325 |
-
ggml_fp16_t * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
|
| 12326 |
-
const float * const src_data = (float *)((char *) src1->data + in*nb13 + iic*nb12); // [IH, IW]
|
| 12327 |
-
|
| 12328 |
-
for (int64_t ikh = 0; ikh < KH; ikh++) {
|
| 12329 |
-
for (int64_t ikw = 0; ikw < KW; ikw++) {
|
| 12330 |
-
const int64_t iiw = iow*s0 + ikw*d0 - p0;
|
| 12331 |
-
const int64_t iih = ioh*s1 + ikh*d1 - p1;
|
| 12332 |
-
|
| 12333 |
-
if (!(iih < 0 || iih >= IH || iiw < 0 || iiw >= IW)) {
|
| 12334 |
-
dst_data[iic*(KH*KW) + ikh*KW + ikw] = GGML_FP32_TO_FP16(src_data[iih*IW + iiw]);
|
| 12335 |
-
}
|
| 12336 |
-
}
|
| 12337 |
-
}
|
| 12338 |
-
}
|
| 12339 |
-
}
|
| 12340 |
-
}
|
| 12341 |
-
}
|
| 12342 |
-
}
|
| 12343 |
-
}
|
| 12344 |
-
|
| 12345 |
-
// gemm: [N, OC, OH, OW] = [OC, IC * KH * KW] x [N*OH*OW, IC * KH * KW]
|
| 12346 |
-
// src0: [OC, IC, KH, KW]
|
| 12347 |
-
// src1: [N, OH, OW, IC * KH * KW]
|
| 12348 |
-
// result: [N, OC, OH, OW]
|
| 12349 |
-
static void ggml_compute_forward_conv_2d_stage_1_f16(
|
| 12350 |
-
const struct ggml_compute_params * params,
|
| 12351 |
-
const struct ggml_tensor * src0,
|
| 12352 |
-
const struct ggml_tensor * src1,
|
| 12353 |
-
struct ggml_tensor * dst) {
|
| 12354 |
-
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
| 12355 |
-
GGML_ASSERT(src1->type == GGML_TYPE_F16);
|
| 12356 |
-
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
| 12357 |
-
|
| 12358 |
-
int64_t t0 = ggml_perf_time_us();
|
| 12359 |
-
UNUSED(t0);
|
| 12360 |
-
|
| 12361 |
-
if (params->type == GGML_TASK_INIT) {
|
| 12362 |
-
return;
|
| 12363 |
-
}
|
| 12364 |
-
|
| 12365 |
-
if (params->type == GGML_TASK_FINALIZE) {
|
| 12366 |
-
return;
|
| 12367 |
-
}
|
| 12368 |
-
|
| 12369 |
-
GGML_TENSOR_BINARY_OP_LOCALS;
|
| 12370 |
-
|
| 12371 |
-
GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
|
| 12372 |
-
GGML_ASSERT(nb10 == sizeof(ggml_fp16_t));
|
| 12373 |
-
GGML_ASSERT(nb0 == sizeof(float));
|
| 12374 |
-
|
| 12375 |
-
const int N = ne13;
|
| 12376 |
-
const int OH = ne12;
|
| 12377 |
-
const int OW = ne11;
|
| 12378 |
-
|
| 12379 |
-
const int OC = ne03;
|
| 12380 |
-
const int IC = ne02;
|
| 12381 |
-
const int KH = ne01;
|
| 12382 |
-
const int KW = ne00;
|
| 12383 |
-
|
| 12384 |
-
const int ith = params->ith;
|
| 12385 |
-
const int nth = params->nth;
|
| 12386 |
-
|
| 12387 |
-
int64_t m = OC;
|
| 12388 |
-
int64_t n = OH * OW;
|
| 12389 |
-
int64_t k = IC * KH * KW;
|
| 12390 |
-
|
| 12391 |
-
// [N, OC, OH, OW] = [OC, IC * KH * KW] x [N*OH*OW, IC * KH * KW]
|
| 12392 |
-
for (int i = 0; i < N; i++) {
|
| 12393 |
-
ggml_fp16_t * A = (ggml_fp16_t *)src0->data; // [m, k]
|
| 12394 |
-
ggml_fp16_t * B = (ggml_fp16_t *)src1->data + i * m * k; // [n, k]
|
| 12395 |
-
float * C = (float *)dst->data + i * m * n; // [m, n]
|
| 12396 |
-
|
| 12397 |
-
gemm_f16_out_f32(m, n, k, A, B, C, ith, nth);
|
| 12398 |
-
}
|
| 12399 |
-
}
|
| 12400 |
-
|
| 12401 |
-
static void ggml_compute_forward_conv_2d_f16_f32(
|
| 12402 |
-
const struct ggml_compute_params * params,
|
| 12403 |
-
const struct ggml_tensor * src0,
|
| 12404 |
-
const struct ggml_tensor * src1,
|
| 12405 |
-
struct ggml_tensor * dst) {
|
| 12406 |
-
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
| 12407 |
-
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
| 12408 |
-
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
| 12409 |
-
|
| 12410 |
-
int64_t t0 = ggml_perf_time_us();
|
| 12411 |
-
UNUSED(t0);
|
| 12412 |
-
|
| 12413 |
-
GGML_TENSOR_BINARY_OP_LOCALS
|
| 12414 |
-
|
| 12415 |
-
// src1: image [N, IC, IH, IW]
|
| 12416 |
-
// src0: kernel [OC, IC, KH, KW]
|
| 12417 |
-
// dst: result [N, OC, OH, OW]
|
| 12418 |
-
// ne12: IC
|
| 12419 |
-
// ne0: OW
|
| 12420 |
-
// ne1: OH
|
| 12421 |
-
// nk0: KW
|
| 12422 |
-
// nk1: KH
|
| 12423 |
-
// ne13: N
|
| 12424 |
-
|
| 12425 |
-
const int N = ne13;
|
| 12426 |
-
const int IC = ne12;
|
| 12427 |
-
const int IH = ne11;
|
| 12428 |
-
const int IW = ne10;
|
| 12429 |
-
|
| 12430 |
-
const int OC = ne03;
|
| 12431 |
-
// const int IC = ne02;
|
| 12432 |
-
const int KH = ne01;
|
| 12433 |
-
const int KW = ne00;
|
| 12434 |
-
|
| 12435 |
-
const int OH = ne1;
|
| 12436 |
-
const int OW = ne0;
|
| 12437 |
-
|
| 12438 |
-
const int ith = params->ith;
|
| 12439 |
-
const int nth = params->nth;
|
| 12440 |
-
|
| 12441 |
-
// const int nk0 = ne00;
|
| 12442 |
-
// const int nk1 = ne01;
|
| 12443 |
-
|
| 12444 |
-
// size of the convolution row - the kernel size unrolled across all channels
|
| 12445 |
-
// const int ew0 = nk0*nk1*ne02;
|
| 12446 |
-
// ew0: IC*KH*KW
|
| 12447 |
-
|
| 12448 |
-
const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
|
| 12449 |
-
const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
|
| 12450 |
-
const int32_t p0 = ((const int32_t*)(dst->op_params))[2];
|
| 12451 |
-
const int32_t p1 = ((const int32_t*)(dst->op_params))[3];
|
| 12452 |
-
const int32_t d0 = ((const int32_t*)(dst->op_params))[4];
|
| 12453 |
-
const int32_t d1 = ((const int32_t*)(dst->op_params))[5];
|
| 12454 |
-
|
| 12455 |
-
GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
|
| 12456 |
-
GGML_ASSERT(nb10 == sizeof(float));
|
| 12457 |
-
|
| 12458 |
-
if (params->type == GGML_TASK_INIT) {
|
| 12459 |
-
memset(params->wdata, 0, params->wsize);
|
| 12460 |
-
|
| 12461 |
-
// prepare source data (src1)
|
| 12462 |
-
// im2col: [N, IC, IH, IW] => [N*OH*OW, IC*KH*KW]
|
| 12463 |
-
|
| 12464 |
-
{
|
| 12465 |
-
ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
|
| 12466 |
-
|
| 12467 |
-
for (int in = 0; in < N; in++) {
|
| 12468 |
-
for (int iic = 0; iic < IC; iic++) {
|
| 12469 |
-
for (int ioh = 0; ioh < OH; ioh++) {
|
| 12470 |
-
for (int iow = 0; iow < OW; iow++) {
|
| 12471 |
-
|
| 12472 |
-
// micro kernel
|
| 12473 |
-
ggml_fp16_t * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
|
| 12474 |
-
const float * const src_data = (float *)((char *) src1->data + in*nb13 + iic*nb12); // [IH, IW]
|
| 12475 |
-
|
| 12476 |
-
for (int ikh = 0; ikh < KH; ikh++) {
|
| 12477 |
-
for (int ikw = 0; ikw < KW; ikw++) {
|
| 12478 |
-
const int iiw = iow*s0 + ikw*d0 - p0;
|
| 12479 |
-
const int iih = ioh*s1 + ikh*d1 - p1;
|
| 12480 |
-
|
| 12481 |
-
if (!(iih < 0 || iih >= IH || iiw < 0 || iiw >= IW)) {
|
| 12482 |
-
dst_data[iic*(KH*KW) + ikh*KW + ikw] = GGML_FP32_TO_FP16(src_data[iih*IW + iiw]);
|
| 12483 |
-
}
|
| 12484 |
-
}
|
| 12485 |
-
}
|
| 12486 |
-
}
|
| 12487 |
-
}
|
| 12488 |
-
}
|
| 12489 |
-
}
|
| 12490 |
-
}
|
| 12491 |
-
|
| 12492 |
-
return;
|
| 12493 |
-
}
|
| 12494 |
-
|
| 12495 |
-
if (params->type == GGML_TASK_FINALIZE) {
|
| 12496 |
-
return;
|
| 12497 |
-
}
|
| 12498 |
-
|
| 12499 |
-
ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
|
| 12500 |
-
// wdata: [N*OH*OW, IC*KH*KW]
|
| 12501 |
-
// dst: result [N, OC, OH, OW]
|
| 12502 |
-
// src0: kernel [OC, IC, KH, KW]
|
| 12503 |
-
|
| 12504 |
-
int64_t m = OC;
|
| 12505 |
-
int64_t n = OH * OW;
|
| 12506 |
-
int64_t k = IC * KH * KW;
|
| 12507 |
-
|
| 12508 |
-
// [N, OC, OH, OW] = [OC, IC * KH * KW] x [N*OH*OW, IC * KH * KW]
|
| 12509 |
-
for (int i = 0; i < N; i++) {
|
| 12510 |
-
ggml_fp16_t * A = (ggml_fp16_t *)src0->data; // [m, k]
|
| 12511 |
-
ggml_fp16_t * B = (ggml_fp16_t *)wdata + i * m * k; // [n, k]
|
| 12512 |
-
float * C = (float *)dst->data + i * m * n; // [m * k]
|
| 12513 |
-
|
| 12514 |
-
gemm_f16_out_f32(m, n, k, A, B, C, ith, nth);
|
| 12515 |
-
}
|
| 12516 |
-
}
|
| 12517 |
-
|
| 12518 |
-
static void ggml_compute_forward_conv_2d(
|
| 12519 |
-
const struct ggml_compute_params * params,
|
| 12520 |
-
const struct ggml_tensor * src0,
|
| 12521 |
-
const struct ggml_tensor * src1,
|
| 12522 |
-
struct ggml_tensor * dst) {
|
| 12523 |
-
switch (src0->type) {
|
| 12524 |
-
case GGML_TYPE_F16:
|
| 12525 |
-
{
|
| 12526 |
-
ggml_compute_forward_conv_2d_f16_f32(params, src0, src1, dst);
|
| 12527 |
-
} break;
|
| 12528 |
-
case GGML_TYPE_F32:
|
| 12529 |
-
{
|
| 12530 |
-
//ggml_compute_forward_conv_2d_f32(params, src0, src1, dst);
|
| 12531 |
-
GGML_ASSERT(false);
|
| 12532 |
-
} break;
|
| 12533 |
-
default:
|
| 12534 |
-
{
|
| 12535 |
-
GGML_ASSERT(false);
|
| 12536 |
-
} break;
|
| 12537 |
-
}
|
| 12538 |
-
}
|
| 12539 |
-
|
| 12540 |
-
static void ggml_compute_forward_conv_2d_stage_0(
|
| 12541 |
-
const struct ggml_compute_params * params,
|
| 12542 |
-
const struct ggml_tensor * src0,
|
| 12543 |
-
const struct ggml_tensor * src1,
|
| 12544 |
-
struct ggml_tensor * dst) {
|
| 12545 |
-
switch (src0->type) {
|
| 12546 |
-
case GGML_TYPE_F16:
|
| 12547 |
-
{
|
| 12548 |
-
ggml_compute_forward_conv_2d_stage_0_f32(params, src0, src1, dst);
|
| 12549 |
-
} break;
|
| 12550 |
-
case GGML_TYPE_F32:
|
| 12551 |
-
{
|
| 12552 |
-
GGML_ASSERT(false);
|
| 12553 |
-
} break;
|
| 12554 |
-
default:
|
| 12555 |
-
{
|
| 12556 |
-
GGML_ASSERT(false);
|
| 12557 |
-
} break;
|
| 12558 |
-
}
|
| 12559 |
-
}
|
| 12560 |
-
|
| 12561 |
-
static void ggml_compute_forward_conv_2d_stage_1(
|
| 12562 |
const struct ggml_compute_params * params,
|
| 12563 |
const struct ggml_tensor * src0,
|
| 12564 |
const struct ggml_tensor * src1,
|
|
@@ -12566,7 +11801,7 @@ static void ggml_compute_forward_conv_2d_stage_1(
|
|
| 12566 |
switch (src0->type) {
|
| 12567 |
case GGML_TYPE_F16:
|
| 12568 |
{
|
| 12569 |
-
|
| 12570 |
} break;
|
| 12571 |
case GGML_TYPE_F32:
|
| 12572 |
{
|
|
@@ -14783,33 +14018,13 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|
| 14783 |
{
|
| 14784 |
ggml_compute_forward_clamp(params, tensor->src[0], tensor);
|
| 14785 |
} break;
|
| 14786 |
-
case GGML_OP_CONV_1D:
|
| 14787 |
-
{
|
| 14788 |
-
ggml_compute_forward_conv_1d(params, tensor->src[0], tensor->src[1], tensor);
|
| 14789 |
-
} break;
|
| 14790 |
-
case GGML_OP_CONV_1D_STAGE_0:
|
| 14791 |
-
{
|
| 14792 |
-
ggml_compute_forward_conv_1d_stage_0(params, tensor->src[0], tensor->src[1], tensor);
|
| 14793 |
-
} break;
|
| 14794 |
-
case GGML_OP_CONV_1D_STAGE_1:
|
| 14795 |
-
{
|
| 14796 |
-
ggml_compute_forward_conv_1d_stage_1(params, tensor->src[0], tensor->src[1], tensor);
|
| 14797 |
-
} break;
|
| 14798 |
case GGML_OP_CONV_TRANSPOSE_1D:
|
| 14799 |
{
|
| 14800 |
ggml_compute_forward_conv_transpose_1d(params, tensor->src[0], tensor->src[1], tensor);
|
| 14801 |
} break;
|
| 14802 |
-
case
|
| 14803 |
-
{
|
| 14804 |
-
ggml_compute_forward_conv_2d(params, tensor->src[0], tensor->src[1], tensor);
|
| 14805 |
-
} break;
|
| 14806 |
-
case GGML_OP_CONV_2D_STAGE_0:
|
| 14807 |
-
{
|
| 14808 |
-
ggml_compute_forward_conv_2d_stage_0(params, tensor->src[0], tensor->src[1], tensor);
|
| 14809 |
-
} break;
|
| 14810 |
-
case GGML_OP_CONV_2D_STAGE_1:
|
| 14811 |
{
|
| 14812 |
-
|
| 14813 |
} break;
|
| 14814 |
case GGML_OP_CONV_TRANSPOSE_2D:
|
| 14815 |
{
|
|
@@ -15780,31 +14995,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
| 15780 |
{
|
| 15781 |
GGML_ASSERT(false); // TODO: not implemented
|
| 15782 |
} break;
|
| 15783 |
-
case GGML_OP_CONV_1D:
|
| 15784 |
-
{
|
| 15785 |
-
GGML_ASSERT(false); // TODO: not implemented
|
| 15786 |
-
} break;
|
| 15787 |
-
case GGML_OP_CONV_1D_STAGE_0:
|
| 15788 |
-
{
|
| 15789 |
-
GGML_ASSERT(false); // TODO: not implemented
|
| 15790 |
-
} break;
|
| 15791 |
-
case GGML_OP_CONV_1D_STAGE_1:
|
| 15792 |
-
{
|
| 15793 |
-
GGML_ASSERT(false); // TODO: not implemented
|
| 15794 |
-
} break;
|
| 15795 |
case GGML_OP_CONV_TRANSPOSE_1D:
|
| 15796 |
{
|
| 15797 |
GGML_ASSERT(false); // TODO: not implemented
|
| 15798 |
} break;
|
| 15799 |
-
case
|
| 15800 |
-
{
|
| 15801 |
-
GGML_ASSERT(false); // TODO: not implemented
|
| 15802 |
-
} break;
|
| 15803 |
-
case GGML_OP_CONV_2D_STAGE_0:
|
| 15804 |
-
{
|
| 15805 |
-
GGML_ASSERT(false); // TODO: not implemented
|
| 15806 |
-
} break;
|
| 15807 |
-
case GGML_OP_CONV_2D_STAGE_1:
|
| 15808 |
{
|
| 15809 |
GGML_ASSERT(false); // TODO: not implemented
|
| 15810 |
} break;
|
|
@@ -16533,31 +15728,11 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
|
| 16533 |
{
|
| 16534 |
n_tasks = 1; //TODO
|
| 16535 |
} break;
|
| 16536 |
-
case GGML_OP_CONV_1D:
|
| 16537 |
-
{
|
| 16538 |
-
n_tasks = n_threads;
|
| 16539 |
-
} break;
|
| 16540 |
-
case GGML_OP_CONV_1D_STAGE_0:
|
| 16541 |
-
{
|
| 16542 |
-
n_tasks = n_threads;
|
| 16543 |
-
} break;
|
| 16544 |
-
case GGML_OP_CONV_1D_STAGE_1:
|
| 16545 |
-
{
|
| 16546 |
-
n_tasks = n_threads;
|
| 16547 |
-
} break;
|
| 16548 |
case GGML_OP_CONV_TRANSPOSE_1D:
|
| 16549 |
{
|
| 16550 |
n_tasks = n_threads;
|
| 16551 |
} break;
|
| 16552 |
-
case
|
| 16553 |
-
{
|
| 16554 |
-
n_tasks = n_threads;
|
| 16555 |
-
} break;
|
| 16556 |
-
case GGML_OP_CONV_2D_STAGE_0:
|
| 16557 |
-
{
|
| 16558 |
-
n_tasks = n_threads;
|
| 16559 |
-
} break;
|
| 16560 |
-
case GGML_OP_CONV_2D_STAGE_1:
|
| 16561 |
{
|
| 16562 |
n_tasks = n_threads;
|
| 16563 |
} break;
|
|
@@ -16642,6 +15817,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
|
| 16642 |
} break;
|
| 16643 |
default:
|
| 16644 |
{
|
|
|
|
| 16645 |
GGML_ASSERT(false);
|
| 16646 |
} break;
|
| 16647 |
}
|
|
@@ -16844,38 +16020,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
|
|
| 16844 |
cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
|
| 16845 |
}
|
| 16846 |
} break;
|
| 16847 |
-
case GGML_OP_CONV_1D:
|
| 16848 |
-
{
|
| 16849 |
-
GGML_ASSERT(node->src[0]->ne[3] == 1);
|
| 16850 |
-
GGML_ASSERT(node->src[1]->ne[2] == 1);
|
| 16851 |
-
GGML_ASSERT(node->src[1]->ne[3] == 1);
|
| 16852 |
-
|
| 16853 |
-
const int64_t ne00 = node->src[0]->ne[0];
|
| 16854 |
-
const int64_t ne01 = node->src[0]->ne[1];
|
| 16855 |
-
const int64_t ne02 = node->src[0]->ne[2];
|
| 16856 |
-
|
| 16857 |
-
const int64_t ne10 = node->src[1]->ne[0];
|
| 16858 |
-
const int64_t ne11 = node->src[1]->ne[1];
|
| 16859 |
-
|
| 16860 |
-
const int64_t ne0 = node->ne[0];
|
| 16861 |
-
const int64_t ne1 = node->ne[1];
|
| 16862 |
-
const int64_t nk = ne00;
|
| 16863 |
-
const int64_t ew0 = nk * ne01;
|
| 16864 |
-
|
| 16865 |
-
UNUSED(ne02);
|
| 16866 |
-
UNUSED(ne10);
|
| 16867 |
-
UNUSED(ne11);
|
| 16868 |
-
|
| 16869 |
-
if (node->src[0]->type == GGML_TYPE_F16 &&
|
| 16870 |
-
node->src[1]->type == GGML_TYPE_F32) {
|
| 16871 |
-
cur = sizeof(ggml_fp16_t)*(ne0*ne1*ew0);
|
| 16872 |
-
} else if (node->src[0]->type == GGML_TYPE_F32 &&
|
| 16873 |
-
node->src[1]->type == GGML_TYPE_F32) {
|
| 16874 |
-
cur = sizeof(float)*(ne0*ne1*ew0);
|
| 16875 |
-
} else {
|
| 16876 |
-
GGML_ASSERT(false);
|
| 16877 |
-
}
|
| 16878 |
-
} break;
|
| 16879 |
case GGML_OP_CONV_TRANSPOSE_1D:
|
| 16880 |
{
|
| 16881 |
GGML_ASSERT(node->src[0]->ne[3] == 1);
|
|
@@ -16901,37 +16045,9 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
|
|
| 16901 |
GGML_ASSERT(false);
|
| 16902 |
}
|
| 16903 |
} break;
|
| 16904 |
-
case
|
| 16905 |
{
|
| 16906 |
-
|
| 16907 |
-
const int64_t ne01 = node->src[0]->ne[1]; // H
|
| 16908 |
-
const int64_t ne02 = node->src[0]->ne[2]; // C
|
| 16909 |
-
const int64_t ne03 = node->src[0]->ne[3]; // N
|
| 16910 |
-
|
| 16911 |
-
const int64_t ne10 = node->src[1]->ne[0]; // W
|
| 16912 |
-
const int64_t ne11 = node->src[1]->ne[1]; // H
|
| 16913 |
-
const int64_t ne12 = node->src[1]->ne[2]; // C
|
| 16914 |
-
|
| 16915 |
-
const int64_t ne0 = node->ne[0];
|
| 16916 |
-
const int64_t ne1 = node->ne[1];
|
| 16917 |
-
const int64_t ne2 = node->ne[2];
|
| 16918 |
-
const int64_t ne3 = node->ne[3];
|
| 16919 |
-
const int64_t nk = ne00*ne01;
|
| 16920 |
-
const int64_t ew0 = nk * ne02;
|
| 16921 |
-
|
| 16922 |
-
UNUSED(ne03);
|
| 16923 |
-
UNUSED(ne2);
|
| 16924 |
-
|
| 16925 |
-
if (node->src[0]->type == GGML_TYPE_F16 &&
|
| 16926 |
-
node->src[1]->type == GGML_TYPE_F32) {
|
| 16927 |
-
// im2col: [N*OH*OW, IC*KH*KW]
|
| 16928 |
-
cur = sizeof(ggml_fp16_t)*(ne3*ne0*ne1*ew0);
|
| 16929 |
-
} else if (node->src[0]->type == GGML_TYPE_F32 &&
|
| 16930 |
-
node->src[1]->type == GGML_TYPE_F32) {
|
| 16931 |
-
cur = sizeof(float)* (ne10*ne11*ne12);
|
| 16932 |
-
} else {
|
| 16933 |
-
GGML_ASSERT(false);
|
| 16934 |
-
}
|
| 16935 |
} break;
|
| 16936 |
case GGML_OP_CONV_TRANSPOSE_2D:
|
| 16937 |
{
|
|
|
|
| 1634 |
"ROPE_BACK",
|
| 1635 |
"ALIBI",
|
| 1636 |
"CLAMP",
|
|
|
|
|
|
|
|
|
|
| 1637 |
"CONV_TRANSPOSE_1D",
|
| 1638 |
+
"IM2COL",
|
|
|
|
|
|
|
| 1639 |
"CONV_TRANSPOSE_2D",
|
| 1640 |
"POOL_1D",
|
| 1641 |
"POOL_2D",
|
|
|
|
| 1666 |
"CROSS_ENTROPY_LOSS_BACK",
|
| 1667 |
};
|
| 1668 |
|
| 1669 |
+
static_assert(GGML_OP_COUNT == 68, "GGML_OP_COUNT != 68");
|
| 1670 |
|
| 1671 |
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
| 1672 |
"none",
|
|
|
|
| 1716 |
"rope_back(x)",
|
| 1717 |
"alibi(x)",
|
| 1718 |
"clamp(x)",
|
|
|
|
|
|
|
|
|
|
| 1719 |
"conv_transpose_1d(x)",
|
| 1720 |
+
"im2col(x)",
|
|
|
|
|
|
|
| 1721 |
"conv_transpose_2d(x)",
|
| 1722 |
"pool_1d(x)",
|
| 1723 |
"pool_2d(x)",
|
|
|
|
| 1748 |
"cross_entropy_loss_back(x,y)",
|
| 1749 |
};
|
| 1750 |
|
| 1751 |
+
static_assert(GGML_OP_COUNT == 68, "GGML_OP_COUNT != 68");
|
| 1752 |
|
| 1753 |
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
| 1754 |
|
|
|
|
| 1776 |
p[GGML_OP_GET_ROWS_BACK ] = true;
|
| 1777 |
p[GGML_OP_DIAG_MASK_INF ] = true;
|
| 1778 |
p[GGML_OP_DIAG_MASK_ZERO ] = true;
|
|
|
|
|
|
|
|
|
|
| 1779 |
p[GGML_OP_CONV_TRANSPOSE_1D ] = true;
|
|
|
|
|
|
|
|
|
|
| 1780 |
p[GGML_OP_CONV_TRANSPOSE_2D ] = true;
|
| 1781 |
p[GGML_OP_FLASH_ATTN_BACK ] = true;
|
| 1782 |
p[GGML_OP_CROSS_ENTROPY_LOSS ] = true;
|
|
|
|
| 5121 |
return (ins + 2 * p - d * (ks - 1) - 1) / s + 1;
|
| 5122 |
}
|
| 5123 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5124 |
GGML_API struct ggml_tensor * ggml_conv_1d(
|
| 5125 |
struct ggml_context * ctx,
|
| 5126 |
struct ggml_tensor * a,
|
|
|
|
| 5128 |
int s0,
|
| 5129 |
int p0,
|
| 5130 |
int d0) {
|
| 5131 |
+
struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, 0, p0, 0, d0, 0, false); // [N, OL, IC * K]
|
|
|
|
|
|
|
|
|
|
| 5132 |
|
| 5133 |
+
struct ggml_tensor * result =
|
| 5134 |
+
ggml_mul_mat(ctx,
|
| 5135 |
+
ggml_reshape_2d(ctx, im2col, im2col->ne[0], (im2col->ne[2] * im2col->ne[1])), // [N, OL, IC * K] => [N*OL, IC * K]
|
| 5136 |
+
ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1]), a->ne[2])); // [OC,IC, K] => [OC, IC * K]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5137 |
|
| 5138 |
+
result = ggml_reshape_3d(ctx, result, im2col->ne[1], a->ne[2], im2col->ne[2]); // [N, OC, OL]
|
|
|
|
|
|
|
|
|
|
| 5139 |
|
| 5140 |
+
return result;
|
| 5141 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5142 |
|
| 5143 |
// ggml_conv_1d_ph
|
| 5144 |
|
|
|
|
| 5201 |
// a: [OC,IC, KH, KW]
|
| 5202 |
// b: [N, IC, IH, IW]
|
| 5203 |
// result: [N, OH, OW, IC*KH*KW]
|
| 5204 |
+
struct ggml_tensor * ggml_im2col(
|
| 5205 |
struct ggml_context * ctx,
|
| 5206 |
struct ggml_tensor * a,
|
| 5207 |
struct ggml_tensor * b,
|
|
|
|
| 5210 |
int p0,
|
| 5211 |
int p1,
|
| 5212 |
int d0,
|
| 5213 |
+
int d1,
|
| 5214 |
+
bool is_2D) {
|
| 5215 |
|
| 5216 |
+
if(is_2D) {
|
| 5217 |
+
GGML_ASSERT(a->ne[2] == b->ne[2]);
|
| 5218 |
+
} else {
|
| 5219 |
+
GGML_ASSERT(a->ne[1] == b->ne[1]);
|
| 5220 |
+
}
|
| 5221 |
bool is_node = false;
|
| 5222 |
|
| 5223 |
if (a->grad || b->grad) {
|
|
|
|
| 5225 |
is_node = true;
|
| 5226 |
}
|
| 5227 |
|
| 5228 |
+
const int64_t OH = is_2D ? ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1) : 0;
|
| 5229 |
+
const int64_t OW = ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0);
|
| 5230 |
|
| 5231 |
const int64_t ne[4] = {
|
| 5232 |
+
is_2D ? (a->ne[2] * a->ne[1] * a->ne[0]) : a->ne[1] * a->ne[0],
|
| 5233 |
OW,
|
| 5234 |
+
is_2D ? OH : b->ne[2],
|
| 5235 |
+
is_2D ? b->ne[3] : 1,
|
| 5236 |
};
|
|
|
|
| 5237 |
|
| 5238 |
+
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F16, 4, ne);
|
| 5239 |
+
int32_t params[] = { s0, s1, p0, p1, d0, d1, (is_2D ? 1 : 0) };
|
| 5240 |
ggml_set_op_params(result, params, sizeof(params));
|
| 5241 |
|
| 5242 |
+
result->op = GGML_OP_IM2COL;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5243 |
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
| 5244 |
result->src[0] = a;
|
| 5245 |
result->src[1] = b;
|
| 5246 |
|
| 5247 |
return result;
|
|
|
|
| 5248 |
}
|
| 5249 |
|
| 5250 |
// a: [OC,IC, KH, KW]
|
| 5251 |
// b: [N, IC, IH, IW]
|
| 5252 |
// result: [N, OC, OH, OW]
|
| 5253 |
struct ggml_tensor * ggml_conv_2d(
|
| 5254 |
+
struct ggml_context * ctx,
|
| 5255 |
+
struct ggml_tensor * a,
|
| 5256 |
+
struct ggml_tensor * b,
|
| 5257 |
+
int s0,
|
| 5258 |
+
int s1,
|
| 5259 |
+
int p0,
|
| 5260 |
+
int p1,
|
| 5261 |
+
int d0,
|
| 5262 |
+
int d1) {
|
| 5263 |
+
struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, s1, p0, p1, d0, d1, true); // [N, OH, OW, IC * KH * KW]
|
| 5264 |
|
| 5265 |
+
struct ggml_tensor * result =
|
| 5266 |
+
ggml_mul_mat(ctx,
|
| 5267 |
+
ggml_reshape_2d(ctx, im2col, im2col->ne[0], im2col->ne[3] * im2col->ne[2] * im2col->ne[1]), // [N, OH, OW, IC * KH * KW] => [N*OH*OW, IC * KH * KW]
|
| 5268 |
+
ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1] * a->ne[2]), a->ne[3])); // [OC,IC, KH, KW] => [OC, IC * KH * KW]
|
| 5269 |
|
| 5270 |
+
result = ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], a->ne[3], im2col->ne[3]); // [N, OC, OH, OW]
|
| 5271 |
|
| 5272 |
+
return result;
|
| 5273 |
}
|
| 5274 |
|
| 5275 |
// ggml_conv_2d_sk_p0
|
|
|
|
| 9364 |
// TODO: find the optimal values for these
|
| 9365 |
if (ggml_is_contiguous(src0) &&
|
| 9366 |
ggml_is_contiguous(src1) &&
|
| 9367 |
+
src0->type == GGML_TYPE_F32 &&
|
| 9368 |
+
src1->type == GGML_TYPE_F32 &&
|
| 9369 |
(ne0 >= 32 && ne1 >= 32 && ne10 >= 32)) {
|
| 9370 |
|
| 9371 |
/*printf("BLAS: %d %d %d %d %d\n", ne0, ne1, ne10, ne00, ne01);*/
|
|
|
|
| 9376 |
}
|
| 9377 |
#endif
|
| 9378 |
|
| 9379 |
+
|
| 9380 |
static void ggml_compute_forward_mul_mat(
|
| 9381 |
const struct ggml_compute_params * params,
|
| 9382 |
const struct ggml_tensor * src0,
|
|
|
|
| 9405 |
|
| 9406 |
// we don't support permuted src0 or src1
|
| 9407 |
GGML_ASSERT(nb00 == ggml_type_size(type));
|
| 9408 |
+
GGML_ASSERT(nb10 == ggml_type_size(src1->type));
|
| 9409 |
|
| 9410 |
// dst cannot be transposed or permuted
|
| 9411 |
GGML_ASSERT(nb0 == sizeof(float));
|
|
|
|
| 11497 |
}
|
| 11498 |
}
|
| 11499 |
|
| 11500 |
+
// ggml_compute_forward_conv_transpose_1d
|
| 11501 |
|
| 11502 |
+
static void ggml_compute_forward_conv_transpose_1d_f16_f32(
|
| 11503 |
const struct ggml_compute_params * params,
|
| 11504 |
const struct ggml_tensor * src0,
|
| 11505 |
const struct ggml_tensor * src1,
|
|
|
|
| 11516 |
const int ith = params->ith;
|
| 11517 |
const int nth = params->nth;
|
| 11518 |
|
| 11519 |
+
const int nk = ne00*ne01*ne02;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11520 |
|
| 11521 |
GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
|
| 11522 |
GGML_ASSERT(nb10 == sizeof(float));
|
|
|
|
| 11524 |
if (params->type == GGML_TASK_INIT) {
|
| 11525 |
memset(params->wdata, 0, params->wsize);
|
| 11526 |
|
| 11527 |
+
// permute kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout)
|
| 11528 |
+
{
|
| 11529 |
+
ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
|
| 11530 |
|
| 11531 |
+
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
| 11532 |
+
for (int64_t i01 = 0; i01 < ne01; i01++) {
|
| 11533 |
+
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01);
|
| 11534 |
+
ggml_fp16_t * dst_data = wdata + i01*ne00*ne02;
|
| 11535 |
+
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
| 11536 |
+
dst_data[i00*ne02 + i02] = src[i00];
|
| 11537 |
+
}
|
| 11538 |
+
}
|
| 11539 |
+
}
|
| 11540 |
+
}
|
| 11541 |
|
| 11542 |
+
// permute source data (src1) from (L x Cin) to (Cin x L)
|
| 11543 |
+
{
|
| 11544 |
+
ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + nk;
|
| 11545 |
+
ggml_fp16_t * dst_data = wdata;
|
| 11546 |
|
| 11547 |
+
for (int64_t i11 = 0; i11 < ne11; i11++) {
|
| 11548 |
+
const float * const src = (float *)((char *) src1->data + i11*nb11);
|
| 11549 |
+
for (int64_t i10 = 0; i10 < ne10; i10++) {
|
| 11550 |
+
dst_data[i10*ne11 + i11] = GGML_FP32_TO_FP16(src[i10]);
|
| 11551 |
}
|
| 11552 |
}
|
| 11553 |
}
|
| 11554 |
|
| 11555 |
+
// need to zero dst since we are accumulating into it
|
| 11556 |
+
memset(dst->data, 0, ggml_nbytes(dst));
|
| 11557 |
+
|
| 11558 |
return;
|
| 11559 |
}
|
| 11560 |
|
|
|
|
| 11562 |
return;
|
| 11563 |
}
|
| 11564 |
|
| 11565 |
+
const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
|
| 11566 |
+
|
| 11567 |
// total rows in dst
|
| 11568 |
+
const int nr = ne1;
|
| 11569 |
|
| 11570 |
// rows per thread
|
| 11571 |
const int dr = (nr + nth - 1)/nth;
|
|
|
|
| 11574 |
const int ir0 = dr*ith;
|
| 11575 |
const int ir1 = MIN(ir0 + dr, nr);
|
| 11576 |
|
| 11577 |
+
ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
|
| 11578 |
+
ggml_fp16_t * const wdata_src = wdata + nk;
|
|
|
|
|
|
|
|
|
|
| 11579 |
|
| 11580 |
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
| 11581 |
+
float * dst_data = (float *)((char *) dst->data + i1*nb1);
|
| 11582 |
+
ggml_fp16_t * wdata_kernel = wdata + i1*ne02*ne00;
|
| 11583 |
+
for (int i10 = 0; i10 < ne10; i10++) {
|
| 11584 |
+
const int i1n = i10*ne11;
|
| 11585 |
+
for (int i00 = 0; i00 < ne00; i00++) {
|
| 11586 |
+
float v = 0;
|
| 11587 |
+
ggml_vec_dot_f16(ne02, &v,
|
| 11588 |
+
(ggml_fp16_t *) wdata_src + i1n,
|
| 11589 |
+
(ggml_fp16_t *) wdata_kernel + i00*ne02);
|
| 11590 |
+
dst_data[i10*s0 + i00] += v;
|
| 11591 |
}
|
| 11592 |
}
|
| 11593 |
}
|
| 11594 |
}
|
| 11595 |
|
| 11596 |
+
static void ggml_compute_forward_conv_transpose_1d_f32(
|
| 11597 |
const struct ggml_compute_params * params,
|
| 11598 |
const struct ggml_tensor * src0,
|
| 11599 |
const struct ggml_tensor * src1,
|
|
|
|
| 11610 |
const int ith = params->ith;
|
| 11611 |
const int nth = params->nth;
|
| 11612 |
|
| 11613 |
+
const int nk = ne00*ne01*ne02;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11614 |
|
| 11615 |
GGML_ASSERT(nb00 == sizeof(float));
|
| 11616 |
GGML_ASSERT(nb10 == sizeof(float));
|
|
|
|
| 11618 |
if (params->type == GGML_TASK_INIT) {
|
| 11619 |
memset(params->wdata, 0, params->wsize);
|
| 11620 |
|
| 11621 |
+
// prepare kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout)
|
| 11622 |
+
{
|
| 11623 |
+
float * const wdata = (float *) params->wdata + 0;
|
| 11624 |
|
| 11625 |
+
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
| 11626 |
+
for (int64_t i01 = 0; i01 < ne01; i01++) {
|
| 11627 |
+
const float * const src = (float *)((char *) src0->data + i02*nb02 + i01*nb01);
|
| 11628 |
+
float * dst_data = wdata + i01*ne00*ne02;
|
| 11629 |
+
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
| 11630 |
+
dst_data[i00*ne02 + i02] = src[i00];
|
| 11631 |
+
}
|
| 11632 |
+
}
|
| 11633 |
+
}
|
| 11634 |
+
}
|
| 11635 |
|
| 11636 |
+
// prepare source data (src1)
|
| 11637 |
+
{
|
| 11638 |
+
float * const wdata = (float *) params->wdata + nk;
|
| 11639 |
+
float * dst_data = wdata;
|
| 11640 |
|
| 11641 |
+
for (int64_t i11 = 0; i11 < ne11; i11++) {
|
| 11642 |
+
const float * const src = (float *)((char *) src1->data + i11*nb11);
|
| 11643 |
+
for (int64_t i10 = 0; i10 < ne10; i10++) {
|
| 11644 |
+
dst_data[i10*ne11 + i11] = src[i10];
|
| 11645 |
}
|
| 11646 |
}
|
| 11647 |
}
|
| 11648 |
|
| 11649 |
+
// need to zero dst since we are accumulating into it
|
| 11650 |
+
memset(dst->data, 0, ggml_nbytes(dst));
|
| 11651 |
+
|
| 11652 |
return;
|
| 11653 |
}
|
| 11654 |
|
|
|
|
| 11656 |
return;
|
| 11657 |
}
|
| 11658 |
|
| 11659 |
+
const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
|
| 11660 |
+
|
| 11661 |
// total rows in dst
|
| 11662 |
+
const int nr = ne1;
|
| 11663 |
|
| 11664 |
// rows per thread
|
| 11665 |
const int dr = (nr + nth - 1)/nth;
|
|
|
|
| 11668 |
const int ir0 = dr*ith;
|
| 11669 |
const int ir1 = MIN(ir0 + dr, nr);
|
| 11670 |
|
| 11671 |
+
float * const wdata = (float *) params->wdata + 0;
|
| 11672 |
+
float * const wdata_src = wdata + nk;
|
|
|
|
|
|
|
|
|
|
| 11673 |
|
| 11674 |
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
| 11675 |
+
float * dst_data = (float *)((char *) dst->data + i1*nb1);
|
| 11676 |
+
float * wdata_kernel = wdata + i1*ne02*ne00;
|
| 11677 |
+
for (int i10 = 0; i10 < ne10; i10++) {
|
| 11678 |
+
const int i1n = i10*ne11;
|
| 11679 |
+
for (int i00 = 0; i00 < ne00; i00++) {
|
| 11680 |
+
float v = 0;
|
| 11681 |
+
ggml_vec_dot_f32(ne02, &v,
|
| 11682 |
+
wdata_src + i1n,
|
| 11683 |
+
wdata_kernel + i00*ne02);
|
| 11684 |
+
dst_data[i10*s0 + i00] += v;
|
| 11685 |
}
|
| 11686 |
}
|
| 11687 |
}
|
| 11688 |
}
|
| 11689 |
|
| 11690 |
+
static void ggml_compute_forward_conv_transpose_1d(
|
| 11691 |
+
const struct ggml_compute_params * params,
|
| 11692 |
+
const struct ggml_tensor * src0,
|
| 11693 |
+
const struct ggml_tensor * src1,
|
| 11694 |
+
struct ggml_tensor * dst) {
|
| 11695 |
+
switch (src0->type) {
|
| 11696 |
+
case GGML_TYPE_F16:
|
| 11697 |
+
{
|
| 11698 |
+
ggml_compute_forward_conv_transpose_1d_f16_f32(params, src0, src1, dst);
|
| 11699 |
+
} break;
|
| 11700 |
+
case GGML_TYPE_F32:
|
| 11701 |
+
{
|
| 11702 |
+
ggml_compute_forward_conv_transpose_1d_f32(params, src0, src1, dst);
|
| 11703 |
+
} break;
|
| 11704 |
+
default:
|
| 11705 |
+
{
|
| 11706 |
+
GGML_ASSERT(false);
|
| 11707 |
+
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11708 |
}
|
| 11709 |
}
|
| 11710 |
|
| 11711 |
+
// src0: kernel [OC, IC, KH, KW]
|
| 11712 |
+
// src1: image [N, IC, IH, IW]
|
| 11713 |
+
// dst: result [N, OH, OW, IC*KH*KW]
|
| 11714 |
+
static void ggml_compute_forward_im2col_f16(
|
| 11715 |
const struct ggml_compute_params * params,
|
| 11716 |
const struct ggml_tensor * src0,
|
| 11717 |
const struct ggml_tensor * src1,
|
|
|
|
| 11725 |
|
| 11726 |
GGML_TENSOR_BINARY_OP_LOCALS;
|
| 11727 |
|
| 11728 |
+
const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
|
| 11729 |
+
const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
|
| 11730 |
+
const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
|
| 11731 |
+
const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
|
| 11732 |
+
const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
|
| 11733 |
+
const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
|
| 11734 |
+
const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
|
| 11735 |
|
| 11736 |
const int ith = params->ith;
|
| 11737 |
const int nth = params->nth;
|
| 11738 |
|
| 11739 |
+
const int64_t N = is_2D ? ne13 : ne12;
|
| 11740 |
+
const int64_t IC = is_2D ? ne12 : ne11;
|
| 11741 |
+
const int64_t IH = is_2D ? ne11 : 1;
|
| 11742 |
+
const int64_t IW = ne10;
|
| 11743 |
+
|
| 11744 |
+
const int64_t KH = is_2D ? ne01 : 1;
|
| 11745 |
+
const int64_t KW = ne00;
|
| 11746 |
+
|
| 11747 |
+
const int64_t OH = is_2D ? ne2 : 1;
|
| 11748 |
+
const int64_t OW = ne1;
|
| 11749 |
+
|
| 11750 |
+
int ofs0 = is_2D ? nb13 : nb12;
|
| 11751 |
+
int ofs1 = is_2D ? nb12 : nb11;
|
| 11752 |
|
| 11753 |
GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
|
| 11754 |
GGML_ASSERT(nb10 == sizeof(float));
|
| 11755 |
|
| 11756 |
if (params->type == GGML_TASK_INIT) {
|
|
|
|
| 11757 |
return;
|
| 11758 |
}
|
| 11759 |
|
|
|
|
| 11761 |
return;
|
| 11762 |
}
|
| 11763 |
|
| 11764 |
+
// im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
|
| 11765 |
{
|
| 11766 |
ggml_fp16_t * const wdata = (ggml_fp16_t *) dst->data;
|
| 11767 |
|
| 11768 |
for (int64_t in = 0; in < N; in++) {
|
| 11769 |
+
for (int64_t ioh = 0; ioh < OH; ioh++) { // 1
|
| 11770 |
+
for (int64_t iow = 0; iow < OW; iow++) {
|
| 11771 |
+
for (int64_t iic = ith; iic < IC; iic += nth) {
|
| 11772 |
|
| 11773 |
+
// micro kernel
|
| 11774 |
+
ggml_fp16_t * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
|
| 11775 |
+
const float * const src_data = (float *)((char *) src1->data + in*ofs0 + iic*ofs1); // [IH, IW]
|
| 11776 |
|
| 11777 |
+
for (int64_t ikh = 0; ikh < KH; ikh++) { // 1
|
| 11778 |
+
for (int64_t ikw = 0; ikw < KW; ikw++) {
|
| 11779 |
+
const int64_t iiw = iow*s0 + ikw*d0 - p0;
|
| 11780 |
+
const int64_t iih = ioh*s1 + ikh*d1 - p1;
|
| 11781 |
|
| 11782 |
+
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
|
| 11783 |
+
dst_data[iic*(KH*KW) + ikh*KW + ikw] = 0;
|
| 11784 |
+
} else {
|
| 11785 |
+
dst_data[iic*(KH*KW) + ikh*KW + ikw] = GGML_FP32_TO_FP16(src_data[iih*IW + iiw]);
|
| 11786 |
+
}
|
| 11787 |
+
}
|
| 11788 |
}
|
| 11789 |
}
|
| 11790 |
}
|
|
|
|
| 11793 |
}
|
| 11794 |
}
|
| 11795 |
|
| 11796 |
+
static void ggml_compute_forward_im2col(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11797 |
const struct ggml_compute_params * params,
|
| 11798 |
const struct ggml_tensor * src0,
|
| 11799 |
const struct ggml_tensor * src1,
|
|
|
|
| 11801 |
switch (src0->type) {
|
| 11802 |
case GGML_TYPE_F16:
|
| 11803 |
{
|
| 11804 |
+
ggml_compute_forward_im2col_f16(params, src0, src1, dst);
|
| 11805 |
} break;
|
| 11806 |
case GGML_TYPE_F32:
|
| 11807 |
{
|
|
|
|
| 14018 |
{
|
| 14019 |
ggml_compute_forward_clamp(params, tensor->src[0], tensor);
|
| 14020 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14021 |
case GGML_OP_CONV_TRANSPOSE_1D:
|
| 14022 |
{
|
| 14023 |
ggml_compute_forward_conv_transpose_1d(params, tensor->src[0], tensor->src[1], tensor);
|
| 14024 |
} break;
|
| 14025 |
+
case GGML_OP_IM2COL:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14026 |
{
|
| 14027 |
+
ggml_compute_forward_im2col(params, tensor->src[0], tensor->src[1], tensor);
|
| 14028 |
} break;
|
| 14029 |
case GGML_OP_CONV_TRANSPOSE_2D:
|
| 14030 |
{
|
|
|
|
| 14995 |
{
|
| 14996 |
GGML_ASSERT(false); // TODO: not implemented
|
| 14997 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14998 |
case GGML_OP_CONV_TRANSPOSE_1D:
|
| 14999 |
{
|
| 15000 |
GGML_ASSERT(false); // TODO: not implemented
|
| 15001 |
} break;
|
| 15002 |
+
case GGML_OP_IM2COL:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15003 |
{
|
| 15004 |
GGML_ASSERT(false); // TODO: not implemented
|
| 15005 |
} break;
|
|
|
|
| 15728 |
{
|
| 15729 |
n_tasks = 1; //TODO
|
| 15730 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15731 |
case GGML_OP_CONV_TRANSPOSE_1D:
|
| 15732 |
{
|
| 15733 |
n_tasks = n_threads;
|
| 15734 |
} break;
|
| 15735 |
+
case GGML_OP_IM2COL:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15736 |
{
|
| 15737 |
n_tasks = n_threads;
|
| 15738 |
} break;
|
|
|
|
| 15817 |
} break;
|
| 15818 |
default:
|
| 15819 |
{
|
| 15820 |
+
printf("%s: op %s not implemented\n", __func__, ggml_op_name(node->op));
|
| 15821 |
GGML_ASSERT(false);
|
| 15822 |
} break;
|
| 15823 |
}
|
|
|
|
| 16020 |
cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
|
| 16021 |
}
|
| 16022 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16023 |
case GGML_OP_CONV_TRANSPOSE_1D:
|
| 16024 |
{
|
| 16025 |
GGML_ASSERT(node->src[0]->ne[3] == 1);
|
|
|
|
| 16045 |
GGML_ASSERT(false);
|
| 16046 |
}
|
| 16047 |
} break;
|
| 16048 |
+
case GGML_OP_IM2COL:
|
| 16049 |
{
|
| 16050 |
+
n_tasks = n_threads;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16051 |
} break;
|
| 16052 |
case GGML_OP_CONV_TRANSPOSE_2D:
|
| 16053 |
{
|
|
@@ -403,13 +403,8 @@ extern "C" {
|
|
| 403 |
GGML_OP_ROPE_BACK,
|
| 404 |
GGML_OP_ALIBI,
|
| 405 |
GGML_OP_CLAMP,
|
| 406 |
-
GGML_OP_CONV_1D,
|
| 407 |
-
GGML_OP_CONV_1D_STAGE_0, // internal
|
| 408 |
-
GGML_OP_CONV_1D_STAGE_1, // internal
|
| 409 |
GGML_OP_CONV_TRANSPOSE_1D,
|
| 410 |
-
|
| 411 |
-
GGML_OP_CONV_2D_STAGE_0, // internal
|
| 412 |
-
GGML_OP_CONV_2D_STAGE_1, // internal
|
| 413 |
GGML_OP_CONV_TRANSPOSE_2D,
|
| 414 |
GGML_OP_POOL_1D,
|
| 415 |
GGML_OP_POOL_2D,
|
|
@@ -1398,6 +1393,18 @@ extern "C" {
|
|
| 1398 |
float min,
|
| 1399 |
float max);
|
| 1400 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1401 |
GGML_API struct ggml_tensor * ggml_conv_1d(
|
| 1402 |
struct ggml_context * ctx,
|
| 1403 |
struct ggml_tensor * a,
|
|
|
|
| 403 |
GGML_OP_ROPE_BACK,
|
| 404 |
GGML_OP_ALIBI,
|
| 405 |
GGML_OP_CLAMP,
|
|
|
|
|
|
|
|
|
|
| 406 |
GGML_OP_CONV_TRANSPOSE_1D,
|
| 407 |
+
GGML_OP_IM2COL,
|
|
|
|
|
|
|
| 408 |
GGML_OP_CONV_TRANSPOSE_2D,
|
| 409 |
GGML_OP_POOL_1D,
|
| 410 |
GGML_OP_POOL_2D,
|
|
|
|
| 1393 |
float min,
|
| 1394 |
float max);
|
| 1395 |
|
| 1396 |
+
GGML_API struct ggml_tensor * ggml_im2col(
|
| 1397 |
+
struct ggml_context * ctx,
|
| 1398 |
+
struct ggml_tensor * a,
|
| 1399 |
+
struct ggml_tensor * b,
|
| 1400 |
+
int s0,
|
| 1401 |
+
int s1,
|
| 1402 |
+
int p0,
|
| 1403 |
+
int p1,
|
| 1404 |
+
int d0,
|
| 1405 |
+
int d1,
|
| 1406 |
+
bool is_2D);
|
| 1407 |
+
|
| 1408 |
GGML_API struct ggml_tensor * ggml_conv_1d(
|
| 1409 |
struct ggml_context * ctx,
|
| 1410 |
struct ggml_tensor * a,
|
|
@@ -1,10 +1,15 @@
|
|
| 1 |
#include "whisper.h"
|
|
|
|
| 2 |
#ifdef WHISPER_USE_COREML
|
| 3 |
#include "coreml/whisper-encoder.h"
|
| 4 |
#endif
|
| 5 |
|
| 6 |
#ifdef GGML_USE_METAL
|
| 7 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
#endif
|
| 9 |
|
| 10 |
#ifdef WHISPER_USE_OPENVINO
|
|
@@ -13,6 +18,7 @@
|
|
| 13 |
|
| 14 |
#include "ggml.h"
|
| 15 |
#include "ggml-alloc.h"
|
|
|
|
| 16 |
|
| 17 |
#include <algorithm>
|
| 18 |
#include <cassert>
|
|
@@ -97,10 +103,32 @@ static void byteswap_tensor(ggml_tensor * tensor) {
|
|
| 97 |
#define BYTESWAP_TENSOR(t) do {} while (0)
|
| 98 |
#endif
|
| 99 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
#define WHISPER_ASSERT(x) \
|
| 101 |
do { \
|
| 102 |
if (!(x)) { \
|
| 103 |
-
|
| 104 |
abort(); \
|
| 105 |
} \
|
| 106 |
} while (0)
|
|
@@ -127,8 +155,8 @@ static void byteswap_tensor(ggml_tensor * tensor) {
|
|
| 127 |
//
|
| 128 |
|
| 129 |
static void ggml_graph_compute_helper(
|
|
|
|
| 130 |
std::vector<uint8_t> & buf,
|
| 131 |
-
ggml_cgraph * graph,
|
| 132 |
int n_threads,
|
| 133 |
whisper_abort_callback abort_callback,
|
| 134 |
void * abort_callback_data) {
|
|
@@ -145,6 +173,21 @@ static void ggml_graph_compute_helper(
|
|
| 145 |
ggml_graph_compute(graph, &plan);
|
| 146 |
}
|
| 147 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
// faster matrix multiplications for tensors that do not have dimension 0 divisible by "pad"
|
| 149 |
// the idea is to represent the original matrix multiplication:
|
| 150 |
//
|
|
@@ -179,6 +222,7 @@ static struct ggml_tensor * ggml_mul_mat_pad(struct ggml_context * ctx, struct g
|
|
| 179 |
}
|
| 180 |
|
| 181 |
// TODO: check if other platforms can benefit from this optimization
|
|
|
|
| 182 |
#if defined(GGML_USE_METAL)
|
| 183 |
#define ggml_mul_mat ggml_mul_mat_pad
|
| 184 |
#endif
|
|
@@ -305,75 +349,6 @@ static const std::map<std::string, std::pair<int, std::string>> g_lang = {
|
|
| 305 |
{ "yue", { 99, "cantonese", } },
|
| 306 |
};
|
| 307 |
|
| 308 |
-
static const size_t MB = 1ull*1024*1024;
|
| 309 |
-
|
| 310 |
-
// TODO: avoid using GGUF
|
| 311 |
-
static const std::map<ggml_type, std::map<e_model, size_t>> MEM_REQ_MODEL = {
|
| 312 |
-
{ GGML_TYPE_F32,
|
| 313 |
-
{
|
| 314 |
-
{ MODEL_TINY, 74ull*MB },
|
| 315 |
-
{ MODEL_BASE, 142ull*MB },
|
| 316 |
-
{ MODEL_SMALL, 466ull*MB },
|
| 317 |
-
{ MODEL_MEDIUM, 1464ull*MB },
|
| 318 |
-
{ MODEL_LARGE, 2952ull*MB },
|
| 319 |
-
},
|
| 320 |
-
},
|
| 321 |
-
{ GGML_TYPE_F16,
|
| 322 |
-
{
|
| 323 |
-
{ MODEL_TINY, 74ull*MB },
|
| 324 |
-
{ MODEL_BASE, 142ull*MB },
|
| 325 |
-
{ MODEL_SMALL, 466ull*MB },
|
| 326 |
-
{ MODEL_MEDIUM, 1464ull*MB },
|
| 327 |
-
{ MODEL_LARGE, 2952ull*MB },
|
| 328 |
-
},
|
| 329 |
-
},
|
| 330 |
-
{ GGML_TYPE_Q4_0,
|
| 331 |
-
{
|
| 332 |
-
{ MODEL_TINY, 26ull*MB },
|
| 333 |
-
{ MODEL_BASE, 50ull*MB },
|
| 334 |
-
{ MODEL_SMALL, 154ull*MB },
|
| 335 |
-
{ MODEL_MEDIUM, 470ull*MB },
|
| 336 |
-
{ MODEL_LARGE, 940ull*MB },
|
| 337 |
-
},
|
| 338 |
-
},
|
| 339 |
-
{ GGML_TYPE_Q4_1,
|
| 340 |
-
{
|
| 341 |
-
{ MODEL_TINY, 32ull*MB },
|
| 342 |
-
{ MODEL_BASE, 58ull*MB },
|
| 343 |
-
{ MODEL_SMALL, 182ull*MB },
|
| 344 |
-
{ MODEL_MEDIUM, 562ull*MB },
|
| 345 |
-
{ MODEL_LARGE, 1124ull*MB },
|
| 346 |
-
},
|
| 347 |
-
},
|
| 348 |
-
{ GGML_TYPE_Q5_0,
|
| 349 |
-
{
|
| 350 |
-
{ MODEL_TINY, 30ull*MB },
|
| 351 |
-
{ MODEL_BASE, 54ull*MB },
|
| 352 |
-
{ MODEL_SMALL, 170ull*MB },
|
| 353 |
-
{ MODEL_MEDIUM, 516ull*MB },
|
| 354 |
-
{ MODEL_LARGE, 1034ull*MB },
|
| 355 |
-
},
|
| 356 |
-
},
|
| 357 |
-
{ GGML_TYPE_Q5_1,
|
| 358 |
-
{
|
| 359 |
-
{ MODEL_TINY, 32ull*MB },
|
| 360 |
-
{ MODEL_BASE, 58ull*MB },
|
| 361 |
-
{ MODEL_SMALL, 182ull*MB },
|
| 362 |
-
{ MODEL_MEDIUM, 562ull*MB },
|
| 363 |
-
{ MODEL_LARGE, 1124ull*MB },
|
| 364 |
-
},
|
| 365 |
-
},
|
| 366 |
-
{ GGML_TYPE_Q8_0,
|
| 367 |
-
{
|
| 368 |
-
{ MODEL_TINY, 45ull*MB },
|
| 369 |
-
{ MODEL_BASE, 84ull*MB },
|
| 370 |
-
{ MODEL_SMALL, 268ull*MB },
|
| 371 |
-
{ MODEL_MEDIUM, 834ull*MB },
|
| 372 |
-
{ MODEL_LARGE, 1674ull*MB },
|
| 373 |
-
},
|
| 374 |
-
},
|
| 375 |
-
};
|
| 376 |
-
|
| 377 |
struct whisper_mel {
|
| 378 |
int n_len;
|
| 379 |
int n_len_org;
|
|
@@ -554,8 +529,7 @@ struct whisper_kv_cache {
|
|
| 554 |
|
| 555 |
struct ggml_context * ctx;
|
| 556 |
|
| 557 |
-
|
| 558 |
-
std::vector<uint8_t> buf;
|
| 559 |
|
| 560 |
int n; // number of tokens currently in the cache
|
| 561 |
};
|
|
@@ -594,11 +568,11 @@ struct whisper_model {
|
|
| 594 |
std::vector<whisper_layer_encoder> layers_encoder;
|
| 595 |
std::vector<whisper_layer_decoder> layers_decoder;
|
| 596 |
|
| 597 |
-
// context
|
| 598 |
struct ggml_context * ctx;
|
| 599 |
|
| 600 |
-
// the model
|
| 601 |
-
|
| 602 |
|
| 603 |
// tensors
|
| 604 |
int n_loaded;
|
|
@@ -663,37 +637,47 @@ struct whisper_allocr {
|
|
| 663 |
ggml_allocr * alloc = nullptr;
|
| 664 |
|
| 665 |
std::vector<uint8_t> meta;
|
| 666 |
-
|
|
|
|
| 667 |
};
|
| 668 |
|
| 669 |
static size_t whisper_allocr_size(struct whisper_allocr & allocr) {
|
| 670 |
-
return allocr.meta.size() + allocr.
|
| 671 |
}
|
| 672 |
|
| 673 |
// measure the memory usage of a graph and prepare the allocr's internal data buffer
|
| 674 |
-
static void whisper_allocr_graph_init(struct whisper_allocr & allocr, std::function<struct ggml_cgraph *()> && get_graph) {
|
| 675 |
-
|
|
|
|
| 676 |
|
| 677 |
-
|
| 678 |
-
auto & meta = allocr.meta;
|
| 679 |
-
auto & data = allocr.data;
|
| 680 |
|
| 681 |
meta.resize(ggml_tensor_overhead()*WHISPER_MAX_NODES + ggml_graph_overhead());
|
| 682 |
|
| 683 |
-
alloc
|
|
|
|
| 684 |
|
| 685 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 686 |
|
| 687 |
-
|
|
|
|
| 688 |
|
| 689 |
-
|
| 690 |
|
| 691 |
-
alloc
|
|
|
|
|
|
|
|
|
|
| 692 |
}
|
| 693 |
|
| 694 |
static void whisper_allocr_free(struct whisper_allocr & allocr) {
|
| 695 |
if (allocr.alloc) {
|
| 696 |
ggml_allocr_free(allocr.alloc);
|
|
|
|
| 697 |
allocr.alloc = nullptr;
|
| 698 |
}
|
| 699 |
}
|
|
@@ -722,8 +706,7 @@ struct whisper_state {
|
|
| 722 |
// buffer for swapping KV caches between decoders during beam-search
|
| 723 |
std::vector<kv_buf> kv_swap_bufs;
|
| 724 |
|
| 725 |
-
|
| 726 |
-
std::vector<uint8_t> work_buffer;
|
| 727 |
|
| 728 |
// ggml-alloc:
|
| 729 |
// - stores meta info about the intermediate tensors into the `meta` buffers
|
|
@@ -737,6 +720,9 @@ struct whisper_state {
|
|
| 737 |
struct ggml_tensor * embd_conv = nullptr;
|
| 738 |
struct ggml_tensor * embd_enc = nullptr;
|
| 739 |
|
|
|
|
|
|
|
|
|
|
| 740 |
// decode output (2-dimensional array: [n_tokens][n_vocab])
|
| 741 |
std::vector<float> logits;
|
| 742 |
|
|
@@ -751,22 +737,21 @@ struct whisper_state {
|
|
| 751 |
int lang_id = 0; // english by default
|
| 752 |
|
| 753 |
std::string path_model; // populated by whisper_init_from_file_with_params()
|
|
|
|
| 754 |
#ifdef WHISPER_USE_COREML
|
| 755 |
whisper_coreml_context * ctx_coreml = nullptr;
|
| 756 |
#endif
|
| 757 |
|
| 758 |
-
#ifdef GGML_USE_METAL
|
| 759 |
-
ggml_metal_context * ctx_metal = nullptr;
|
| 760 |
-
#endif
|
| 761 |
-
|
| 762 |
#ifdef WHISPER_USE_OPENVINO
|
| 763 |
whisper_openvino_context * ctx_openvino = nullptr;
|
| 764 |
#endif
|
| 765 |
|
| 766 |
// [EXPERIMENTAL] token-level timestamps data
|
| 767 |
-
int64_t t_beg
|
| 768 |
int64_t t_last = 0;
|
|
|
|
| 769 |
whisper_token tid_last;
|
|
|
|
| 770 |
std::vector<float> energy; // PCM signal energy
|
| 771 |
|
| 772 |
// [EXPERIMENTAL] speed-up techniques
|
|
@@ -780,35 +765,25 @@ struct whisper_context {
|
|
| 780 |
ggml_type wtype = ggml_type::GGML_TYPE_F16; // weight type (FP32 / FP16 / QX)
|
| 781 |
ggml_type itype = ggml_type::GGML_TYPE_F16; // intermediate type (FP32 or FP16)
|
| 782 |
|
|
|
|
|
|
|
| 783 |
whisper_model model;
|
| 784 |
whisper_vocab vocab;
|
|
|
|
| 785 |
whisper_state * state = nullptr;
|
| 786 |
|
|
|
|
|
|
|
| 787 |
std::string path_model; // populated by whisper_init_from_file_with_params()
|
| 788 |
-
whisper_context_params params;
|
| 789 |
};
|
| 790 |
|
| 791 |
-
|
| 792 |
-
|
| 793 |
-
|
|
|
|
|
|
|
| 794 |
|
| 795 |
-
static
|
| 796 |
-
|
| 797 |
-
#ifdef __GNUC__
|
| 798 |
-
#ifdef __MINGW32__
|
| 799 |
-
__attribute__((gnu_format(printf, 1, 2)))
|
| 800 |
-
#else
|
| 801 |
-
__attribute__((format(printf, 1, 2)))
|
| 802 |
-
#endif
|
| 803 |
-
#endif
|
| 804 |
-
static void log(const char * fmt, ...) {
|
| 805 |
-
if (!whisper_log) return;
|
| 806 |
-
char buf[1024];
|
| 807 |
-
va_list args;
|
| 808 |
-
va_start(args, fmt);
|
| 809 |
-
vsnprintf(buf, sizeof(buf), fmt, args);
|
| 810 |
-
whisper_log(buf);
|
| 811 |
-
}
|
| 812 |
|
| 813 |
template<typename T>
|
| 814 |
static void read_safe(whisper_model_loader * loader, T & dest) {
|
|
@@ -819,6 +794,7 @@ static void read_safe(whisper_model_loader * loader, T & dest) {
|
|
| 819 |
static bool kv_cache_init(
|
| 820 |
const struct whisper_hparams & hparams,
|
| 821 |
struct whisper_kv_cache & cache,
|
|
|
|
| 822 |
ggml_type wtype,
|
| 823 |
int n_ctx) {
|
| 824 |
const int64_t n_text_state = hparams.n_text_state;
|
|
@@ -827,30 +803,41 @@ static bool kv_cache_init(
|
|
| 827 |
const int64_t n_mem = n_text_layer*n_ctx;
|
| 828 |
const int64_t n_elements = n_text_state*n_mem;
|
| 829 |
|
| 830 |
-
const size_t mem_bytes = 2*(ggml_type_size(wtype)*n_elements + ggml_tensor_overhead());
|
| 831 |
-
|
| 832 |
-
cache.buf.resize(mem_bytes);
|
| 833 |
-
|
| 834 |
struct ggml_init_params params = {
|
| 835 |
-
/*.mem_size =*/
|
| 836 |
-
/*.mem_buffer =*/
|
| 837 |
-
/*.no_alloc =*/
|
| 838 |
};
|
| 839 |
|
| 840 |
cache.ctx = ggml_init(params);
|
| 841 |
|
| 842 |
if (!cache.ctx) {
|
| 843 |
-
|
| 844 |
return false;
|
| 845 |
}
|
| 846 |
|
| 847 |
cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
|
| 848 |
cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
|
| 849 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 850 |
return true;
|
| 851 |
}
|
| 852 |
|
| 853 |
-
|
|
|
|
| 854 |
WHISPER_ASSERT(cache.ctx);
|
| 855 |
|
| 856 |
const int n_elements = ggml_nelements(cache.k);
|
|
@@ -859,34 +846,78 @@ static bool kv_cache_reinit(struct whisper_kv_cache & cache) {
|
|
| 859 |
const ggml_type wtype = cache.k->type;
|
| 860 |
WHISPER_ASSERT(wtype == cache.v->type);
|
| 861 |
|
| 862 |
-
WHISPER_ASSERT(cache.buf.size() >= 2*n_elements*ggml_type_sizef(wtype));
|
| 863 |
-
|
| 864 |
struct ggml_init_params params = {
|
| 865 |
-
/*.mem_size =*/
|
| 866 |
-
/*.mem_buffer =*/
|
| 867 |
-
/*.no_alloc =*/
|
| 868 |
};
|
| 869 |
|
| 870 |
cache.ctx = ggml_init(params);
|
| 871 |
|
| 872 |
if (!cache.ctx) {
|
| 873 |
-
|
| 874 |
return false;
|
| 875 |
}
|
| 876 |
|
| 877 |
cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
|
| 878 |
cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
|
| 879 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 880 |
return true;
|
| 881 |
}
|
| 882 |
|
| 883 |
static void kv_cache_free(struct whisper_kv_cache & cache) {
|
| 884 |
if (cache.ctx) {
|
| 885 |
ggml_free(cache.ctx);
|
|
|
|
| 886 |
cache.ctx = nullptr;
|
| 887 |
}
|
| 888 |
}
|
| 889 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 890 |
// load the model from a ggml file
|
| 891 |
//
|
| 892 |
// file format:
|
|
@@ -899,7 +930,7 @@ static void kv_cache_free(struct whisper_kv_cache & cache) {
|
|
| 899 |
// see the convert-pt-to-ggml.py script for details
|
| 900 |
//
|
| 901 |
static bool whisper_model_load(struct whisper_model_loader * loader, whisper_context & wctx) {
|
| 902 |
-
|
| 903 |
|
| 904 |
const int64_t t_start_us = ggml_time_us();
|
| 905 |
|
|
@@ -913,7 +944,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
| 913 |
uint32_t magic;
|
| 914 |
read_safe(loader, magic);
|
| 915 |
if (magic != GGML_FILE_MAGIC) {
|
| 916 |
-
|
| 917 |
return false;
|
| 918 |
}
|
| 919 |
}
|
|
@@ -970,41 +1001,23 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
| 970 |
// in order to save memory and also to speed up the computation
|
| 971 |
wctx.wtype = ggml_ftype_to_ggml_type((ggml_ftype) (model.hparams.ftype));
|
| 972 |
if (wctx.wtype == GGML_TYPE_COUNT) {
|
| 973 |
-
|
| 974 |
return false;
|
| 975 |
}
|
| 976 |
|
| 977 |
-
|
| 978 |
-
|
| 979 |
-
|
| 980 |
-
|
| 981 |
-
|
| 982 |
-
|
| 983 |
-
|
| 984 |
-
|
| 985 |
-
|
| 986 |
-
|
| 987 |
-
|
| 988 |
-
|
| 989 |
-
|
| 990 |
-
log("%s: qntvr = %d\n", __func__, qntvr);
|
| 991 |
-
log("%s: type = %d (%s%s)\n", __func__, model.type, g_model_name.at(model.type).c_str(), mver.c_str());
|
| 992 |
-
|
| 993 |
-
// print memory requirements
|
| 994 |
-
{
|
| 995 |
-
// TODO
|
| 996 |
-
//log("%s: mem required = %7.2f MB (+ %7.2f MB per decoder)\n", __func__,
|
| 997 |
-
// mem_required / 1024.0 / 1024.0, mem_required_decoder / 1024.0 / 1024.0);
|
| 998 |
-
}
|
| 999 |
-
|
| 1000 |
-
// initialize all memory buffers
|
| 1001 |
-
// always have at least one decoder
|
| 1002 |
-
|
| 1003 |
-
wctx.model.buf = new std::vector<uint8_t>();
|
| 1004 |
-
wctx.model.buf->resize(scale*MEM_REQ_MODEL.at(wctx.wtype).at(model.type));
|
| 1005 |
-
|
| 1006 |
-
// we skip initialization of the state until it is needed
|
| 1007 |
-
// because it might be that state will always be provided externally.
|
| 1008 |
}
|
| 1009 |
|
| 1010 |
// load mel filters
|
|
@@ -1025,7 +1038,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
| 1025 |
read_safe(loader, n_vocab);
|
| 1026 |
|
| 1027 |
//if (n_vocab != model.hparams.n_vocab) {
|
| 1028 |
-
//
|
| 1029 |
// __func__, fname.c_str(), n_vocab, model.hparams.n_vocab);
|
| 1030 |
// return false;
|
| 1031 |
//}
|
|
@@ -1045,7 +1058,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
| 1045 |
word.assign(&tmp[0], tmp.size());
|
| 1046 |
} else {
|
| 1047 |
// seems like we have an empty-string token in multi-language models (i = 50256)
|
| 1048 |
-
//
|
| 1049 |
word = "";
|
| 1050 |
}
|
| 1051 |
|
|
@@ -1073,7 +1086,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
| 1073 |
}
|
| 1074 |
|
| 1075 |
if (n_vocab < model.hparams.n_vocab) {
|
| 1076 |
-
|
| 1077 |
for (int i = n_vocab; i < model.hparams.n_vocab; i++) {
|
| 1078 |
if (i > vocab.token_beg) {
|
| 1079 |
word = "[_TT_" + std::to_string(i - vocab.token_beg) + "]";
|
|
@@ -1099,140 +1112,35 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
| 1099 |
}
|
| 1100 |
}
|
| 1101 |
|
| 1102 |
-
|
| 1103 |
}
|
| 1104 |
|
| 1105 |
-
size_t ctx_size = 0;
|
| 1106 |
-
|
| 1107 |
const ggml_type wtype = wctx.wtype;
|
| 1108 |
const ggml_type vtype = wctx.wtype == GGML_TYPE_F32 ? GGML_TYPE_F32 : GGML_TYPE_F16; // conv type
|
| 1109 |
|
|
|
|
| 1110 |
{
|
| 1111 |
const auto & hparams = model.hparams;
|
| 1112 |
|
| 1113 |
-
const int n_vocab = hparams.n_vocab;
|
| 1114 |
-
|
| 1115 |
-
const int n_audio_ctx = hparams.n_audio_ctx;
|
| 1116 |
-
const int n_audio_state = hparams.n_audio_state;
|
| 1117 |
const int n_audio_layer = hparams.n_audio_layer;
|
|
|
|
| 1118 |
|
| 1119 |
-
const
|
| 1120 |
-
const int n_text_state = hparams.n_text_state;
|
| 1121 |
-
const int n_text_layer = hparams.n_text_layer;
|
| 1122 |
-
|
| 1123 |
-
const int n_mels = hparams.n_mels;
|
| 1124 |
-
|
| 1125 |
-
// encoder
|
| 1126 |
-
{
|
| 1127 |
-
ctx_size += n_audio_ctx*n_audio_state*ggml_type_sizef(GGML_TYPE_F32); // e_pe;
|
| 1128 |
-
|
| 1129 |
-
ctx_size += 3*n_mels*n_audio_state*ggml_type_sizef(vtype); // e_conv_1_w
|
| 1130 |
-
ctx_size += n_audio_state*ggml_type_sizef(GGML_TYPE_F32); // e_conv_1_b
|
| 1131 |
-
|
| 1132 |
-
ctx_size += 3*n_audio_state*n_audio_state*ggml_type_sizef(vtype); // e_conv_2_w
|
| 1133 |
-
ctx_size += n_audio_state*ggml_type_sizef(GGML_TYPE_F32); // e_conv_2_b
|
| 1134 |
-
|
| 1135 |
-
ctx_size += n_audio_state*ggml_type_sizef(GGML_TYPE_F32); // e_ln_w;
|
| 1136 |
-
ctx_size += n_audio_state*ggml_type_sizef(GGML_TYPE_F32); // e_ln_b;
|
| 1137 |
-
}
|
| 1138 |
-
|
| 1139 |
-
// decoder
|
| 1140 |
-
{
|
| 1141 |
-
ctx_size += n_text_ctx*n_text_state*ggml_type_sizef(GGML_TYPE_F32); // d_pe;
|
| 1142 |
-
|
| 1143 |
-
ctx_size += n_vocab*n_text_state*ggml_type_sizef(wtype); // d_te;
|
| 1144 |
-
|
| 1145 |
-
ctx_size += n_text_state*ggml_type_sizef(GGML_TYPE_F32); // d_ln_w;
|
| 1146 |
-
ctx_size += n_text_state*ggml_type_sizef(GGML_TYPE_F32); // d_ln_b;
|
| 1147 |
-
}
|
| 1148 |
-
|
| 1149 |
-
// encoder layers
|
| 1150 |
-
{
|
| 1151 |
-
ctx_size += n_audio_layer*(n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_ln_w
|
| 1152 |
-
ctx_size += n_audio_layer*(n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_ln_b
|
| 1153 |
-
|
| 1154 |
-
ctx_size += n_audio_layer*(4*n_audio_state*n_audio_state*ggml_type_sizef(wtype)); // mlp_0_w
|
| 1155 |
-
ctx_size += n_audio_layer*( 4*n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_0_b
|
| 1156 |
-
|
| 1157 |
-
ctx_size += n_audio_layer*(4*n_audio_state*n_audio_state*ggml_type_sizef(wtype)); // mlp_1_w
|
| 1158 |
-
ctx_size += n_audio_layer*( n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_1_b
|
| 1159 |
-
|
| 1160 |
-
ctx_size += n_audio_layer*(n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_ln_0_w
|
| 1161 |
-
ctx_size += n_audio_layer*(n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_ln_0_b
|
| 1162 |
-
|
| 1163 |
-
ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_sizef(wtype)); // attn_q_w
|
| 1164 |
-
ctx_size += n_audio_layer*( n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_q_b
|
| 1165 |
-
|
| 1166 |
-
ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_sizef(wtype)); // attn_k_w
|
| 1167 |
-
|
| 1168 |
-
ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_sizef(wtype)); // attn_v_w
|
| 1169 |
-
ctx_size += n_audio_layer*( n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_v_b
|
| 1170 |
-
|
| 1171 |
-
ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_sizef(wtype)); // attn_ln_1_w
|
| 1172 |
-
ctx_size += n_audio_layer*( n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_ln_1_b
|
| 1173 |
-
}
|
| 1174 |
-
|
| 1175 |
-
// decoder layers
|
| 1176 |
-
{
|
| 1177 |
-
ctx_size += n_text_layer*(n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_ln_w
|
| 1178 |
-
ctx_size += n_text_layer*(n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_ln_b
|
| 1179 |
-
|
| 1180 |
-
ctx_size += n_text_layer*(4*n_text_state*n_text_state*ggml_type_sizef(wtype)); // mlp_0_w
|
| 1181 |
-
ctx_size += n_text_layer*( 4*n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_0_b
|
| 1182 |
-
|
| 1183 |
-
ctx_size += n_text_layer*(4*n_text_state*n_text_state*ggml_type_sizef(wtype)); // mlp_1_w
|
| 1184 |
-
ctx_size += n_text_layer*( n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_1_b
|
| 1185 |
-
|
| 1186 |
-
ctx_size += n_text_layer*(n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_ln_0_w
|
| 1187 |
-
ctx_size += n_text_layer*(n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_ln_0_b
|
| 1188 |
-
|
| 1189 |
-
ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // attn_q_w
|
| 1190 |
-
ctx_size += n_text_layer*( n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_q_b
|
| 1191 |
-
|
| 1192 |
-
ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // attn_k_w
|
| 1193 |
-
|
| 1194 |
-
ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // attn_v_w
|
| 1195 |
-
ctx_size += n_text_layer*( n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_v_b
|
| 1196 |
-
|
| 1197 |
-
ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // attn_ln_1_w
|
| 1198 |
-
ctx_size += n_text_layer*( n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_ln_1_b
|
| 1199 |
-
//
|
| 1200 |
-
ctx_size += n_text_layer*(n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // cross_attn_ln_0_w
|
| 1201 |
-
ctx_size += n_text_layer*(n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // cross_attn_ln_0_b
|
| 1202 |
-
|
| 1203 |
-
ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // cross_attn_q_w
|
| 1204 |
-
ctx_size += n_text_layer*( n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // cross_attn_q_b
|
| 1205 |
-
|
| 1206 |
-
ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // cross_attn_k_w
|
| 1207 |
-
|
| 1208 |
-
ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // cross_attn_v_w
|
| 1209 |
-
ctx_size += n_text_layer*( n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // cross_attn_v_b
|
| 1210 |
-
|
| 1211 |
-
ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // cross_attn_ln_1_w
|
| 1212 |
-
ctx_size += n_text_layer*( n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // cross_attn_ln_1_b
|
| 1213 |
-
}
|
| 1214 |
-
|
| 1215 |
-
ctx_size += (15 + 15*n_audio_layer + 24*n_text_layer)*512; // object overhead
|
| 1216 |
-
|
| 1217 |
-
log("%s: model ctx = %7.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
|
| 1218 |
-
}
|
| 1219 |
|
| 1220 |
-
// create the ggml context
|
| 1221 |
-
{
|
| 1222 |
struct ggml_init_params params = {
|
| 1223 |
-
/*.mem_size =*/
|
| 1224 |
-
/*.mem_buffer =*/
|
| 1225 |
-
/*.no_alloc =*/
|
| 1226 |
};
|
| 1227 |
|
| 1228 |
model.ctx = ggml_init(params);
|
| 1229 |
if (!model.ctx) {
|
| 1230 |
-
|
| 1231 |
return false;
|
| 1232 |
}
|
| 1233 |
}
|
| 1234 |
|
| 1235 |
-
// prepare
|
| 1236 |
{
|
| 1237 |
auto & ctx = model.ctx;
|
| 1238 |
|
|
@@ -1255,16 +1163,16 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
| 1255 |
|
| 1256 |
// encoder
|
| 1257 |
{
|
| 1258 |
-
model.e_pe
|
| 1259 |
|
| 1260 |
-
model.e_conv_1_w
|
| 1261 |
-
model.e_conv_1_b
|
| 1262 |
|
| 1263 |
-
model.e_conv_2_w
|
| 1264 |
-
model.e_conv_2_b
|
| 1265 |
|
| 1266 |
-
model.e_ln_w
|
| 1267 |
-
model.e_ln_b
|
| 1268 |
|
| 1269 |
// map by name
|
| 1270 |
model.tensors["encoder.positional_embedding"] = model.e_pe;
|
|
@@ -1428,12 +1336,37 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
| 1428 |
}
|
| 1429 |
}
|
| 1430 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1431 |
// load weights
|
| 1432 |
{
|
| 1433 |
size_t total_size = 0;
|
| 1434 |
|
| 1435 |
model.n_loaded = 0;
|
| 1436 |
|
|
|
|
|
|
|
| 1437 |
while (true) {
|
| 1438 |
int32_t n_dims;
|
| 1439 |
int32_t length;
|
|
@@ -1460,50 +1393,92 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
| 1460 |
name.assign(&tmp[0], tmp.size());
|
| 1461 |
|
| 1462 |
if (model.tensors.find(name) == model.tensors.end()) {
|
| 1463 |
-
|
| 1464 |
return false;
|
| 1465 |
}
|
| 1466 |
|
| 1467 |
auto tensor = model.tensors[name.data()];
|
| 1468 |
-
if (ggml_nelements(tensor) != nelements) {
|
| 1469 |
-
log("%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
|
| 1470 |
-
log("%s: shape: [%d, %d, %d], expected: [%d, %d, %d]\n",
|
| 1471 |
-
__func__, ne[0], ne[1], ne[2], (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2]);
|
| 1472 |
-
return false;
|
| 1473 |
-
}
|
| 1474 |
|
| 1475 |
-
|
| 1476 |
-
log("%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]\n",
|
| 1477 |
-
__func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2], ne[0], ne[1], ne[2]);
|
| 1478 |
-
return false;
|
| 1479 |
-
}
|
| 1480 |
|
| 1481 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1482 |
|
| 1483 |
-
|
| 1484 |
-
|
| 1485 |
-
|
| 1486 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1487 |
}
|
| 1488 |
|
| 1489 |
-
|
| 1490 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1491 |
|
| 1492 |
//printf("%48s - [%5d, %5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ne[2], ggml_type_name((ggml_type) ttype), ggml_nbytes(tensor)/1024.0/1024.0);
|
| 1493 |
total_size += ggml_nbytes(tensor);
|
| 1494 |
model.n_loaded++;
|
| 1495 |
}
|
| 1496 |
|
| 1497 |
-
|
| 1498 |
|
| 1499 |
if (model.n_loaded == 0) {
|
| 1500 |
-
|
| 1501 |
} else if (model.n_loaded != (int) model.tensors.size()) {
|
| 1502 |
-
|
| 1503 |
return false;
|
| 1504 |
}
|
| 1505 |
}
|
| 1506 |
|
|
|
|
|
|
|
| 1507 |
wctx.t_load_us = ggml_time_us() - t_start_us;
|
| 1508 |
|
| 1509 |
return true;
|
|
@@ -1559,10 +1534,12 @@ static struct ggml_cgraph * whisper_build_graph_conv(
|
|
| 1559 |
if (!ggml_allocr_is_measure(alloc)) {
|
| 1560 |
assert(mel_inp.n_mel == n_mels);
|
| 1561 |
|
| 1562 |
-
|
|
|
|
|
|
|
| 1563 |
memset(dst, 0, ggml_nbytes(mel));
|
| 1564 |
|
| 1565 |
-
const int i0 = std::min(mel_offset,
|
| 1566 |
const int i1 = std::min(mel_offset + 2*n_ctx, mel_inp.n_len);
|
| 1567 |
|
| 1568 |
for (int j = 0; j < mel_inp.n_mel; ++j) {
|
|
@@ -1570,6 +1547,8 @@ static struct ggml_cgraph * whisper_build_graph_conv(
|
|
| 1570 |
dst[j*2*n_ctx + (i - i0)] = mel_inp.data[j*mel_inp.n_len + i];
|
| 1571 |
}
|
| 1572 |
}
|
|
|
|
|
|
|
| 1573 |
}
|
| 1574 |
|
| 1575 |
struct ggml_tensor * cur = nullptr;
|
|
@@ -1578,24 +1557,27 @@ static struct ggml_cgraph * whisper_build_graph_conv(
|
|
| 1578 |
// convolution + gelu
|
| 1579 |
{
|
| 1580 |
cur = ggml_conv_1d_ph(ctx0, model.e_conv_1_w, mel, 1, 1);
|
| 1581 |
-
cur = ggml_add(ctx0,
|
| 1582 |
-
|
| 1583 |
-
|
| 1584 |
-
|
| 1585 |
-
|
|
|
|
| 1586 |
|
| 1587 |
cur = ggml_gelu(ctx0, cur);
|
| 1588 |
|
| 1589 |
cur = ggml_conv_1d_ph(ctx0, model.e_conv_2_w, cur, 2, 1);
|
| 1590 |
-
cur = ggml_add(ctx0,
|
| 1591 |
-
|
| 1592 |
-
|
| 1593 |
-
|
| 1594 |
-
|
|
|
|
| 1595 |
|
| 1596 |
cur = ggml_gelu(ctx0, cur);
|
| 1597 |
}
|
| 1598 |
|
|
|
|
| 1599 |
wstate.embd_conv = cur;
|
| 1600 |
} else {
|
| 1601 |
#ifdef WHISPER_USE_COREML
|
|
@@ -1615,6 +1597,7 @@ static struct ggml_cgraph * whisper_build_graph_conv(
|
|
| 1615 |
}
|
| 1616 |
#endif
|
| 1617 |
|
|
|
|
| 1618 |
wstate.embd_enc = cur;
|
| 1619 |
}
|
| 1620 |
|
|
@@ -1648,15 +1631,22 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
|
|
| 1648 |
|
| 1649 |
ggml_allocr * alloc = wstate.alloc_encode.alloc;
|
| 1650 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1651 |
struct ggml_tensor * KQscale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
|
| 1652 |
ggml_allocr_alloc(alloc, KQscale);
|
| 1653 |
|
| 1654 |
if (!ggml_allocr_is_measure(alloc)) {
|
| 1655 |
-
|
|
|
|
| 1656 |
}
|
| 1657 |
|
| 1658 |
-
struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_conv);
|
| 1659 |
-
|
| 1660 |
// ===================================================================
|
| 1661 |
// NOTE: experimenting with partial evaluation of the encoder (ignore)
|
| 1662 |
//static int iter = -1;
|
|
@@ -1675,7 +1665,6 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
|
|
| 1675 |
const size_t e_pe_offset = model.e_pe->ne[0]*ggml_element_size(model.e_pe)*n_ctx*iter;
|
| 1676 |
|
| 1677 |
struct ggml_tensor * e_pe = ggml_view_2d(ctx0, model.e_pe, model.e_pe->ne[0], n_ctx, e_pe_stride, e_pe_offset);
|
| 1678 |
-
|
| 1679 |
cur = ggml_add(ctx0, e_pe, ggml_cont(ctx0, ggml_transpose(ctx0, cur)));
|
| 1680 |
|
| 1681 |
// ===================================================================
|
|
@@ -1897,13 +1886,20 @@ static struct ggml_cgraph * whisper_build_graph_cross(
|
|
| 1897 |
|
| 1898 |
ggml_allocr * alloc = wstate.alloc_cross.alloc;
|
| 1899 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1900 |
struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_enc);
|
| 1901 |
|
| 1902 |
struct ggml_tensor * Kscale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
|
| 1903 |
ggml_allocr_alloc(alloc, Kscale);
|
| 1904 |
|
| 1905 |
if (!ggml_allocr_is_measure(alloc)) {
|
| 1906 |
-
|
|
|
|
| 1907 |
}
|
| 1908 |
|
| 1909 |
for (int il = 0; il < model.hparams.n_text_layer; ++il) {
|
|
@@ -1974,7 +1970,7 @@ static bool whisper_encode_internal(
|
|
| 1974 |
ggml_allocr_alloc_graph(alloc, gf);
|
| 1975 |
|
| 1976 |
if (!whisper_encode_external(wstate)) {
|
| 1977 |
-
ggml_graph_compute_helper(wstate.
|
| 1978 |
}
|
| 1979 |
}
|
| 1980 |
|
|
@@ -1988,16 +1984,7 @@ static bool whisper_encode_internal(
|
|
| 1988 |
|
| 1989 |
ggml_allocr_alloc_graph(alloc, gf);
|
| 1990 |
|
| 1991 |
-
|
| 1992 |
-
if (wstate.ctx_metal) {
|
| 1993 |
-
ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
|
| 1994 |
-
ggml_metal_graph_compute(wstate.ctx_metal, gf);
|
| 1995 |
-
} else {
|
| 1996 |
-
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
|
| 1997 |
-
}
|
| 1998 |
-
#else
|
| 1999 |
-
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
|
| 2000 |
-
#endif
|
| 2001 |
}
|
| 2002 |
|
| 2003 |
// cross
|
|
@@ -2010,20 +1997,9 @@ static bool whisper_encode_internal(
|
|
| 2010 |
|
| 2011 |
ggml_allocr_alloc_graph(alloc, gf);
|
| 2012 |
|
| 2013 |
-
|
| 2014 |
-
if (wstate.ctx_metal) {
|
| 2015 |
-
ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
|
| 2016 |
-
ggml_metal_graph_compute(wstate.ctx_metal, gf);
|
| 2017 |
-
} else {
|
| 2018 |
-
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
|
| 2019 |
-
}
|
| 2020 |
-
#else
|
| 2021 |
-
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
|
| 2022 |
-
#endif
|
| 2023 |
}
|
| 2024 |
|
| 2025 |
-
// ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
|
| 2026 |
-
|
| 2027 |
wstate.t_encode_us += ggml_time_us() - t_start_us;
|
| 2028 |
wstate.n_encode++;
|
| 2029 |
|
|
@@ -2070,7 +2046,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
|
| 2070 |
ggml_allocr_alloc(alloc, embd);
|
| 2071 |
|
| 2072 |
if (!ggml_allocr_is_measure(alloc)) {
|
| 2073 |
-
|
| 2074 |
}
|
| 2075 |
|
| 2076 |
struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
|
|
@@ -2078,7 +2054,8 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
|
| 2078 |
|
| 2079 |
if (!ggml_allocr_is_measure(alloc)) {
|
| 2080 |
for (int i = 0; i < N; ++i) {
|
| 2081 |
-
|
|
|
|
| 2082 |
}
|
| 2083 |
}
|
| 2084 |
|
|
@@ -2086,7 +2063,8 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
|
| 2086 |
ggml_allocr_alloc(alloc, KQscale);
|
| 2087 |
|
| 2088 |
if (!ggml_allocr_is_measure(alloc)) {
|
| 2089 |
-
|
|
|
|
| 2090 |
}
|
| 2091 |
|
| 2092 |
// token encoding + position encoding
|
|
@@ -2410,25 +2388,18 @@ static bool whisper_decode_internal(
|
|
| 2410 |
|
| 2411 |
logits = gf->nodes[gf->n_nodes - 1];
|
| 2412 |
|
| 2413 |
-
|
| 2414 |
-
if (wstate.ctx_metal) {
|
| 2415 |
-
ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
|
| 2416 |
-
ggml_metal_graph_compute(wstate.ctx_metal, gf);
|
| 2417 |
-
} else {
|
| 2418 |
-
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
|
| 2419 |
-
}
|
| 2420 |
-
#else
|
| 2421 |
-
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
|
| 2422 |
-
#endif
|
| 2423 |
}
|
| 2424 |
|
| 2425 |
// extract logits for all N tokens
|
| 2426 |
//logits_out.resize(n_tokens*n_vocab);
|
| 2427 |
//memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_tokens*n_vocab);
|
|
|
|
| 2428 |
|
| 2429 |
// extract logits only for the last token
|
| 2430 |
logits_out.resize(n_vocab);
|
| 2431 |
-
memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_vocab);
|
|
|
|
| 2432 |
|
| 2433 |
if (n_tokens > 1) {
|
| 2434 |
//printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
|
|
@@ -2794,7 +2765,7 @@ static std::vector<whisper_vocab::id> tokenize(const whisper_vocab & vocab, cons
|
|
| 2794 |
--j;
|
| 2795 |
}
|
| 2796 |
if (!found) {
|
| 2797 |
-
|
| 2798 |
++i;
|
| 2799 |
}
|
| 2800 |
}
|
|
@@ -2857,45 +2828,48 @@ static std::string whisper_openvino_get_path_cache(std::string path_bin) {
|
|
| 2857 |
|
| 2858 |
struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
| 2859 |
fill_sin_cos_table();
|
|
|
|
| 2860 |
whisper_state * state = new whisper_state;
|
| 2861 |
|
| 2862 |
-
|
| 2863 |
-
|
|
|
|
|
|
|
| 2864 |
delete state;
|
| 2865 |
return nullptr;
|
| 2866 |
}
|
| 2867 |
|
| 2868 |
{
|
| 2869 |
const size_t memory_size = ggml_nbytes(state->decoders[0].kv_self.k) + ggml_nbytes(state->decoders[0].kv_self.v);
|
| 2870 |
-
|
| 2871 |
}
|
| 2872 |
|
| 2873 |
-
if (!kv_cache_init(ctx->model.hparams, state->kv_cross, ctx->itype, ctx->model.hparams.n_audio_ctx)) {
|
| 2874 |
-
|
| 2875 |
delete state;
|
| 2876 |
return nullptr;
|
| 2877 |
}
|
| 2878 |
|
| 2879 |
{
|
| 2880 |
const size_t memory_size = ggml_nbytes(state->kv_cross.k) + ggml_nbytes(state->kv_cross.v);
|
| 2881 |
-
|
| 2882 |
}
|
| 2883 |
|
| 2884 |
#ifdef WHISPER_USE_COREML
|
| 2885 |
const auto path_coreml = whisper_get_coreml_path_encoder(ctx->path_model);
|
| 2886 |
|
| 2887 |
-
|
| 2888 |
-
|
| 2889 |
|
| 2890 |
state->ctx_coreml = whisper_coreml_init(path_coreml.c_str());
|
| 2891 |
if (!state->ctx_coreml) {
|
| 2892 |
-
|
| 2893 |
#ifndef WHISPER_COREML_ALLOW_FALLBACK
|
| 2894 |
delete state;
|
| 2895 |
return nullptr;
|
| 2896 |
#endif
|
| 2897 |
} else {
|
| 2898 |
-
|
| 2899 |
}
|
| 2900 |
#endif
|
| 2901 |
|
|
@@ -2912,37 +2886,37 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|
| 2912 |
|
| 2913 |
// conv allocator
|
| 2914 |
{
|
| 2915 |
-
whisper_allocr_graph_init(state->alloc_conv,
|
| 2916 |
[&]() {
|
| 2917 |
return whisper_build_graph_conv(*ctx, *state, 0);
|
| 2918 |
});
|
| 2919 |
|
| 2920 |
-
|
| 2921 |
}
|
| 2922 |
|
| 2923 |
// encoder allocator
|
| 2924 |
if (!whisper_encode_external(*state)) {
|
| 2925 |
-
whisper_allocr_graph_init(state->alloc_encode,
|
| 2926 |
[&]() {
|
| 2927 |
return whisper_build_graph_encoder(*ctx, *state);
|
| 2928 |
});
|
| 2929 |
|
| 2930 |
-
|
| 2931 |
}
|
| 2932 |
|
| 2933 |
// cross allocator
|
| 2934 |
{
|
| 2935 |
-
whisper_allocr_graph_init(state->alloc_cross,
|
| 2936 |
[&]() {
|
| 2937 |
return whisper_build_graph_cross(*ctx, *state);
|
| 2938 |
});
|
| 2939 |
|
| 2940 |
-
|
| 2941 |
}
|
| 2942 |
|
| 2943 |
// decoder allocator
|
| 2944 |
{
|
| 2945 |
-
whisper_allocr_graph_init(state->alloc_decode,
|
| 2946 |
[&]() {
|
| 2947 |
const auto & hparams = ctx->model.hparams;
|
| 2948 |
|
|
@@ -2953,69 +2927,13 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|
| 2953 |
return whisper_build_graph_decoder(*ctx, *state, state->decoders[0], nullptr, n_tokens, n_past);
|
| 2954 |
});
|
| 2955 |
|
| 2956 |
-
|
| 2957 |
}
|
| 2958 |
|
| 2959 |
-
|
| 2960 |
-
|
| 2961 |
-
|
| 2962 |
-
|
| 2963 |
-
log("%s: ggml_metal_init() failed\n", __func__);
|
| 2964 |
-
delete state;
|
| 2965 |
-
return nullptr;
|
| 2966 |
-
}
|
| 2967 |
-
}
|
| 2968 |
-
|
| 2969 |
-
if (state->ctx_metal) {
|
| 2970 |
-
log("%s: Metal context initialized\n", __func__);
|
| 2971 |
-
|
| 2972 |
-
// this allocates all Metal resources and memory buffers
|
| 2973 |
-
|
| 2974 |
-
void * data_ptr = NULL;
|
| 2975 |
-
size_t data_size = 0;
|
| 2976 |
-
|
| 2977 |
-
// TODO: add mmap support
|
| 2978 |
-
//if (params.use_mmap) {
|
| 2979 |
-
// data_ptr = ctx->model.mapping->addr;
|
| 2980 |
-
// data_size = ctx->model.mapping->size;
|
| 2981 |
-
//} else {
|
| 2982 |
-
// data_ptr = ggml_get_mem_buffer(ctx->model.ctx);
|
| 2983 |
-
// data_size = ggml_get_mem_size (ctx->model.ctx);
|
| 2984 |
-
//}
|
| 2985 |
-
|
| 2986 |
-
data_ptr = ggml_get_mem_buffer(ctx->model.ctx);
|
| 2987 |
-
data_size = ggml_get_mem_size (ctx->model.ctx);
|
| 2988 |
-
|
| 2989 |
-
const size_t max_size = ggml_get_max_tensor_size(ctx->model.ctx);
|
| 2990 |
-
|
| 2991 |
-
log("%s: max tensor size = %8.2f MB\n", __func__, max_size/1024.0/1024.0);
|
| 2992 |
-
|
| 2993 |
-
#define WHISPER_METAL_CHECK_BUF(result) \
|
| 2994 |
-
if (!(result)) { \
|
| 2995 |
-
log("%s: failed to add metal buffer\n", __func__); \
|
| 2996 |
-
delete state; \
|
| 2997 |
-
return nullptr; \
|
| 2998 |
-
}
|
| 2999 |
-
|
| 3000 |
-
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data", data_ptr, data_size, max_size));
|
| 3001 |
-
|
| 3002 |
-
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_conv", state->alloc_conv.meta.data(), state->alloc_conv.meta.size(), 0));
|
| 3003 |
-
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_encode", state->alloc_encode.meta.data(), state->alloc_encode.meta.size(), 0));
|
| 3004 |
-
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_cross", state->alloc_cross.meta.data(), state->alloc_cross.meta.size(), 0));
|
| 3005 |
-
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_decode", state->alloc_decode.meta.data(), state->alloc_decode.meta.size(), 0));
|
| 3006 |
-
|
| 3007 |
-
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_conv", state->alloc_conv.data.data(), state->alloc_conv.data.size(), 0));
|
| 3008 |
-
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_encode", state->alloc_encode.data.data(), state->alloc_encode.data.size(), 0));
|
| 3009 |
-
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_cross", state->alloc_cross.data.data(), state->alloc_cross.data.size(), 0));
|
| 3010 |
-
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_decode", state->alloc_decode.data.data(), state->alloc_decode.data.size(), 0));
|
| 3011 |
-
|
| 3012 |
-
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "kv_cross", state->kv_cross.buf.data(), state->kv_cross.buf.size(), 0));
|
| 3013 |
-
|
| 3014 |
-
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "kv_self_0", state->decoders[0].kv_self.buf.data(), state->decoders[0].kv_self.buf.size(), 0));
|
| 3015 |
-
#undef WHISPER_METAL_CHECK_BUF
|
| 3016 |
-
|
| 3017 |
-
}
|
| 3018 |
-
#endif
|
| 3019 |
|
| 3020 |
state->rng = std::mt19937(0);
|
| 3021 |
|
|
@@ -3036,7 +2954,7 @@ int whisper_ctx_init_openvino_encoder(
|
|
| 3036 |
return 1;
|
| 3037 |
#else
|
| 3038 |
if (!model_path && ctx->path_model.empty()) {
|
| 3039 |
-
|
| 3040 |
return 1;
|
| 3041 |
}
|
| 3042 |
|
|
@@ -3056,15 +2974,15 @@ int whisper_ctx_init_openvino_encoder(
|
|
| 3056 |
path_cache = cache_dir;
|
| 3057 |
}
|
| 3058 |
|
| 3059 |
-
|
| 3060 |
-
|
| 3061 |
|
| 3062 |
ctx->state->ctx_openvino = whisper_openvino_init(path_encoder.c_str(), device, path_cache.c_str());
|
| 3063 |
if (!ctx->state->ctx_openvino) {
|
| 3064 |
-
|
| 3065 |
return 1;
|
| 3066 |
} else {
|
| 3067 |
-
|
| 3068 |
}
|
| 3069 |
|
| 3070 |
return 0;
|
|
@@ -3079,11 +2997,11 @@ struct whisper_context_params whisper_context_default_params() {
|
|
| 3079 |
}
|
| 3080 |
|
| 3081 |
struct whisper_context * whisper_init_from_file_with_params_no_state(const char * path_model, struct whisper_context_params params) {
|
| 3082 |
-
|
| 3083 |
|
| 3084 |
auto fin = std::ifstream(path_model, std::ios::binary);
|
| 3085 |
if (!fin) {
|
| 3086 |
-
|
| 3087 |
return nullptr;
|
| 3088 |
}
|
| 3089 |
|
|
@@ -3125,7 +3043,7 @@ struct whisper_context * whisper_init_from_buffer_with_params_no_state(void * bu
|
|
| 3125 |
|
| 3126 |
buf_context ctx = { reinterpret_cast<uint8_t*>(buffer), buffer_size, 0 };
|
| 3127 |
|
| 3128 |
-
|
| 3129 |
|
| 3130 |
whisper_model_loader loader = {};
|
| 3131 |
|
|
@@ -3161,7 +3079,7 @@ struct whisper_context * whisper_init_with_params_no_state(struct whisper_model_
|
|
| 3161 |
|
| 3162 |
if (!whisper_model_load(loader, *ctx)) {
|
| 3163 |
loader->close(loader->context);
|
| 3164 |
-
|
| 3165 |
delete ctx;
|
| 3166 |
return nullptr;
|
| 3167 |
}
|
|
@@ -3256,13 +3174,6 @@ void whisper_free_state(struct whisper_state * state)
|
|
| 3256 |
}
|
| 3257 |
#endif
|
| 3258 |
|
| 3259 |
-
#ifdef GGML_USE_METAL
|
| 3260 |
-
if (state->ctx_metal) {
|
| 3261 |
-
ggml_metal_free(state->ctx_metal);
|
| 3262 |
-
state->ctx_metal = nullptr;
|
| 3263 |
-
}
|
| 3264 |
-
#endif
|
| 3265 |
-
|
| 3266 |
#ifdef WHISPER_USE_OPENVINO
|
| 3267 |
if (state->ctx_openvino != nullptr) {
|
| 3268 |
whisper_openvino_free(state->ctx_openvino);
|
|
@@ -3271,9 +3182,11 @@ void whisper_free_state(struct whisper_state * state)
|
|
| 3271 |
#endif
|
| 3272 |
|
| 3273 |
whisper_allocr_free(state->alloc_conv);
|
| 3274 |
-
whisper_allocr_free(state->alloc_decode);
|
| 3275 |
-
whisper_allocr_free(state->alloc_cross);
|
| 3276 |
whisper_allocr_free(state->alloc_encode);
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3277 |
|
| 3278 |
delete state;
|
| 3279 |
}
|
|
@@ -3284,12 +3197,15 @@ void whisper_free(struct whisper_context * ctx) {
|
|
| 3284 |
if (ctx->model.ctx) {
|
| 3285 |
ggml_free(ctx->model.ctx);
|
| 3286 |
}
|
| 3287 |
-
|
| 3288 |
-
|
|
|
|
| 3289 |
}
|
| 3290 |
|
| 3291 |
whisper_free_state(ctx->state);
|
| 3292 |
|
|
|
|
|
|
|
| 3293 |
delete ctx;
|
| 3294 |
}
|
| 3295 |
}
|
|
@@ -3308,7 +3224,7 @@ void whisper_free_params(struct whisper_full_params * params) {
|
|
| 3308 |
|
| 3309 |
int whisper_pcm_to_mel_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
|
| 3310 |
if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, ctx->model.filters.n_mel, n_threads, ctx->model.filters, false, state->mel)) {
|
| 3311 |
-
|
| 3312 |
return -1;
|
| 3313 |
}
|
| 3314 |
|
|
@@ -3322,7 +3238,7 @@ int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int
|
|
| 3322 |
// same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 (PV without phase lock is not good)
|
| 3323 |
int whisper_pcm_to_mel_phase_vocoder_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
|
| 3324 |
if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, 2 * WHISPER_N_FFT, 2 * WHISPER_HOP_LENGTH, ctx->model.filters.n_mel, n_threads, ctx->model.filters, false, state->mel)) {
|
| 3325 |
-
|
| 3326 |
return -1;
|
| 3327 |
}
|
| 3328 |
|
|
@@ -3350,7 +3266,7 @@ int whisper_set_mel_with_state(
|
|
| 3350 |
int n_len,
|
| 3351 |
int n_mel) {
|
| 3352 |
if (n_mel != ctx->model.filters.n_mel) {
|
| 3353 |
-
|
| 3354 |
return -1;
|
| 3355 |
}
|
| 3356 |
|
|
@@ -3374,7 +3290,7 @@ int whisper_set_mel(
|
|
| 3374 |
|
| 3375 |
int whisper_encode_with_state(struct whisper_context * ctx, struct whisper_state * state, int offset, int n_threads) {
|
| 3376 |
if (!whisper_encode_internal(*ctx, *state, offset, n_threads, nullptr, nullptr)) {
|
| 3377 |
-
|
| 3378 |
return -1;
|
| 3379 |
}
|
| 3380 |
|
|
@@ -3383,7 +3299,7 @@ int whisper_encode_with_state(struct whisper_context * ctx, struct whisper_state
|
|
| 3383 |
|
| 3384 |
int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
|
| 3385 |
if (!whisper_encode_internal(*ctx, *ctx->state, offset, n_threads, nullptr, nullptr)) {
|
| 3386 |
-
|
| 3387 |
return -1;
|
| 3388 |
}
|
| 3389 |
|
|
@@ -3394,7 +3310,7 @@ int whisper_decode_with_state(struct whisper_context * ctx, struct whisper_state
|
|
| 3394 |
const int selected_decoder_id = 0;
|
| 3395 |
|
| 3396 |
if (!whisper_decode_internal(*ctx, *state, state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads, nullptr, nullptr)) {
|
| 3397 |
-
|
| 3398 |
return 1;
|
| 3399 |
}
|
| 3400 |
|
|
@@ -3406,12 +3322,12 @@ int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, i
|
|
| 3406 |
const int selected_decoder_id = 0;
|
| 3407 |
|
| 3408 |
if (ctx->state == nullptr) {
|
| 3409 |
-
|
| 3410 |
return false;
|
| 3411 |
}
|
| 3412 |
|
| 3413 |
if (!whisper_decode_internal(*ctx, *ctx->state, ctx->state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads, nullptr, nullptr)) {
|
| 3414 |
-
|
| 3415 |
return 1;
|
| 3416 |
}
|
| 3417 |
|
|
@@ -3422,7 +3338,7 @@ int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_to
|
|
| 3422 |
const auto res = tokenize(ctx->vocab, text);
|
| 3423 |
|
| 3424 |
if (n_max_tokens < (int) res.size()) {
|
| 3425 |
-
|
| 3426 |
return -1;
|
| 3427 |
}
|
| 3428 |
|
|
@@ -3450,7 +3366,7 @@ int whisper_lang_id(const char * lang) {
|
|
| 3450 |
}
|
| 3451 |
}
|
| 3452 |
|
| 3453 |
-
|
| 3454 |
return -1;
|
| 3455 |
}
|
| 3456 |
return g_lang.at(lang).first;
|
|
@@ -3463,7 +3379,7 @@ const char * whisper_lang_str(int id) {
|
|
| 3463 |
}
|
| 3464 |
}
|
| 3465 |
|
| 3466 |
-
|
| 3467 |
return nullptr;
|
| 3468 |
}
|
| 3469 |
|
|
@@ -3476,25 +3392,25 @@ int whisper_lang_auto_detect_with_state(
|
|
| 3476 |
const int seek = offset_ms/10;
|
| 3477 |
|
| 3478 |
if (seek < 0) {
|
| 3479 |
-
|
| 3480 |
return -1;
|
| 3481 |
}
|
| 3482 |
|
| 3483 |
if (seek >= state->mel.n_len_org) {
|
| 3484 |
-
|
| 3485 |
return -2;
|
| 3486 |
}
|
| 3487 |
|
| 3488 |
// run the encoder
|
| 3489 |
if (whisper_encode_with_state(ctx, state, seek, n_threads) != 0) {
|
| 3490 |
-
|
| 3491 |
return -6;
|
| 3492 |
}
|
| 3493 |
|
| 3494 |
const std::vector<whisper_token> prompt = { whisper_token_sot(ctx) };
|
| 3495 |
|
| 3496 |
if (whisper_decode_with_state(ctx, state, prompt.data(), prompt.size(), 0, n_threads) != 0) {
|
| 3497 |
-
|
| 3498 |
return -7;
|
| 3499 |
}
|
| 3500 |
|
|
@@ -3694,8 +3610,8 @@ whisper_token whisper_token_transcribe(struct whisper_context * ctx) {
|
|
| 3694 |
void whisper_print_timings(struct whisper_context * ctx) {
|
| 3695 |
const int64_t t_end_us = ggml_time_us();
|
| 3696 |
|
| 3697 |
-
|
| 3698 |
-
|
| 3699 |
if (ctx->state != nullptr) {
|
| 3700 |
|
| 3701 |
const int32_t n_sample = std::max(1, ctx->state->n_sample);
|
|
@@ -3703,14 +3619,14 @@ void whisper_print_timings(struct whisper_context * ctx) {
|
|
| 3703 |
const int32_t n_decode = std::max(1, ctx->state->n_decode);
|
| 3704 |
const int32_t n_prompt = std::max(1, ctx->state->n_prompt);
|
| 3705 |
|
| 3706 |
-
|
| 3707 |
-
|
| 3708 |
-
|
| 3709 |
-
|
| 3710 |
-
|
| 3711 |
-
|
| 3712 |
}
|
| 3713 |
-
|
| 3714 |
}
|
| 3715 |
|
| 3716 |
void whisper_reset_timings(struct whisper_context * ctx) {
|
|
@@ -3762,6 +3678,7 @@ const char * whisper_print_system_info(void) {
|
|
| 3762 |
s += "SSE3 = " + std::to_string(ggml_cpu_has_sse3()) + " | ";
|
| 3763 |
s += "SSSE3 = " + std::to_string(ggml_cpu_has_ssse3()) + " | ";
|
| 3764 |
s += "VSX = " + std::to_string(ggml_cpu_has_vsx()) + " | ";
|
|
|
|
| 3765 |
s += "COREML = " + std::to_string(whisper_has_coreml()) + " | ";
|
| 3766 |
s += "OPENVINO = " + std::to_string(whisper_has_openvino()) + " | ";
|
| 3767 |
|
|
@@ -4056,7 +3973,7 @@ static void whisper_process_logits(
|
|
| 4056 |
const bool last_was_timestamp = tokens_cur.size() > 0 && tokens_cur.back().id >= vocab.token_beg;
|
| 4057 |
const bool penultimate_was_timestamp = tokens_cur.size() < 2 || tokens_cur[tokens_cur.size() - 2].id >= vocab.token_beg;
|
| 4058 |
|
| 4059 |
-
//
|
| 4060 |
|
| 4061 |
if (last_was_timestamp) {
|
| 4062 |
if (penultimate_was_timestamp) {
|
|
@@ -4132,7 +4049,7 @@ static void whisper_process_logits(
|
|
| 4132 |
|
| 4133 |
const float max_text_token_logprob = *std::max_element(logprobs.begin(), logprobs.begin() + vocab.token_beg);
|
| 4134 |
|
| 4135 |
-
//
|
| 4136 |
|
| 4137 |
if (timestamp_logprob > max_text_token_logprob) {
|
| 4138 |
for (int i = 0; i < vocab.token_beg; ++i) {
|
|
@@ -4427,8 +4344,10 @@ static bool whisper_kv_swap_fast(
|
|
| 4427 |
for (auto & i : two_copy) {
|
| 4428 |
// make a copy of KV caches
|
| 4429 |
WHISPER_PRINT_DEBUG("%s: store KV cache into swap: idx %d\n", __func__, i);
|
| 4430 |
-
memcpy(kv_swap_bufs[i].k.data(), src[i].kv_self.k->data, kv_swap_bufs[i].k.size());
|
| 4431 |
-
memcpy(kv_swap_bufs[i].v.data(), src[i].kv_self.v->data, kv_swap_bufs[i].v.size());
|
|
|
|
|
|
|
| 4432 |
}
|
| 4433 |
|
| 4434 |
// since two-copy decoder KV caches are protected by kv_swap_bufs, modify them first
|
|
@@ -4441,13 +4360,17 @@ static bool whisper_kv_swap_fast(
|
|
| 4441 |
if (two_copy.find(view[i]) != two_copy.end()) {
|
| 4442 |
// modify KV caches of decoder using data from kv_swap_bufs
|
| 4443 |
WHISPER_PRINT_DEBUG("%s: two-copy decoder using swap buffers: swap[%d] -> %d\n", __func__, view[i], i);
|
| 4444 |
-
memcpy(src[i].kv_self.k->data, kv_swap_bufs[view[i]].k.data(), kv_swap_bufs[view[i]].k.size());
|
| 4445 |
-
memcpy(src[i].kv_self.v->data, kv_swap_bufs[view[i]].v.data(), kv_swap_bufs[view[i]].v.size());
|
|
|
|
|
|
|
| 4446 |
} else {
|
| 4447 |
// modify KV caches of decoder using data from correspond decoder KV caches directly
|
| 4448 |
WHISPER_PRINT_DEBUG("%s: two-copy decoder without swap buffers: %d -> %d\n", __func__, view[i], i);
|
| 4449 |
-
memcpy(src[i].kv_self.k->data, src[view[i]].kv_self.k->data, ggml_nbytes(src[view[i]].kv_self.k));
|
| 4450 |
-
memcpy(src[i].kv_self.v->data, src[view[i]].kv_self.v->data, ggml_nbytes(src[view[i]].kv_self.v));
|
|
|
|
|
|
|
| 4451 |
}
|
| 4452 |
}
|
| 4453 |
|
|
@@ -4461,13 +4384,17 @@ static bool whisper_kv_swap_fast(
|
|
| 4461 |
if (two_copy.find(view[i]) != two_copy.end()) {
|
| 4462 |
// modify KV caches of decoder using data from kv_swap_bufs
|
| 4463 |
WHISPER_PRINT_DEBUG("%s: one-copy decoder using swap buffers: swap[%d] -> %d\n", __func__, view[i], i);
|
| 4464 |
-
memcpy(src[i].kv_self.k->data, kv_swap_bufs[view[i]].k.data(), kv_swap_bufs[view[i]].k.size());
|
| 4465 |
-
memcpy(src[i].kv_self.v->data, kv_swap_bufs[view[i]].v.data(), kv_swap_bufs[view[i]].v.size());
|
|
|
|
|
|
|
| 4466 |
} else {
|
| 4467 |
// modify KV caches of decoder using data from correspond decoder KV caches directly
|
| 4468 |
WHISPER_PRINT_DEBUG("%s: one-copy decoder without swap buffers: %d -> %d\n", __func__, view[i], i);
|
| 4469 |
-
memcpy(src[i].kv_self.k->data, src[view[i]].kv_self.k->data, ggml_nbytes(src[view[i]].kv_self.k));
|
| 4470 |
-
memcpy(src[i].kv_self.v->data, src[view[i]].kv_self.v->data, ggml_nbytes(src[view[i]].kv_self.v));
|
|
|
|
|
|
|
| 4471 |
}
|
| 4472 |
}
|
| 4473 |
|
|
@@ -4495,11 +4422,11 @@ int whisper_full_with_state(
|
|
| 4495 |
// compute log mel spectrogram
|
| 4496 |
if (params.speed_up) {
|
| 4497 |
// TODO: Replace PV with more advanced algorithm
|
| 4498 |
-
|
| 4499 |
return -1;
|
| 4500 |
} else {
|
| 4501 |
if (whisper_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) {
|
| 4502 |
-
|
| 4503 |
return -2;
|
| 4504 |
}
|
| 4505 |
}
|
|
@@ -4511,13 +4438,13 @@ int whisper_full_with_state(
|
|
| 4511 |
|
| 4512 |
const auto lang_id = whisper_lang_auto_detect_with_state(ctx, state, 0, params.n_threads, probs.data());
|
| 4513 |
if (lang_id < 0) {
|
| 4514 |
-
|
| 4515 |
return -3;
|
| 4516 |
}
|
| 4517 |
state->lang_id = lang_id;
|
| 4518 |
params.language = whisper_lang_str(lang_id);
|
| 4519 |
|
| 4520 |
-
|
| 4521 |
if (params.detect_language) {
|
| 4522 |
return 0;
|
| 4523 |
}
|
|
@@ -4575,8 +4502,8 @@ int whisper_full_with_state(
|
|
| 4575 |
|
| 4576 |
if (decoder.kv_self.ctx == nullptr) {
|
| 4577 |
decoder.kv_self = state->decoders[0].kv_self;
|
| 4578 |
-
if (!kv_cache_reinit(decoder.kv_self)) {
|
| 4579 |
-
|
| 4580 |
return -4;
|
| 4581 |
}
|
| 4582 |
|
|
@@ -4587,23 +4514,6 @@ int whisper_full_with_state(
|
|
| 4587 |
decoder.probs.resize (ctx->vocab.n_vocab);
|
| 4588 |
decoder.logits.resize (ctx->vocab.n_vocab);
|
| 4589 |
decoder.logprobs.resize(ctx->vocab.n_vocab);
|
| 4590 |
-
|
| 4591 |
-
// TODO: not very clean - look for a better way and potentially merging with the init of decoder 0
|
| 4592 |
-
#ifdef GGML_USE_METAL
|
| 4593 |
-
if (state->ctx_metal) {
|
| 4594 |
-
#define WHISPER_METAL_CHECK_BUF(result) \
|
| 4595 |
-
if (!(result)) { \
|
| 4596 |
-
log("%s: failed to add metal buffer\n", __func__); \
|
| 4597 |
-
return 0; \
|
| 4598 |
-
}
|
| 4599 |
-
|
| 4600 |
-
const std::string kv_name = "kv_self_" + std::to_string(j);
|
| 4601 |
-
auto & kv_self = decoder.kv_self;
|
| 4602 |
-
|
| 4603 |
-
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, kv_name.c_str(), kv_self.buf.data(), kv_self.buf.size(), 0));
|
| 4604 |
-
#undef WHISPER_METAL_CHECK_BUF
|
| 4605 |
-
}
|
| 4606 |
-
#endif
|
| 4607 |
}
|
| 4608 |
}
|
| 4609 |
|
|
@@ -4637,7 +4547,7 @@ int whisper_full_with_state(
|
|
| 4637 |
|
| 4638 |
// overwrite audio_ctx, max allowed is hparams.n_audio_ctx
|
| 4639 |
if (params.audio_ctx > whisper_n_audio_ctx(ctx)) {
|
| 4640 |
-
|
| 4641 |
return -5;
|
| 4642 |
}
|
| 4643 |
state->exp_n_audio_ctx = params.audio_ctx;
|
|
@@ -4662,7 +4572,7 @@ int whisper_full_with_state(
|
|
| 4662 |
// distilled models require the "no_timestamps" token
|
| 4663 |
// TODO: add input parameter (#1229)
|
| 4664 |
if (is_distil) {
|
| 4665 |
-
|
| 4666 |
prompt_init.push_back(whisper_token_not(ctx));
|
| 4667 |
}
|
| 4668 |
}
|
|
@@ -4699,14 +4609,14 @@ int whisper_full_with_state(
|
|
| 4699 |
|
| 4700 |
if (params.encoder_begin_callback) {
|
| 4701 |
if (params.encoder_begin_callback(ctx, state, params.encoder_begin_callback_user_data) == false) {
|
| 4702 |
-
|
| 4703 |
break;
|
| 4704 |
}
|
| 4705 |
}
|
| 4706 |
|
| 4707 |
// encode audio features starting at offset seek
|
| 4708 |
if (!whisper_encode_internal(*ctx, *state, seek, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
|
| 4709 |
-
|
| 4710 |
return -6;
|
| 4711 |
}
|
| 4712 |
|
|
@@ -4789,7 +4699,7 @@ int whisper_full_with_state(
|
|
| 4789 |
WHISPER_PRINT_DEBUG("\n\n");
|
| 4790 |
|
| 4791 |
if (!whisper_decode_internal(*ctx, *state, state->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
|
| 4792 |
-
|
| 4793 |
return -7;
|
| 4794 |
}
|
| 4795 |
|
|
@@ -4803,8 +4713,11 @@ int whisper_full_with_state(
|
|
| 4803 |
for (int j = 1; j < n_decoders_cur; ++j) {
|
| 4804 |
auto & decoder = state->decoders[j];
|
| 4805 |
|
| 4806 |
-
|
| 4807 |
-
memcpy(decoder.kv_self.
|
|
|
|
|
|
|
|
|
|
| 4808 |
|
| 4809 |
decoder.kv_self.n += prompt.size();
|
| 4810 |
|
|
@@ -5013,7 +4926,7 @@ int whisper_full_with_state(
|
|
| 5013 |
//WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, kv_self.n %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.kv_self.n, decoder.seek_delta);
|
| 5014 |
|
| 5015 |
if (!whisper_decode_internal(*ctx, *state, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
|
| 5016 |
-
|
| 5017 |
return -8;
|
| 5018 |
}
|
| 5019 |
|
|
@@ -5339,12 +5252,12 @@ int whisper_full_parallel(
|
|
| 5339 |
ctx->state->t_decode_us /= n_processors;
|
| 5340 |
|
| 5341 |
// print information about the audio boundaries
|
| 5342 |
-
|
| 5343 |
-
|
| 5344 |
for (int i = 0; i < n_processors - 1; ++i) {
|
| 5345 |
-
|
| 5346 |
}
|
| 5347 |
-
|
| 5348 |
|
| 5349 |
return ret;
|
| 5350 |
}
|
|
@@ -5586,12 +5499,12 @@ WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads) {
|
|
| 5586 |
double tsum = 0.0;
|
| 5587 |
|
| 5588 |
// heat-up
|
| 5589 |
-
ggml_graph_compute_helper(
|
| 5590 |
|
| 5591 |
for (int i = 0; i < n_max; ++i) {
|
| 5592 |
const int64_t t0 = ggml_time_us();
|
| 5593 |
|
| 5594 |
-
ggml_graph_compute_helper(
|
| 5595 |
|
| 5596 |
const int64_t t1 = ggml_time_us();
|
| 5597 |
|
|
@@ -5709,7 +5622,7 @@ static void whisper_exp_compute_token_level_timestamps(
|
|
| 5709 |
const int n_samples = state.energy.size();
|
| 5710 |
|
| 5711 |
if (n_samples == 0) {
|
| 5712 |
-
|
| 5713 |
return;
|
| 5714 |
}
|
| 5715 |
|
|
@@ -5930,6 +5843,38 @@ static void whisper_exp_compute_token_level_timestamps(
|
|
| 5930 |
//}
|
| 5931 |
}
|
| 5932 |
|
| 5933 |
-
void
|
| 5934 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5935 |
}
|
|
|
|
| 1 |
#include "whisper.h"
|
| 2 |
+
|
| 3 |
#ifdef WHISPER_USE_COREML
|
| 4 |
#include "coreml/whisper-encoder.h"
|
| 5 |
#endif
|
| 6 |
|
| 7 |
#ifdef GGML_USE_METAL
|
| 8 |
+
#include "ggml-metal.h"
|
| 9 |
+
#endif
|
| 10 |
+
|
| 11 |
+
#ifdef GGML_USE_CUBLAS
|
| 12 |
+
#include "ggml-cuda.h"
|
| 13 |
#endif
|
| 14 |
|
| 15 |
#ifdef WHISPER_USE_OPENVINO
|
|
|
|
| 18 |
|
| 19 |
#include "ggml.h"
|
| 20 |
#include "ggml-alloc.h"
|
| 21 |
+
#include "ggml-backend.h"
|
| 22 |
|
| 23 |
#include <algorithm>
|
| 24 |
#include <cassert>
|
|
|
|
| 103 |
#define BYTESWAP_TENSOR(t) do {} while (0)
|
| 104 |
#endif
|
| 105 |
|
| 106 |
+
#ifdef __GNUC__
|
| 107 |
+
#ifdef __MINGW32__
|
| 108 |
+
#define WHISPER_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
|
| 109 |
+
#else
|
| 110 |
+
#define WHISPER_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
|
| 111 |
+
#endif
|
| 112 |
+
#else
|
| 113 |
+
#define WHISPER_ATTRIBUTE_FORMAT(...)
|
| 114 |
+
#endif
|
| 115 |
+
|
| 116 |
+
//
|
| 117 |
+
// logging
|
| 118 |
+
//
|
| 119 |
+
|
| 120 |
+
WHISPER_ATTRIBUTE_FORMAT(2, 3)
|
| 121 |
+
static void whisper_log_internal (ggml_log_level level, const char* format, ...);
|
| 122 |
+
static void whisper_log_callback_default(ggml_log_level level, const char * text, void * user_data);
|
| 123 |
+
|
| 124 |
+
#define WHISPER_LOG_INFO(...) whisper_log_internal(GGML_LOG_LEVEL_INFO , __VA_ARGS__)
|
| 125 |
+
#define WHISPER_LOG_WARN(...) whisper_log_internal(GGML_LOG_LEVEL_WARN , __VA_ARGS__)
|
| 126 |
+
#define WHISPER_LOG_ERROR(...) whisper_log_internal(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
|
| 127 |
+
|
| 128 |
#define WHISPER_ASSERT(x) \
|
| 129 |
do { \
|
| 130 |
if (!(x)) { \
|
| 131 |
+
WHISPER_LOG_ERROR("WHISPER_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
|
| 132 |
abort(); \
|
| 133 |
} \
|
| 134 |
} while (0)
|
|
|
|
| 155 |
//
|
| 156 |
|
| 157 |
static void ggml_graph_compute_helper(
|
| 158 |
+
struct ggml_cgraph * graph,
|
| 159 |
std::vector<uint8_t> & buf,
|
|
|
|
| 160 |
int n_threads,
|
| 161 |
whisper_abort_callback abort_callback,
|
| 162 |
void * abort_callback_data) {
|
|
|
|
| 173 |
ggml_graph_compute(graph, &plan);
|
| 174 |
}
|
| 175 |
|
| 176 |
+
static void ggml_graph_compute_helper(
|
| 177 |
+
struct ggml_backend * backend,
|
| 178 |
+
struct ggml_cgraph * graph,
|
| 179 |
+
int n_threads) {
|
| 180 |
+
if (ggml_backend_is_cpu(backend)) {
|
| 181 |
+
ggml_backend_cpu_set_n_threads(backend, n_threads);
|
| 182 |
+
}
|
| 183 |
+
#ifdef GGML_USE_METAL
|
| 184 |
+
if (ggml_backend_is_metal(backend)) {
|
| 185 |
+
ggml_backend_metal_set_n_cb(backend, n_threads);
|
| 186 |
+
}
|
| 187 |
+
#endif
|
| 188 |
+
ggml_backend_graph_compute(backend, graph);
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
// faster matrix multiplications for tensors that do not have dimension 0 divisible by "pad"
|
| 192 |
// the idea is to represent the original matrix multiplication:
|
| 193 |
//
|
|
|
|
| 222 |
}
|
| 223 |
|
| 224 |
// TODO: check if other platforms can benefit from this optimization
|
| 225 |
+
// TODO: CUDA is currently broken - seems ggml_mul_mat does not handle views correctly
|
| 226 |
#if defined(GGML_USE_METAL)
|
| 227 |
#define ggml_mul_mat ggml_mul_mat_pad
|
| 228 |
#endif
|
|
|
|
| 349 |
{ "yue", { 99, "cantonese", } },
|
| 350 |
};
|
| 351 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 352 |
struct whisper_mel {
|
| 353 |
int n_len;
|
| 354 |
int n_len_org;
|
|
|
|
| 529 |
|
| 530 |
struct ggml_context * ctx;
|
| 531 |
|
| 532 |
+
ggml_backend_buffer_t buffer;
|
|
|
|
| 533 |
|
| 534 |
int n; // number of tokens currently in the cache
|
| 535 |
};
|
|
|
|
| 568 |
std::vector<whisper_layer_encoder> layers_encoder;
|
| 569 |
std::vector<whisper_layer_decoder> layers_decoder;
|
| 570 |
|
| 571 |
+
// ggml context that contains all the meta information about the model tensors
|
| 572 |
struct ggml_context * ctx;
|
| 573 |
|
| 574 |
+
// the model backend data is read-only and can be shared between processors
|
| 575 |
+
struct ggml_backend_buffer * buffer;
|
| 576 |
|
| 577 |
// tensors
|
| 578 |
int n_loaded;
|
|
|
|
| 637 |
ggml_allocr * alloc = nullptr;
|
| 638 |
|
| 639 |
std::vector<uint8_t> meta;
|
| 640 |
+
|
| 641 |
+
ggml_backend_buffer_t buffer;
|
| 642 |
};
|
| 643 |
|
| 644 |
static size_t whisper_allocr_size(struct whisper_allocr & allocr) {
|
| 645 |
+
return allocr.meta.size() + ggml_allocr_max_size(allocr.alloc);
|
| 646 |
}
|
| 647 |
|
| 648 |
// measure the memory usage of a graph and prepare the allocr's internal data buffer
|
| 649 |
+
static void whisper_allocr_graph_init(struct whisper_allocr & allocr, ggml_backend_t backend, std::function<struct ggml_cgraph *()> && get_graph) {
|
| 650 |
+
auto & alloc = allocr.alloc;
|
| 651 |
+
auto & meta = allocr.meta;
|
| 652 |
|
| 653 |
+
alloc = ggml_allocr_new_measure_from_backend(backend);
|
|
|
|
|
|
|
| 654 |
|
| 655 |
meta.resize(ggml_tensor_overhead()*WHISPER_MAX_NODES + ggml_graph_overhead());
|
| 656 |
|
| 657 |
+
ggml_allocr_alloc_graph(alloc, get_graph());
|
| 658 |
+
}
|
| 659 |
|
| 660 |
+
static void whisper_allocr_graph_realloc(struct whisper_allocr & allocr, ggml_backend_t backend) {
|
| 661 |
+
if (allocr.alloc == nullptr) {
|
| 662 |
+
// this can be null if we use external encoder like CoreML or OpenVINO
|
| 663 |
+
return;
|
| 664 |
+
}
|
| 665 |
|
| 666 |
+
auto & alloc = allocr.alloc;
|
| 667 |
+
auto & buffer = allocr.buffer;
|
| 668 |
|
| 669 |
+
size_t size = ggml_allocr_max_size(alloc);
|
| 670 |
|
| 671 |
+
ggml_allocr_free(alloc);
|
| 672 |
+
|
| 673 |
+
buffer = ggml_backend_alloc_buffer(backend, size);
|
| 674 |
+
alloc = ggml_allocr_new_from_buffer(buffer);
|
| 675 |
}
|
| 676 |
|
| 677 |
static void whisper_allocr_free(struct whisper_allocr & allocr) {
|
| 678 |
if (allocr.alloc) {
|
| 679 |
ggml_allocr_free(allocr.alloc);
|
| 680 |
+
ggml_backend_buffer_free(allocr.buffer);
|
| 681 |
allocr.alloc = nullptr;
|
| 682 |
}
|
| 683 |
}
|
|
|
|
| 706 |
// buffer for swapping KV caches between decoders during beam-search
|
| 707 |
std::vector<kv_buf> kv_swap_bufs;
|
| 708 |
|
| 709 |
+
ggml_backend_t backend = nullptr;
|
|
|
|
| 710 |
|
| 711 |
// ggml-alloc:
|
| 712 |
// - stores meta info about the intermediate tensors into the `meta` buffers
|
|
|
|
| 720 |
struct ggml_tensor * embd_conv = nullptr;
|
| 721 |
struct ggml_tensor * embd_enc = nullptr;
|
| 722 |
|
| 723 |
+
// helper for GPU offloading
|
| 724 |
+
std::vector<float> inp_mel;
|
| 725 |
+
|
| 726 |
// decode output (2-dimensional array: [n_tokens][n_vocab])
|
| 727 |
std::vector<float> logits;
|
| 728 |
|
|
|
|
| 737 |
int lang_id = 0; // english by default
|
| 738 |
|
| 739 |
std::string path_model; // populated by whisper_init_from_file_with_params()
|
| 740 |
+
|
| 741 |
#ifdef WHISPER_USE_COREML
|
| 742 |
whisper_coreml_context * ctx_coreml = nullptr;
|
| 743 |
#endif
|
| 744 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 745 |
#ifdef WHISPER_USE_OPENVINO
|
| 746 |
whisper_openvino_context * ctx_openvino = nullptr;
|
| 747 |
#endif
|
| 748 |
|
| 749 |
// [EXPERIMENTAL] token-level timestamps data
|
| 750 |
+
int64_t t_beg = 0;
|
| 751 |
int64_t t_last = 0;
|
| 752 |
+
|
| 753 |
whisper_token tid_last;
|
| 754 |
+
|
| 755 |
std::vector<float> energy; // PCM signal energy
|
| 756 |
|
| 757 |
// [EXPERIMENTAL] speed-up techniques
|
|
|
|
| 765 |
ggml_type wtype = ggml_type::GGML_TYPE_F16; // weight type (FP32 / FP16 / QX)
|
| 766 |
ggml_type itype = ggml_type::GGML_TYPE_F16; // intermediate type (FP32 or FP16)
|
| 767 |
|
| 768 |
+
whisper_context_params params;
|
| 769 |
+
|
| 770 |
whisper_model model;
|
| 771 |
whisper_vocab vocab;
|
| 772 |
+
|
| 773 |
whisper_state * state = nullptr;
|
| 774 |
|
| 775 |
+
ggml_backend_t backend = nullptr;
|
| 776 |
+
|
| 777 |
std::string path_model; // populated by whisper_init_from_file_with_params()
|
|
|
|
| 778 |
};
|
| 779 |
|
| 780 |
+
struct whisper_global {
|
| 781 |
+
// We save the log callback globally
|
| 782 |
+
ggml_log_callback log_callback = whisper_log_callback_default;
|
| 783 |
+
void * log_callback_user_data = nullptr;
|
| 784 |
+
};
|
| 785 |
|
| 786 |
+
static whisper_global g_state;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 787 |
|
| 788 |
template<typename T>
|
| 789 |
static void read_safe(whisper_model_loader * loader, T & dest) {
|
|
|
|
| 794 |
static bool kv_cache_init(
|
| 795 |
const struct whisper_hparams & hparams,
|
| 796 |
struct whisper_kv_cache & cache,
|
| 797 |
+
ggml_backend_t backend,
|
| 798 |
ggml_type wtype,
|
| 799 |
int n_ctx) {
|
| 800 |
const int64_t n_text_state = hparams.n_text_state;
|
|
|
|
| 803 |
const int64_t n_mem = n_text_layer*n_ctx;
|
| 804 |
const int64_t n_elements = n_text_state*n_mem;
|
| 805 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 806 |
struct ggml_init_params params = {
|
| 807 |
+
/*.mem_size =*/ 2*ggml_tensor_overhead(),
|
| 808 |
+
/*.mem_buffer =*/ nullptr,
|
| 809 |
+
/*.no_alloc =*/ true,
|
| 810 |
};
|
| 811 |
|
| 812 |
cache.ctx = ggml_init(params);
|
| 813 |
|
| 814 |
if (!cache.ctx) {
|
| 815 |
+
WHISPER_LOG_ERROR("%s: failed to allocate memory for kv cache\n", __func__);
|
| 816 |
return false;
|
| 817 |
}
|
| 818 |
|
| 819 |
cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
|
| 820 |
cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
|
| 821 |
|
| 822 |
+
const size_t mem_bytes = ggml_nbytes(cache.k) + ggml_nbytes(cache.v);
|
| 823 |
+
|
| 824 |
+
cache.buffer = ggml_backend_alloc_buffer(backend, mem_bytes);
|
| 825 |
+
|
| 826 |
+
// allocate the tensors into the backend buffer
|
| 827 |
+
{
|
| 828 |
+
ggml_allocr * alloc = ggml_allocr_new_from_buffer(cache.buffer);
|
| 829 |
+
|
| 830 |
+
ggml_allocr_alloc(alloc, cache.k);
|
| 831 |
+
ggml_allocr_alloc(alloc, cache.v);
|
| 832 |
+
|
| 833 |
+
ggml_allocr_free(alloc);
|
| 834 |
+
}
|
| 835 |
+
|
| 836 |
return true;
|
| 837 |
}
|
| 838 |
|
| 839 |
+
// TODO: remove after batched decoding
|
| 840 |
+
static bool kv_cache_reinit(struct whisper_kv_cache & cache, ggml_backend_t backend) {
|
| 841 |
WHISPER_ASSERT(cache.ctx);
|
| 842 |
|
| 843 |
const int n_elements = ggml_nelements(cache.k);
|
|
|
|
| 846 |
const ggml_type wtype = cache.k->type;
|
| 847 |
WHISPER_ASSERT(wtype == cache.v->type);
|
| 848 |
|
|
|
|
|
|
|
| 849 |
struct ggml_init_params params = {
|
| 850 |
+
/*.mem_size =*/ 2*ggml_tensor_overhead(),
|
| 851 |
+
/*.mem_buffer =*/ nullptr,
|
| 852 |
+
/*.no_alloc =*/ true,
|
| 853 |
};
|
| 854 |
|
| 855 |
cache.ctx = ggml_init(params);
|
| 856 |
|
| 857 |
if (!cache.ctx) {
|
| 858 |
+
WHISPER_LOG_ERROR("%s: failed to allocate memory for kv cache\n", __func__);
|
| 859 |
return false;
|
| 860 |
}
|
| 861 |
|
| 862 |
cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
|
| 863 |
cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
|
| 864 |
|
| 865 |
+
const size_t mem_bytes = ggml_nbytes(cache.k) + ggml_nbytes(cache.v);
|
| 866 |
+
|
| 867 |
+
cache.buffer = ggml_backend_alloc_buffer(backend, mem_bytes);
|
| 868 |
+
|
| 869 |
+
// allocate the tensors into the backend buffer
|
| 870 |
+
{
|
| 871 |
+
ggml_allocr * alloc = ggml_allocr_new_from_buffer(cache.buffer);
|
| 872 |
+
|
| 873 |
+
ggml_allocr_alloc(alloc, cache.k);
|
| 874 |
+
ggml_allocr_alloc(alloc, cache.v);
|
| 875 |
+
|
| 876 |
+
ggml_allocr_free(alloc);
|
| 877 |
+
}
|
| 878 |
+
|
| 879 |
return true;
|
| 880 |
}
|
| 881 |
|
| 882 |
static void kv_cache_free(struct whisper_kv_cache & cache) {
|
| 883 |
if (cache.ctx) {
|
| 884 |
ggml_free(cache.ctx);
|
| 885 |
+
ggml_backend_buffer_free(cache.buffer);
|
| 886 |
cache.ctx = nullptr;
|
| 887 |
}
|
| 888 |
}
|
| 889 |
|
| 890 |
+
static ggml_backend_t whisper_backend_init(const whisper_context_params & params) {
|
| 891 |
+
ggml_backend_t backend_gpu = NULL;
|
| 892 |
+
|
| 893 |
+
// initialize the backends
|
| 894 |
+
#ifdef GGML_USE_CUBLAS
|
| 895 |
+
if (params.use_gpu) {
|
| 896 |
+
WHISPER_LOG_INFO("%s: using CUDA backend\n", __func__);
|
| 897 |
+
backend_gpu = ggml_backend_cuda_init();
|
| 898 |
+
if (!backend_gpu) {
|
| 899 |
+
WHISPER_LOG_ERROR("%s: ggml_backend_cuda_init() failed\n", __func__);
|
| 900 |
+
}
|
| 901 |
+
}
|
| 902 |
+
#endif
|
| 903 |
+
|
| 904 |
+
#ifdef GGML_USE_METAL
|
| 905 |
+
if (params.use_gpu) {
|
| 906 |
+
WHISPER_LOG_INFO("%s: using Metal backend\n", __func__);
|
| 907 |
+
ggml_metal_log_set_callback(whisper_log_callback_default, nullptr);
|
| 908 |
+
backend_gpu = ggml_backend_metal_init();
|
| 909 |
+
if (!backend_gpu) {
|
| 910 |
+
WHISPER_LOG_ERROR("%s: ggml_backend_metal_init() failed\n", __func__);
|
| 911 |
+
}
|
| 912 |
+
}
|
| 913 |
+
#endif
|
| 914 |
+
|
| 915 |
+
if (backend_gpu) {
|
| 916 |
+
return backend_gpu;
|
| 917 |
+
}
|
| 918 |
+
return ggml_backend_cpu_init();
|
| 919 |
+
}
|
| 920 |
+
|
| 921 |
// load the model from a ggml file
|
| 922 |
//
|
| 923 |
// file format:
|
|
|
|
| 930 |
// see the convert-pt-to-ggml.py script for details
|
| 931 |
//
|
| 932 |
static bool whisper_model_load(struct whisper_model_loader * loader, whisper_context & wctx) {
|
| 933 |
+
WHISPER_LOG_INFO("%s: loading model\n", __func__);
|
| 934 |
|
| 935 |
const int64_t t_start_us = ggml_time_us();
|
| 936 |
|
|
|
|
| 944 |
uint32_t magic;
|
| 945 |
read_safe(loader, magic);
|
| 946 |
if (magic != GGML_FILE_MAGIC) {
|
| 947 |
+
WHISPER_LOG_ERROR("%s: invalid model data (bad magic)\n", __func__);
|
| 948 |
return false;
|
| 949 |
}
|
| 950 |
}
|
|
|
|
| 1001 |
// in order to save memory and also to speed up the computation
|
| 1002 |
wctx.wtype = ggml_ftype_to_ggml_type((ggml_ftype) (model.hparams.ftype));
|
| 1003 |
if (wctx.wtype == GGML_TYPE_COUNT) {
|
| 1004 |
+
WHISPER_LOG_ERROR("%s: invalid model (bad ftype value %d)\n", __func__, model.hparams.ftype);
|
| 1005 |
return false;
|
| 1006 |
}
|
| 1007 |
|
| 1008 |
+
WHISPER_LOG_INFO("%s: n_vocab = %d\n", __func__, hparams.n_vocab);
|
| 1009 |
+
WHISPER_LOG_INFO("%s: n_audio_ctx = %d\n", __func__, hparams.n_audio_ctx);
|
| 1010 |
+
WHISPER_LOG_INFO("%s: n_audio_state = %d\n", __func__, hparams.n_audio_state);
|
| 1011 |
+
WHISPER_LOG_INFO("%s: n_audio_head = %d\n", __func__, hparams.n_audio_head);
|
| 1012 |
+
WHISPER_LOG_INFO("%s: n_audio_layer = %d\n", __func__, hparams.n_audio_layer);
|
| 1013 |
+
WHISPER_LOG_INFO("%s: n_text_ctx = %d\n", __func__, hparams.n_text_ctx);
|
| 1014 |
+
WHISPER_LOG_INFO("%s: n_text_state = %d\n", __func__, hparams.n_text_state);
|
| 1015 |
+
WHISPER_LOG_INFO("%s: n_text_head = %d\n", __func__, hparams.n_text_head);
|
| 1016 |
+
WHISPER_LOG_INFO("%s: n_text_layer = %d\n", __func__, hparams.n_text_layer);
|
| 1017 |
+
WHISPER_LOG_INFO("%s: n_mels = %d\n", __func__, hparams.n_mels);
|
| 1018 |
+
WHISPER_LOG_INFO("%s: ftype = %d\n", __func__, model.hparams.ftype);
|
| 1019 |
+
WHISPER_LOG_INFO("%s: qntvr = %d\n", __func__, qntvr);
|
| 1020 |
+
WHISPER_LOG_INFO("%s: type = %d (%s%s)\n", __func__, model.type, g_model_name.at(model.type).c_str(), mver.c_str());
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1021 |
}
|
| 1022 |
|
| 1023 |
// load mel filters
|
|
|
|
| 1038 |
read_safe(loader, n_vocab);
|
| 1039 |
|
| 1040 |
//if (n_vocab != model.hparams.n_vocab) {
|
| 1041 |
+
// WHISPER_LOG_ERROR("%s: invalid model file '%s' (bad vocab size %d != %d)\n",
|
| 1042 |
// __func__, fname.c_str(), n_vocab, model.hparams.n_vocab);
|
| 1043 |
// return false;
|
| 1044 |
//}
|
|
|
|
| 1058 |
word.assign(&tmp[0], tmp.size());
|
| 1059 |
} else {
|
| 1060 |
// seems like we have an empty-string token in multi-language models (i = 50256)
|
| 1061 |
+
//WHISPER_LOG_WARN("%s: warning: empty-string token in vocab, i = %d\n", __func__, i);
|
| 1062 |
word = "";
|
| 1063 |
}
|
| 1064 |
|
|
|
|
| 1086 |
}
|
| 1087 |
|
| 1088 |
if (n_vocab < model.hparams.n_vocab) {
|
| 1089 |
+
WHISPER_LOG_INFO("%s: adding %d extra tokens\n", __func__, model.hparams.n_vocab - n_vocab);
|
| 1090 |
for (int i = n_vocab; i < model.hparams.n_vocab; i++) {
|
| 1091 |
if (i > vocab.token_beg) {
|
| 1092 |
word = "[_TT_" + std::to_string(i - vocab.token_beg) + "]";
|
|
|
|
| 1112 |
}
|
| 1113 |
}
|
| 1114 |
|
| 1115 |
+
WHISPER_LOG_INFO("%s: n_langs = %d\n", __func__, vocab.num_languages());
|
| 1116 |
}
|
| 1117 |
|
|
|
|
|
|
|
| 1118 |
const ggml_type wtype = wctx.wtype;
|
| 1119 |
const ggml_type vtype = wctx.wtype == GGML_TYPE_F32 ? GGML_TYPE_F32 : GGML_TYPE_F16; // conv type
|
| 1120 |
|
| 1121 |
+
// create the ggml context
|
| 1122 |
{
|
| 1123 |
const auto & hparams = model.hparams;
|
| 1124 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1125 |
const int n_audio_layer = hparams.n_audio_layer;
|
| 1126 |
+
const int n_text_layer = hparams.n_text_layer;
|
| 1127 |
|
| 1128 |
+
const size_t n_tensors = 10 /* input */ + 15 + 15*n_audio_layer + 24*n_text_layer;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1129 |
|
|
|
|
|
|
|
| 1130 |
struct ggml_init_params params = {
|
| 1131 |
+
/*.mem_size =*/ n_tensors*ggml_tensor_overhead(),
|
| 1132 |
+
/*.mem_buffer =*/ nullptr,
|
| 1133 |
+
/*.no_alloc =*/ true,
|
| 1134 |
};
|
| 1135 |
|
| 1136 |
model.ctx = ggml_init(params);
|
| 1137 |
if (!model.ctx) {
|
| 1138 |
+
WHISPER_LOG_ERROR("%s: ggml_init() failed\n", __func__);
|
| 1139 |
return false;
|
| 1140 |
}
|
| 1141 |
}
|
| 1142 |
|
| 1143 |
+
// prepare tensors for the weights
|
| 1144 |
{
|
| 1145 |
auto & ctx = model.ctx;
|
| 1146 |
|
|
|
|
| 1163 |
|
| 1164 |
// encoder
|
| 1165 |
{
|
| 1166 |
+
model.e_pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_state, n_audio_ctx);
|
| 1167 |
|
| 1168 |
+
model.e_conv_1_w = ggml_new_tensor_3d(ctx, vtype, 3, n_mels, n_audio_state);
|
| 1169 |
+
model.e_conv_1_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 2*n_audio_ctx, n_audio_state);
|
| 1170 |
|
| 1171 |
+
model.e_conv_2_w = ggml_new_tensor_3d(ctx, vtype, 3, n_audio_state, n_audio_state);
|
| 1172 |
+
model.e_conv_2_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_ctx, n_audio_state);
|
| 1173 |
|
| 1174 |
+
model.e_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
|
| 1175 |
+
model.e_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
|
| 1176 |
|
| 1177 |
// map by name
|
| 1178 |
model.tensors["encoder.positional_embedding"] = model.e_pe;
|
|
|
|
| 1336 |
}
|
| 1337 |
}
|
| 1338 |
|
| 1339 |
+
wctx.backend = whisper_backend_init(wctx.params);
|
| 1340 |
+
|
| 1341 |
+
{
|
| 1342 |
+
size_t size_main = 0;
|
| 1343 |
+
|
| 1344 |
+
for (const auto & t : model.tensors) {
|
| 1345 |
+
size_main += ggml_nbytes(t.second) + ggml_tensor_overhead();
|
| 1346 |
+
}
|
| 1347 |
+
|
| 1348 |
+
model.buffer = ggml_backend_alloc_buffer(wctx.backend, size_main);
|
| 1349 |
+
|
| 1350 |
+
WHISPER_LOG_INFO("%s: %8s buffer size = %8.2f MB\n", __func__, ggml_backend_name(wctx.backend), size_main / 1024.0 / 1024.0);
|
| 1351 |
+
}
|
| 1352 |
+
|
| 1353 |
+
ggml_allocr * alloc = ggml_allocr_new_from_buffer(model.buffer);
|
| 1354 |
+
|
| 1355 |
+
// allocate tensors in the backend buffers
|
| 1356 |
+
{
|
| 1357 |
+
for (const auto & t : model.tensors) {
|
| 1358 |
+
ggml_allocr_alloc(alloc, t.second);
|
| 1359 |
+
}
|
| 1360 |
+
}
|
| 1361 |
+
|
| 1362 |
// load weights
|
| 1363 |
{
|
| 1364 |
size_t total_size = 0;
|
| 1365 |
|
| 1366 |
model.n_loaded = 0;
|
| 1367 |
|
| 1368 |
+
std::vector<char> read_buf;
|
| 1369 |
+
|
| 1370 |
while (true) {
|
| 1371 |
int32_t n_dims;
|
| 1372 |
int32_t length;
|
|
|
|
| 1393 |
name.assign(&tmp[0], tmp.size());
|
| 1394 |
|
| 1395 |
if (model.tensors.find(name) == model.tensors.end()) {
|
| 1396 |
+
WHISPER_LOG_ERROR("%s: unknown tensor '%s' in model file\n", __func__, name.data());
|
| 1397 |
return false;
|
| 1398 |
}
|
| 1399 |
|
| 1400 |
auto tensor = model.tensors[name.data()];
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1401 |
|
| 1402 |
+
const bool is_conv_bias = (name == "encoder.conv1.bias" || name == "encoder.conv2.bias");
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1403 |
|
| 1404 |
+
if (!is_conv_bias) {
|
| 1405 |
+
if (ggml_nelements(tensor) != nelements) {
|
| 1406 |
+
WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
|
| 1407 |
+
WHISPER_LOG_ERROR("%s: shape: [%d, %d, %d], expected: [%d, %d, %d]\n",
|
| 1408 |
+
__func__, ne[0], ne[1], ne[2], (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2]);
|
| 1409 |
+
return false;
|
| 1410 |
+
}
|
| 1411 |
|
| 1412 |
+
if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2]) {
|
| 1413 |
+
WHISPER_LOG_ERROR("%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]\n",
|
| 1414 |
+
__func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2], ne[0], ne[1], ne[2]);
|
| 1415 |
+
return false;
|
| 1416 |
+
}
|
| 1417 |
+
|
| 1418 |
+
const size_t bpe = ggml_type_size(ggml_type(ttype));
|
| 1419 |
+
|
| 1420 |
+
if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
|
| 1421 |
+
WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
|
| 1422 |
+
__func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
|
| 1423 |
+
return false;
|
| 1424 |
+
}
|
| 1425 |
}
|
| 1426 |
|
| 1427 |
+
ggml_backend_t backend = wctx.backend;
|
| 1428 |
+
|
| 1429 |
+
//printf("%s: [%5.5s] %s\n", __func__, ggml_backend_name(backend), name.c_str());
|
| 1430 |
+
|
| 1431 |
+
if ((ggml_backend_is_cpu(backend)
|
| 1432 |
+
#ifdef GGML_USE_METAL
|
| 1433 |
+
|| ggml_backend_is_metal(backend)
|
| 1434 |
+
#endif
|
| 1435 |
+
) && !is_conv_bias) {
|
| 1436 |
+
// for the CPU and Metal backend, we can read directly into the tensor
|
| 1437 |
+
loader->read(loader->context, tensor->data, ggml_nbytes(tensor));
|
| 1438 |
+
BYTESWAP_TENSOR(tensor);
|
| 1439 |
+
} else {
|
| 1440 |
+
// read into a temporary buffer first, then copy to device memory
|
| 1441 |
+
read_buf.resize(ggml_nbytes(tensor));
|
| 1442 |
+
|
| 1443 |
+
// we repeat the 2 bias tensors along dim 0:
|
| 1444 |
+
// [1, 512] -> [3000, 512] (conv1.bias)
|
| 1445 |
+
// [1, 512] -> [1500, 512] (conv2.bias)
|
| 1446 |
+
if (is_conv_bias) {
|
| 1447 |
+
loader->read(loader->context, read_buf.data(), read_buf.size() / tensor->ne[0]);
|
| 1448 |
+
|
| 1449 |
+
float * data_f32 = (float *) read_buf.data();
|
| 1450 |
+
for (int64_t y = 0; y < tensor->ne[1]; ++y) {
|
| 1451 |
+
const int64_t yy = tensor->ne[1] - y - 1;
|
| 1452 |
+
const float val = data_f32[yy];
|
| 1453 |
+
|
| 1454 |
+
for (int64_t x = 0; x < tensor->ne[0]; ++x) {
|
| 1455 |
+
data_f32[yy*tensor->ne[0] + x] = val;
|
| 1456 |
+
}
|
| 1457 |
+
}
|
| 1458 |
+
} else {
|
| 1459 |
+
loader->read(loader->context, read_buf.data(), read_buf.size());
|
| 1460 |
+
}
|
| 1461 |
+
|
| 1462 |
+
ggml_backend_tensor_set(tensor, read_buf.data(), 0, ggml_nbytes(tensor));
|
| 1463 |
+
}
|
| 1464 |
|
| 1465 |
//printf("%48s - [%5d, %5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ne[2], ggml_type_name((ggml_type) ttype), ggml_nbytes(tensor)/1024.0/1024.0);
|
| 1466 |
total_size += ggml_nbytes(tensor);
|
| 1467 |
model.n_loaded++;
|
| 1468 |
}
|
| 1469 |
|
| 1470 |
+
WHISPER_LOG_INFO("%s: model size = %7.2f MB\n", __func__, total_size/1024.0/1024.0);
|
| 1471 |
|
| 1472 |
if (model.n_loaded == 0) {
|
| 1473 |
+
WHISPER_LOG_WARN("%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__);
|
| 1474 |
} else if (model.n_loaded != (int) model.tensors.size()) {
|
| 1475 |
+
WHISPER_LOG_ERROR("%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", __func__, model.tensors.size(), model.n_loaded);
|
| 1476 |
return false;
|
| 1477 |
}
|
| 1478 |
}
|
| 1479 |
|
| 1480 |
+
ggml_allocr_free(alloc);
|
| 1481 |
+
|
| 1482 |
wctx.t_load_us = ggml_time_us() - t_start_us;
|
| 1483 |
|
| 1484 |
return true;
|
|
|
|
| 1534 |
if (!ggml_allocr_is_measure(alloc)) {
|
| 1535 |
assert(mel_inp.n_mel == n_mels);
|
| 1536 |
|
| 1537 |
+
wstate.inp_mel.resize(ggml_nelements(mel));
|
| 1538 |
+
|
| 1539 |
+
float * dst = wstate.inp_mel.data();
|
| 1540 |
memset(dst, 0, ggml_nbytes(mel));
|
| 1541 |
|
| 1542 |
+
const int i0 = std::min(mel_offset, mel_inp.n_len);
|
| 1543 |
const int i1 = std::min(mel_offset + 2*n_ctx, mel_inp.n_len);
|
| 1544 |
|
| 1545 |
for (int j = 0; j < mel_inp.n_mel; ++j) {
|
|
|
|
| 1547 |
dst[j*2*n_ctx + (i - i0)] = mel_inp.data[j*mel_inp.n_len + i];
|
| 1548 |
}
|
| 1549 |
}
|
| 1550 |
+
|
| 1551 |
+
ggml_backend_tensor_set(mel, wstate.inp_mel.data(), 0, ggml_nelements(mel)*sizeof(float));
|
| 1552 |
}
|
| 1553 |
|
| 1554 |
struct ggml_tensor * cur = nullptr;
|
|
|
|
| 1557 |
// convolution + gelu
|
| 1558 |
{
|
| 1559 |
cur = ggml_conv_1d_ph(ctx0, model.e_conv_1_w, mel, 1, 1);
|
| 1560 |
+
cur = ggml_add(ctx0, cur, model.e_conv_1_b);
|
| 1561 |
+
//cur = ggml_add(ctx0,
|
| 1562 |
+
// ggml_repeat(ctx0,
|
| 1563 |
+
// model.e_conv_1_b,
|
| 1564 |
+
// cur),
|
| 1565 |
+
// cur);
|
| 1566 |
|
| 1567 |
cur = ggml_gelu(ctx0, cur);
|
| 1568 |
|
| 1569 |
cur = ggml_conv_1d_ph(ctx0, model.e_conv_2_w, cur, 2, 1);
|
| 1570 |
+
cur = ggml_add(ctx0, cur, model.e_conv_2_b);
|
| 1571 |
+
//cur = ggml_add(ctx0,
|
| 1572 |
+
// ggml_repeat(ctx0,
|
| 1573 |
+
// model.e_conv_2_b,
|
| 1574 |
+
// cur),
|
| 1575 |
+
// cur);
|
| 1576 |
|
| 1577 |
cur = ggml_gelu(ctx0, cur);
|
| 1578 |
}
|
| 1579 |
|
| 1580 |
+
ggml_set_name(cur, "embd_conv");
|
| 1581 |
wstate.embd_conv = cur;
|
| 1582 |
} else {
|
| 1583 |
#ifdef WHISPER_USE_COREML
|
|
|
|
| 1597 |
}
|
| 1598 |
#endif
|
| 1599 |
|
| 1600 |
+
ggml_set_name(cur, "embd_enc");
|
| 1601 |
wstate.embd_enc = cur;
|
| 1602 |
}
|
| 1603 |
|
|
|
|
| 1631 |
|
| 1632 |
ggml_allocr * alloc = wstate.alloc_encode.alloc;
|
| 1633 |
|
| 1634 |
+
//struct ggml_tensor * cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_ctx, n_state);
|
| 1635 |
+
//ggml_allocr_alloc(alloc, cur);
|
| 1636 |
+
|
| 1637 |
+
//if (!ggml_allocr_is_measure(alloc)) {
|
| 1638 |
+
// ggml_backend_tensor_copy(wstate.embd_conv, cur);
|
| 1639 |
+
//}
|
| 1640 |
+
struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_conv);
|
| 1641 |
+
|
| 1642 |
struct ggml_tensor * KQscale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
|
| 1643 |
ggml_allocr_alloc(alloc, KQscale);
|
| 1644 |
|
| 1645 |
if (!ggml_allocr_is_measure(alloc)) {
|
| 1646 |
+
const float val = 1.0f/sqrtf(float(n_state)/n_head);
|
| 1647 |
+
ggml_backend_tensor_set(KQscale, &val, 0, sizeof(float));
|
| 1648 |
}
|
| 1649 |
|
|
|
|
|
|
|
| 1650 |
// ===================================================================
|
| 1651 |
// NOTE: experimenting with partial evaluation of the encoder (ignore)
|
| 1652 |
//static int iter = -1;
|
|
|
|
| 1665 |
const size_t e_pe_offset = model.e_pe->ne[0]*ggml_element_size(model.e_pe)*n_ctx*iter;
|
| 1666 |
|
| 1667 |
struct ggml_tensor * e_pe = ggml_view_2d(ctx0, model.e_pe, model.e_pe->ne[0], n_ctx, e_pe_stride, e_pe_offset);
|
|
|
|
| 1668 |
cur = ggml_add(ctx0, e_pe, ggml_cont(ctx0, ggml_transpose(ctx0, cur)));
|
| 1669 |
|
| 1670 |
// ===================================================================
|
|
|
|
| 1886 |
|
| 1887 |
ggml_allocr * alloc = wstate.alloc_cross.alloc;
|
| 1888 |
|
| 1889 |
+
//struct ggml_tensor * cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx);
|
| 1890 |
+
//ggml_allocr_alloc(alloc, cur);
|
| 1891 |
+
|
| 1892 |
+
//if (!ggml_allocr_is_measure(alloc)) {
|
| 1893 |
+
// ggml_backend_tensor_copy(wstate.embd_enc, cur);
|
| 1894 |
+
//}
|
| 1895 |
struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_enc);
|
| 1896 |
|
| 1897 |
struct ggml_tensor * Kscale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
|
| 1898 |
ggml_allocr_alloc(alloc, Kscale);
|
| 1899 |
|
| 1900 |
if (!ggml_allocr_is_measure(alloc)) {
|
| 1901 |
+
const float val = pow(float(n_state) / n_head, -0.25);
|
| 1902 |
+
ggml_backend_tensor_set(Kscale, &val, 0, sizeof(float));
|
| 1903 |
}
|
| 1904 |
|
| 1905 |
for (int il = 0; il < model.hparams.n_text_layer; ++il) {
|
|
|
|
| 1970 |
ggml_allocr_alloc_graph(alloc, gf);
|
| 1971 |
|
| 1972 |
if (!whisper_encode_external(wstate)) {
|
| 1973 |
+
ggml_graph_compute_helper(wstate.backend, gf, n_threads);
|
| 1974 |
}
|
| 1975 |
}
|
| 1976 |
|
|
|
|
| 1984 |
|
| 1985 |
ggml_allocr_alloc_graph(alloc, gf);
|
| 1986 |
|
| 1987 |
+
ggml_graph_compute_helper(wstate.backend, gf, n_threads);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1988 |
}
|
| 1989 |
|
| 1990 |
// cross
|
|
|
|
| 1997 |
|
| 1998 |
ggml_allocr_alloc_graph(alloc, gf);
|
| 1999 |
|
| 2000 |
+
ggml_graph_compute_helper(wstate.backend, gf, n_threads);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2001 |
}
|
| 2002 |
|
|
|
|
|
|
|
| 2003 |
wstate.t_encode_us += ggml_time_us() - t_start_us;
|
| 2004 |
wstate.n_encode++;
|
| 2005 |
|
|
|
|
| 2046 |
ggml_allocr_alloc(alloc, embd);
|
| 2047 |
|
| 2048 |
if (!ggml_allocr_is_measure(alloc)) {
|
| 2049 |
+
ggml_backend_tensor_set(embd, tokens, 0, N*ggml_element_size(embd));
|
| 2050 |
}
|
| 2051 |
|
| 2052 |
struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
|
|
|
|
| 2054 |
|
| 2055 |
if (!ggml_allocr_is_measure(alloc)) {
|
| 2056 |
for (int i = 0; i < N; ++i) {
|
| 2057 |
+
const int32_t val = n_past + i;
|
| 2058 |
+
ggml_backend_tensor_set(position, &val, i*sizeof(int32_t), sizeof(int32_t));
|
| 2059 |
}
|
| 2060 |
}
|
| 2061 |
|
|
|
|
| 2063 |
ggml_allocr_alloc(alloc, KQscale);
|
| 2064 |
|
| 2065 |
if (!ggml_allocr_is_measure(alloc)) {
|
| 2066 |
+
const float val = pow(float(n_state)/n_head, -0.25);
|
| 2067 |
+
ggml_backend_tensor_set(KQscale, &val, 0, sizeof(float));
|
| 2068 |
}
|
| 2069 |
|
| 2070 |
// token encoding + position encoding
|
|
|
|
| 2388 |
|
| 2389 |
logits = gf->nodes[gf->n_nodes - 1];
|
| 2390 |
|
| 2391 |
+
ggml_graph_compute_helper(wstate.backend, gf, n_threads);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2392 |
}
|
| 2393 |
|
| 2394 |
// extract logits for all N tokens
|
| 2395 |
//logits_out.resize(n_tokens*n_vocab);
|
| 2396 |
//memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_tokens*n_vocab);
|
| 2397 |
+
//ggml_backend_tensor_get(logits, logits_out.data(), (n_vocab*(n_tokens - 1))*sizeof(float), sizeof(float)*n_vocab);
|
| 2398 |
|
| 2399 |
// extract logits only for the last token
|
| 2400 |
logits_out.resize(n_vocab);
|
| 2401 |
+
//memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_vocab);
|
| 2402 |
+
ggml_backend_tensor_get(logits, logits_out.data(), 0, sizeof(float)*n_vocab);
|
| 2403 |
|
| 2404 |
if (n_tokens > 1) {
|
| 2405 |
//printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
|
|
|
|
| 2765 |
--j;
|
| 2766 |
}
|
| 2767 |
if (!found) {
|
| 2768 |
+
WHISPER_LOG_ERROR("unknown token\n");
|
| 2769 |
++i;
|
| 2770 |
}
|
| 2771 |
}
|
|
|
|
| 2828 |
|
| 2829 |
struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
| 2830 |
fill_sin_cos_table();
|
| 2831 |
+
|
| 2832 |
whisper_state * state = new whisper_state;
|
| 2833 |
|
| 2834 |
+
state->backend = whisper_backend_init(ctx->params);
|
| 2835 |
+
|
| 2836 |
+
if (!kv_cache_init(ctx->model.hparams, state->decoders[0].kv_self, ctx->backend, ctx->itype, ctx->model.hparams.n_text_ctx)) {
|
| 2837 |
+
WHISPER_LOG_ERROR("%s: kv_cache_init() failed for self-attention cache\n", __func__);
|
| 2838 |
delete state;
|
| 2839 |
return nullptr;
|
| 2840 |
}
|
| 2841 |
|
| 2842 |
{
|
| 2843 |
const size_t memory_size = ggml_nbytes(state->decoders[0].kv_self.k) + ggml_nbytes(state->decoders[0].kv_self.v);
|
| 2844 |
+
WHISPER_LOG_INFO("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
|
| 2845 |
}
|
| 2846 |
|
| 2847 |
+
if (!kv_cache_init(ctx->model.hparams, state->kv_cross, ctx->backend, ctx->itype, ctx->model.hparams.n_audio_ctx)) {
|
| 2848 |
+
WHISPER_LOG_ERROR("%s: kv_cache_init() failed for cross-attention cache\n", __func__);
|
| 2849 |
delete state;
|
| 2850 |
return nullptr;
|
| 2851 |
}
|
| 2852 |
|
| 2853 |
{
|
| 2854 |
const size_t memory_size = ggml_nbytes(state->kv_cross.k) + ggml_nbytes(state->kv_cross.v);
|
| 2855 |
+
WHISPER_LOG_INFO("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
|
| 2856 |
}
|
| 2857 |
|
| 2858 |
#ifdef WHISPER_USE_COREML
|
| 2859 |
const auto path_coreml = whisper_get_coreml_path_encoder(ctx->path_model);
|
| 2860 |
|
| 2861 |
+
WHISPER_LOG_INFO("%s: loading Core ML model from '%s'\n", __func__, path_coreml.c_str());
|
| 2862 |
+
WHISPER_LOG_INFO("%s: first run on a device may take a while ...\n", __func__);
|
| 2863 |
|
| 2864 |
state->ctx_coreml = whisper_coreml_init(path_coreml.c_str());
|
| 2865 |
if (!state->ctx_coreml) {
|
| 2866 |
+
WHISPER_LOG_ERROR("%s: failed to load Core ML model from '%s'\n", __func__, path_coreml.c_str());
|
| 2867 |
#ifndef WHISPER_COREML_ALLOW_FALLBACK
|
| 2868 |
delete state;
|
| 2869 |
return nullptr;
|
| 2870 |
#endif
|
| 2871 |
} else {
|
| 2872 |
+
WHISPER_LOG_INFO("%s: Core ML model loaded\n", __func__);
|
| 2873 |
}
|
| 2874 |
#endif
|
| 2875 |
|
|
|
|
| 2886 |
|
| 2887 |
// conv allocator
|
| 2888 |
{
|
| 2889 |
+
whisper_allocr_graph_init(state->alloc_conv, ctx->backend,
|
| 2890 |
[&]() {
|
| 2891 |
return whisper_build_graph_conv(*ctx, *state, 0);
|
| 2892 |
});
|
| 2893 |
|
| 2894 |
+
WHISPER_LOG_INFO("%s: compute buffer (conv) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_conv) / 1024.0 / 1024.0);
|
| 2895 |
}
|
| 2896 |
|
| 2897 |
// encoder allocator
|
| 2898 |
if (!whisper_encode_external(*state)) {
|
| 2899 |
+
whisper_allocr_graph_init(state->alloc_encode, ctx->backend,
|
| 2900 |
[&]() {
|
| 2901 |
return whisper_build_graph_encoder(*ctx, *state);
|
| 2902 |
});
|
| 2903 |
|
| 2904 |
+
WHISPER_LOG_INFO("%s: compute buffer (encode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_encode) / 1024.0 / 1024.0);
|
| 2905 |
}
|
| 2906 |
|
| 2907 |
// cross allocator
|
| 2908 |
{
|
| 2909 |
+
whisper_allocr_graph_init(state->alloc_cross, ctx->backend,
|
| 2910 |
[&]() {
|
| 2911 |
return whisper_build_graph_cross(*ctx, *state);
|
| 2912 |
});
|
| 2913 |
|
| 2914 |
+
WHISPER_LOG_INFO("%s: compute buffer (cross) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_cross) / 1024.0 / 1024.0);
|
| 2915 |
}
|
| 2916 |
|
| 2917 |
// decoder allocator
|
| 2918 |
{
|
| 2919 |
+
whisper_allocr_graph_init(state->alloc_decode, ctx->backend,
|
| 2920 |
[&]() {
|
| 2921 |
const auto & hparams = ctx->model.hparams;
|
| 2922 |
|
|
|
|
| 2927 |
return whisper_build_graph_decoder(*ctx, *state, state->decoders[0], nullptr, n_tokens, n_past);
|
| 2928 |
});
|
| 2929 |
|
| 2930 |
+
WHISPER_LOG_INFO("%s: compute buffer (decode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_decode) / 1024.0 / 1024.0);
|
| 2931 |
}
|
| 2932 |
|
| 2933 |
+
whisper_allocr_graph_realloc(state->alloc_conv, ctx->backend);
|
| 2934 |
+
whisper_allocr_graph_realloc(state->alloc_encode, ctx->backend);
|
| 2935 |
+
whisper_allocr_graph_realloc(state->alloc_cross, ctx->backend);
|
| 2936 |
+
whisper_allocr_graph_realloc(state->alloc_decode, ctx->backend);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2937 |
|
| 2938 |
state->rng = std::mt19937(0);
|
| 2939 |
|
|
|
|
| 2954 |
return 1;
|
| 2955 |
#else
|
| 2956 |
if (!model_path && ctx->path_model.empty()) {
|
| 2957 |
+
WHISPER_LOG_ERROR("%s: model_path is nullptr, and ctx has no model_path set.\n", __func__);
|
| 2958 |
return 1;
|
| 2959 |
}
|
| 2960 |
|
|
|
|
| 2974 |
path_cache = cache_dir;
|
| 2975 |
}
|
| 2976 |
|
| 2977 |
+
WHISPER_LOG_INFO("%s: loading OpenVINO model from '%s'\n", __func__, path_encoder.c_str());
|
| 2978 |
+
WHISPER_LOG_INFO("%s: first run on a device may take a while ...\n", __func__);
|
| 2979 |
|
| 2980 |
ctx->state->ctx_openvino = whisper_openvino_init(path_encoder.c_str(), device, path_cache.c_str());
|
| 2981 |
if (!ctx->state->ctx_openvino) {
|
| 2982 |
+
WHISPER_LOG_ERROR("%s: failed to init OpenVINO encoder from '%s'\n", __func__, path_encoder.c_str());
|
| 2983 |
return 1;
|
| 2984 |
} else {
|
| 2985 |
+
WHISPER_LOG_INFO("%s: OpenVINO model loaded\n", __func__);
|
| 2986 |
}
|
| 2987 |
|
| 2988 |
return 0;
|
|
|
|
| 2997 |
}
|
| 2998 |
|
| 2999 |
struct whisper_context * whisper_init_from_file_with_params_no_state(const char * path_model, struct whisper_context_params params) {
|
| 3000 |
+
WHISPER_LOG_INFO("%s: loading model from '%s'\n", __func__, path_model);
|
| 3001 |
|
| 3002 |
auto fin = std::ifstream(path_model, std::ios::binary);
|
| 3003 |
if (!fin) {
|
| 3004 |
+
WHISPER_LOG_ERROR("%s: failed to open '%s'\n", __func__, path_model);
|
| 3005 |
return nullptr;
|
| 3006 |
}
|
| 3007 |
|
|
|
|
| 3043 |
|
| 3044 |
buf_context ctx = { reinterpret_cast<uint8_t*>(buffer), buffer_size, 0 };
|
| 3045 |
|
| 3046 |
+
WHISPER_LOG_INFO("%s: loading model from buffer\n", __func__);
|
| 3047 |
|
| 3048 |
whisper_model_loader loader = {};
|
| 3049 |
|
|
|
|
| 3079 |
|
| 3080 |
if (!whisper_model_load(loader, *ctx)) {
|
| 3081 |
loader->close(loader->context);
|
| 3082 |
+
WHISPER_LOG_ERROR("%s: failed to load model\n", __func__);
|
| 3083 |
delete ctx;
|
| 3084 |
return nullptr;
|
| 3085 |
}
|
|
|
|
| 3174 |
}
|
| 3175 |
#endif
|
| 3176 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3177 |
#ifdef WHISPER_USE_OPENVINO
|
| 3178 |
if (state->ctx_openvino != nullptr) {
|
| 3179 |
whisper_openvino_free(state->ctx_openvino);
|
|
|
|
| 3182 |
#endif
|
| 3183 |
|
| 3184 |
whisper_allocr_free(state->alloc_conv);
|
|
|
|
|
|
|
| 3185 |
whisper_allocr_free(state->alloc_encode);
|
| 3186 |
+
whisper_allocr_free(state->alloc_cross);
|
| 3187 |
+
whisper_allocr_free(state->alloc_decode);
|
| 3188 |
+
|
| 3189 |
+
ggml_backend_free(state->backend);
|
| 3190 |
|
| 3191 |
delete state;
|
| 3192 |
}
|
|
|
|
| 3197 |
if (ctx->model.ctx) {
|
| 3198 |
ggml_free(ctx->model.ctx);
|
| 3199 |
}
|
| 3200 |
+
|
| 3201 |
+
if (ctx->model.buffer) {
|
| 3202 |
+
ggml_backend_buffer_free(ctx->model.buffer);
|
| 3203 |
}
|
| 3204 |
|
| 3205 |
whisper_free_state(ctx->state);
|
| 3206 |
|
| 3207 |
+
ggml_backend_free(ctx->backend);
|
| 3208 |
+
|
| 3209 |
delete ctx;
|
| 3210 |
}
|
| 3211 |
}
|
|
|
|
| 3224 |
|
| 3225 |
int whisper_pcm_to_mel_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
|
| 3226 |
if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, ctx->model.filters.n_mel, n_threads, ctx->model.filters, false, state->mel)) {
|
| 3227 |
+
WHISPER_LOG_ERROR("%s: failed to compute mel spectrogram\n", __func__);
|
| 3228 |
return -1;
|
| 3229 |
}
|
| 3230 |
|
|
|
|
| 3238 |
// same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 (PV without phase lock is not good)
|
| 3239 |
int whisper_pcm_to_mel_phase_vocoder_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
|
| 3240 |
if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, 2 * WHISPER_N_FFT, 2 * WHISPER_HOP_LENGTH, ctx->model.filters.n_mel, n_threads, ctx->model.filters, false, state->mel)) {
|
| 3241 |
+
WHISPER_LOG_ERROR("%s: failed to compute mel spectrogram\n", __func__);
|
| 3242 |
return -1;
|
| 3243 |
}
|
| 3244 |
|
|
|
|
| 3266 |
int n_len,
|
| 3267 |
int n_mel) {
|
| 3268 |
if (n_mel != ctx->model.filters.n_mel) {
|
| 3269 |
+
WHISPER_LOG_ERROR("%s: invalid number of mel bands: %d (expected %d)\n", __func__, n_mel, ctx->model.filters.n_mel);
|
| 3270 |
return -1;
|
| 3271 |
}
|
| 3272 |
|
|
|
|
| 3290 |
|
| 3291 |
int whisper_encode_with_state(struct whisper_context * ctx, struct whisper_state * state, int offset, int n_threads) {
|
| 3292 |
if (!whisper_encode_internal(*ctx, *state, offset, n_threads, nullptr, nullptr)) {
|
| 3293 |
+
WHISPER_LOG_ERROR("%s: failed to eval\n", __func__);
|
| 3294 |
return -1;
|
| 3295 |
}
|
| 3296 |
|
|
|
|
| 3299 |
|
| 3300 |
int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
|
| 3301 |
if (!whisper_encode_internal(*ctx, *ctx->state, offset, n_threads, nullptr, nullptr)) {
|
| 3302 |
+
WHISPER_LOG_ERROR("%s: failed to eval\n", __func__);
|
| 3303 |
return -1;
|
| 3304 |
}
|
| 3305 |
|
|
|
|
| 3310 |
const int selected_decoder_id = 0;
|
| 3311 |
|
| 3312 |
if (!whisper_decode_internal(*ctx, *state, state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads, nullptr, nullptr)) {
|
| 3313 |
+
WHISPER_LOG_ERROR("%s: failed to eval\n", __func__);
|
| 3314 |
return 1;
|
| 3315 |
}
|
| 3316 |
|
|
|
|
| 3322 |
const int selected_decoder_id = 0;
|
| 3323 |
|
| 3324 |
if (ctx->state == nullptr) {
|
| 3325 |
+
WHISPER_LOG_ERROR("%s: ERROR state was not loaded.\n", __func__);
|
| 3326 |
return false;
|
| 3327 |
}
|
| 3328 |
|
| 3329 |
if (!whisper_decode_internal(*ctx, *ctx->state, ctx->state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads, nullptr, nullptr)) {
|
| 3330 |
+
WHISPER_LOG_ERROR("%s: failed to eval\n", __func__);
|
| 3331 |
return 1;
|
| 3332 |
}
|
| 3333 |
|
|
|
|
| 3338 |
const auto res = tokenize(ctx->vocab, text);
|
| 3339 |
|
| 3340 |
if (n_max_tokens < (int) res.size()) {
|
| 3341 |
+
WHISPER_LOG_ERROR("%s: too many resulting tokens: %d (max %d)\n", __func__, (int) res.size(), n_max_tokens);
|
| 3342 |
return -1;
|
| 3343 |
}
|
| 3344 |
|
|
|
|
| 3366 |
}
|
| 3367 |
}
|
| 3368 |
|
| 3369 |
+
WHISPER_LOG_ERROR("%s: unknown language '%s'\n", __func__, lang);
|
| 3370 |
return -1;
|
| 3371 |
}
|
| 3372 |
return g_lang.at(lang).first;
|
|
|
|
| 3379 |
}
|
| 3380 |
}
|
| 3381 |
|
| 3382 |
+
WHISPER_LOG_ERROR("%s: unknown language id %d\n", __func__, id);
|
| 3383 |
return nullptr;
|
| 3384 |
}
|
| 3385 |
|
|
|
|
| 3392 |
const int seek = offset_ms/10;
|
| 3393 |
|
| 3394 |
if (seek < 0) {
|
| 3395 |
+
WHISPER_LOG_ERROR("%s: offset %dms is before the start of the audio\n", __func__, offset_ms);
|
| 3396 |
return -1;
|
| 3397 |
}
|
| 3398 |
|
| 3399 |
if (seek >= state->mel.n_len_org) {
|
| 3400 |
+
WHISPER_LOG_ERROR("%s: offset %dms is past the end of the audio (%dms)\n", __func__, offset_ms, state->mel.n_len_org*10);
|
| 3401 |
return -2;
|
| 3402 |
}
|
| 3403 |
|
| 3404 |
// run the encoder
|
| 3405 |
if (whisper_encode_with_state(ctx, state, seek, n_threads) != 0) {
|
| 3406 |
+
WHISPER_LOG_ERROR("%s: failed to encode\n", __func__);
|
| 3407 |
return -6;
|
| 3408 |
}
|
| 3409 |
|
| 3410 |
const std::vector<whisper_token> prompt = { whisper_token_sot(ctx) };
|
| 3411 |
|
| 3412 |
if (whisper_decode_with_state(ctx, state, prompt.data(), prompt.size(), 0, n_threads) != 0) {
|
| 3413 |
+
WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
|
| 3414 |
return -7;
|
| 3415 |
}
|
| 3416 |
|
|
|
|
| 3610 |
void whisper_print_timings(struct whisper_context * ctx) {
|
| 3611 |
const int64_t t_end_us = ggml_time_us();
|
| 3612 |
|
| 3613 |
+
WHISPER_LOG_INFO("\n");
|
| 3614 |
+
WHISPER_LOG_INFO("%s: load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0f);
|
| 3615 |
if (ctx->state != nullptr) {
|
| 3616 |
|
| 3617 |
const int32_t n_sample = std::max(1, ctx->state->n_sample);
|
|
|
|
| 3619 |
const int32_t n_decode = std::max(1, ctx->state->n_decode);
|
| 3620 |
const int32_t n_prompt = std::max(1, ctx->state->n_prompt);
|
| 3621 |
|
| 3622 |
+
WHISPER_LOG_INFO("%s: fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h);
|
| 3623 |
+
WHISPER_LOG_INFO("%s: mel time = %8.2f ms\n", __func__, ctx->state->t_mel_us / 1000.0f);
|
| 3624 |
+
WHISPER_LOG_INFO("%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample);
|
| 3625 |
+
WHISPER_LOG_INFO("%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode);
|
| 3626 |
+
WHISPER_LOG_INFO("%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode);
|
| 3627 |
+
WHISPER_LOG_INFO("%s: prompt time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_prompt_us, n_prompt, 1e-3f * ctx->state->t_prompt_us / n_prompt);
|
| 3628 |
}
|
| 3629 |
+
WHISPER_LOG_INFO("%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f);
|
| 3630 |
}
|
| 3631 |
|
| 3632 |
void whisper_reset_timings(struct whisper_context * ctx) {
|
|
|
|
| 3678 |
s += "SSE3 = " + std::to_string(ggml_cpu_has_sse3()) + " | ";
|
| 3679 |
s += "SSSE3 = " + std::to_string(ggml_cpu_has_ssse3()) + " | ";
|
| 3680 |
s += "VSX = " + std::to_string(ggml_cpu_has_vsx()) + " | ";
|
| 3681 |
+
s += "CUDA = " + std::to_string(ggml_cpu_has_cublas()) + " | ";
|
| 3682 |
s += "COREML = " + std::to_string(whisper_has_coreml()) + " | ";
|
| 3683 |
s += "OPENVINO = " + std::to_string(whisper_has_openvino()) + " | ";
|
| 3684 |
|
|
|
|
| 3973 |
const bool last_was_timestamp = tokens_cur.size() > 0 && tokens_cur.back().id >= vocab.token_beg;
|
| 3974 |
const bool penultimate_was_timestamp = tokens_cur.size() < 2 || tokens_cur[tokens_cur.size() - 2].id >= vocab.token_beg;
|
| 3975 |
|
| 3976 |
+
//WHISPER_LOG_INFO("last_was_timestamp=%d penultimate_was_timestamp=%d\n", last_was_timestamp, penultimate_was_timestamp);
|
| 3977 |
|
| 3978 |
if (last_was_timestamp) {
|
| 3979 |
if (penultimate_was_timestamp) {
|
|
|
|
| 4049 |
|
| 4050 |
const float max_text_token_logprob = *std::max_element(logprobs.begin(), logprobs.begin() + vocab.token_beg);
|
| 4051 |
|
| 4052 |
+
//WHISPER_LOG_INFO("timestamp_logprob=%f max_text_token_logprob=%f\n", timestamp_logprob, max_text_token_logprob);
|
| 4053 |
|
| 4054 |
if (timestamp_logprob > max_text_token_logprob) {
|
| 4055 |
for (int i = 0; i < vocab.token_beg; ++i) {
|
|
|
|
| 4344 |
for (auto & i : two_copy) {
|
| 4345 |
// make a copy of KV caches
|
| 4346 |
WHISPER_PRINT_DEBUG("%s: store KV cache into swap: idx %d\n", __func__, i);
|
| 4347 |
+
//memcpy(kv_swap_bufs[i].k.data(), src[i].kv_self.k->data, kv_swap_bufs[i].k.size());
|
| 4348 |
+
//memcpy(kv_swap_bufs[i].v.data(), src[i].kv_self.v->data, kv_swap_bufs[i].v.size());
|
| 4349 |
+
ggml_backend_tensor_get(src[i].kv_self.k, kv_swap_bufs[i].k.data(), 0, kv_swap_bufs[i].k.size());
|
| 4350 |
+
ggml_backend_tensor_get(src[i].kv_self.v, kv_swap_bufs[i].v.data(), 0, kv_swap_bufs[i].v.size());
|
| 4351 |
}
|
| 4352 |
|
| 4353 |
// since two-copy decoder KV caches are protected by kv_swap_bufs, modify them first
|
|
|
|
| 4360 |
if (two_copy.find(view[i]) != two_copy.end()) {
|
| 4361 |
// modify KV caches of decoder using data from kv_swap_bufs
|
| 4362 |
WHISPER_PRINT_DEBUG("%s: two-copy decoder using swap buffers: swap[%d] -> %d\n", __func__, view[i], i);
|
| 4363 |
+
//memcpy(src[i].kv_self.k->data, kv_swap_bufs[view[i]].k.data(), kv_swap_bufs[view[i]].k.size());
|
| 4364 |
+
//memcpy(src[i].kv_self.v->data, kv_swap_bufs[view[i]].v.data(), kv_swap_bufs[view[i]].v.size());
|
| 4365 |
+
ggml_backend_tensor_set(src[i].kv_self.k, kv_swap_bufs[view[i]].k.data(), 0, kv_swap_bufs[view[i]].k.size());
|
| 4366 |
+
ggml_backend_tensor_set(src[i].kv_self.v, kv_swap_bufs[view[i]].v.data(), 0, kv_swap_bufs[view[i]].v.size());
|
| 4367 |
} else {
|
| 4368 |
// modify KV caches of decoder using data from correspond decoder KV caches directly
|
| 4369 |
WHISPER_PRINT_DEBUG("%s: two-copy decoder without swap buffers: %d -> %d\n", __func__, view[i], i);
|
| 4370 |
+
//memcpy(src[i].kv_self.k->data, src[view[i]].kv_self.k->data, ggml_nbytes(src[view[i]].kv_self.k));
|
| 4371 |
+
//memcpy(src[i].kv_self.v->data, src[view[i]].kv_self.v->data, ggml_nbytes(src[view[i]].kv_self.v));
|
| 4372 |
+
ggml_backend_tensor_copy(src[view[i]].kv_self.k, src[i].kv_self.k);
|
| 4373 |
+
ggml_backend_tensor_copy(src[view[i]].kv_self.v, src[i].kv_self.v);
|
| 4374 |
}
|
| 4375 |
}
|
| 4376 |
|
|
|
|
| 4384 |
if (two_copy.find(view[i]) != two_copy.end()) {
|
| 4385 |
// modify KV caches of decoder using data from kv_swap_bufs
|
| 4386 |
WHISPER_PRINT_DEBUG("%s: one-copy decoder using swap buffers: swap[%d] -> %d\n", __func__, view[i], i);
|
| 4387 |
+
//memcpy(src[i].kv_self.k->data, kv_swap_bufs[view[i]].k.data(), kv_swap_bufs[view[i]].k.size());
|
| 4388 |
+
//memcpy(src[i].kv_self.v->data, kv_swap_bufs[view[i]].v.data(), kv_swap_bufs[view[i]].v.size());
|
| 4389 |
+
ggml_backend_tensor_set(src[i].kv_self.k, kv_swap_bufs[view[i]].k.data(), 0, kv_swap_bufs[view[i]].k.size());
|
| 4390 |
+
ggml_backend_tensor_set(src[i].kv_self.v, kv_swap_bufs[view[i]].v.data(), 0, kv_swap_bufs[view[i]].v.size());
|
| 4391 |
} else {
|
| 4392 |
// modify KV caches of decoder using data from correspond decoder KV caches directly
|
| 4393 |
WHISPER_PRINT_DEBUG("%s: one-copy decoder without swap buffers: %d -> %d\n", __func__, view[i], i);
|
| 4394 |
+
//memcpy(src[i].kv_self.k->data, src[view[i]].kv_self.k->data, ggml_nbytes(src[view[i]].kv_self.k));
|
| 4395 |
+
//memcpy(src[i].kv_self.v->data, src[view[i]].kv_self.v->data, ggml_nbytes(src[view[i]].kv_self.v));
|
| 4396 |
+
ggml_backend_tensor_copy(src[view[i]].kv_self.k, src[i].kv_self.k);
|
| 4397 |
+
ggml_backend_tensor_copy(src[view[i]].kv_self.v, src[i].kv_self.v);
|
| 4398 |
}
|
| 4399 |
}
|
| 4400 |
|
|
|
|
| 4422 |
// compute log mel spectrogram
|
| 4423 |
if (params.speed_up) {
|
| 4424 |
// TODO: Replace PV with more advanced algorithm
|
| 4425 |
+
WHISPER_LOG_ERROR("%s: failed to compute log mel spectrogram\n", __func__);
|
| 4426 |
return -1;
|
| 4427 |
} else {
|
| 4428 |
if (whisper_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) {
|
| 4429 |
+
WHISPER_LOG_ERROR("%s: failed to compute log mel spectrogram\n", __func__);
|
| 4430 |
return -2;
|
| 4431 |
}
|
| 4432 |
}
|
|
|
|
| 4438 |
|
| 4439 |
const auto lang_id = whisper_lang_auto_detect_with_state(ctx, state, 0, params.n_threads, probs.data());
|
| 4440 |
if (lang_id < 0) {
|
| 4441 |
+
WHISPER_LOG_ERROR("%s: failed to auto-detect language\n", __func__);
|
| 4442 |
return -3;
|
| 4443 |
}
|
| 4444 |
state->lang_id = lang_id;
|
| 4445 |
params.language = whisper_lang_str(lang_id);
|
| 4446 |
|
| 4447 |
+
WHISPER_LOG_INFO("%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]);
|
| 4448 |
if (params.detect_language) {
|
| 4449 |
return 0;
|
| 4450 |
}
|
|
|
|
| 4502 |
|
| 4503 |
if (decoder.kv_self.ctx == nullptr) {
|
| 4504 |
decoder.kv_self = state->decoders[0].kv_self;
|
| 4505 |
+
if (!kv_cache_reinit(decoder.kv_self, ctx->backend)) {
|
| 4506 |
+
WHISPER_LOG_ERROR("%s: kv_cache_reinit() failed for self-attention, decoder %d\n", __func__, j);
|
| 4507 |
return -4;
|
| 4508 |
}
|
| 4509 |
|
|
|
|
| 4514 |
decoder.probs.resize (ctx->vocab.n_vocab);
|
| 4515 |
decoder.logits.resize (ctx->vocab.n_vocab);
|
| 4516 |
decoder.logprobs.resize(ctx->vocab.n_vocab);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4517 |
}
|
| 4518 |
}
|
| 4519 |
|
|
|
|
| 4547 |
|
| 4548 |
// overwrite audio_ctx, max allowed is hparams.n_audio_ctx
|
| 4549 |
if (params.audio_ctx > whisper_n_audio_ctx(ctx)) {
|
| 4550 |
+
WHISPER_LOG_ERROR("%s: audio_ctx is larger than the maximum allowed (%d > %d)\n", __func__, params.audio_ctx, whisper_n_audio_ctx(ctx));
|
| 4551 |
return -5;
|
| 4552 |
}
|
| 4553 |
state->exp_n_audio_ctx = params.audio_ctx;
|
|
|
|
| 4572 |
// distilled models require the "no_timestamps" token
|
| 4573 |
// TODO: add input parameter (#1229)
|
| 4574 |
if (is_distil) {
|
| 4575 |
+
WHISPER_LOG_WARN("%s: using distilled model - forcing no_timestamps\n", __func__);
|
| 4576 |
prompt_init.push_back(whisper_token_not(ctx));
|
| 4577 |
}
|
| 4578 |
}
|
|
|
|
| 4609 |
|
| 4610 |
if (params.encoder_begin_callback) {
|
| 4611 |
if (params.encoder_begin_callback(ctx, state, params.encoder_begin_callback_user_data) == false) {
|
| 4612 |
+
WHISPER_LOG_ERROR("%s: encoder_begin_callback returned false - aborting\n", __func__);
|
| 4613 |
break;
|
| 4614 |
}
|
| 4615 |
}
|
| 4616 |
|
| 4617 |
// encode audio features starting at offset seek
|
| 4618 |
if (!whisper_encode_internal(*ctx, *state, seek, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
|
| 4619 |
+
WHISPER_LOG_ERROR("%s: failed to encode\n", __func__);
|
| 4620 |
return -6;
|
| 4621 |
}
|
| 4622 |
|
|
|
|
| 4699 |
WHISPER_PRINT_DEBUG("\n\n");
|
| 4700 |
|
| 4701 |
if (!whisper_decode_internal(*ctx, *state, state->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
|
| 4702 |
+
WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
|
| 4703 |
return -7;
|
| 4704 |
}
|
| 4705 |
|
|
|
|
| 4713 |
for (int j = 1; j < n_decoders_cur; ++j) {
|
| 4714 |
auto & decoder = state->decoders[j];
|
| 4715 |
|
| 4716 |
+
// TODO: fix CUDA
|
| 4717 |
+
//memcpy(decoder.kv_self.k->data, state->decoders[0].kv_self.k->data, ggml_nbytes(decoder.kv_self.k));
|
| 4718 |
+
//memcpy(decoder.kv_self.v->data, state->decoders[0].kv_self.v->data, ggml_nbytes(decoder.kv_self.v));
|
| 4719 |
+
ggml_backend_tensor_copy(state->decoders[0].kv_self.k, decoder.kv_self.k);
|
| 4720 |
+
ggml_backend_tensor_copy(state->decoders[0].kv_self.v, decoder.kv_self.v);
|
| 4721 |
|
| 4722 |
decoder.kv_self.n += prompt.size();
|
| 4723 |
|
|
|
|
| 4926 |
//WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, kv_self.n %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.kv_self.n, decoder.seek_delta);
|
| 4927 |
|
| 4928 |
if (!whisper_decode_internal(*ctx, *state, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
|
| 4929 |
+
WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
|
| 4930 |
return -8;
|
| 4931 |
}
|
| 4932 |
|
|
|
|
| 5252 |
ctx->state->t_decode_us /= n_processors;
|
| 5253 |
|
| 5254 |
// print information about the audio boundaries
|
| 5255 |
+
WHISPER_LOG_WARN("\n");
|
| 5256 |
+
WHISPER_LOG_WARN("%s: the audio has been split into %d chunks at the following times:\n", __func__, n_processors);
|
| 5257 |
for (int i = 0; i < n_processors - 1; ++i) {
|
| 5258 |
+
WHISPER_LOG_WARN("%s: split %d - %s\n", __func__, (i + 1), to_timestamp(100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t).c_str());
|
| 5259 |
}
|
| 5260 |
+
WHISPER_LOG_WARN("%s: the transcription quality may be degraded near these boundaries\n", __func__);
|
| 5261 |
|
| 5262 |
return ret;
|
| 5263 |
}
|
|
|
|
| 5499 |
double tsum = 0.0;
|
| 5500 |
|
| 5501 |
// heat-up
|
| 5502 |
+
ggml_graph_compute_helper(gf, work, n_threads, nullptr, nullptr);
|
| 5503 |
|
| 5504 |
for (int i = 0; i < n_max; ++i) {
|
| 5505 |
const int64_t t0 = ggml_time_us();
|
| 5506 |
|
| 5507 |
+
ggml_graph_compute_helper(gf, work, n_threads, nullptr, nullptr);
|
| 5508 |
|
| 5509 |
const int64_t t1 = ggml_time_us();
|
| 5510 |
|
|
|
|
| 5622 |
const int n_samples = state.energy.size();
|
| 5623 |
|
| 5624 |
if (n_samples == 0) {
|
| 5625 |
+
WHISPER_LOG_ERROR("%s: no signal data available\n", __func__);
|
| 5626 |
return;
|
| 5627 |
}
|
| 5628 |
|
|
|
|
| 5843 |
//}
|
| 5844 |
}
|
| 5845 |
|
| 5846 |
+
void whisper_log_set(ggml_log_callback log_callback, void * user_data) {
|
| 5847 |
+
g_state.log_callback = log_callback ? log_callback : whisper_log_callback_default;
|
| 5848 |
+
g_state.log_callback_user_data = user_data;
|
| 5849 |
+
}
|
| 5850 |
+
|
| 5851 |
+
static void whisper_log_internal_v(ggml_log_level level, const char * format, va_list args) {
|
| 5852 |
+
va_list args_copy;
|
| 5853 |
+
va_copy(args_copy, args);
|
| 5854 |
+
char buffer[128];
|
| 5855 |
+
int len = vsnprintf(buffer, 128, format, args);
|
| 5856 |
+
if (len < 128) {
|
| 5857 |
+
g_state.log_callback(level, buffer, g_state.log_callback_user_data);
|
| 5858 |
+
} else {
|
| 5859 |
+
char* buffer2 = new char[len+1];
|
| 5860 |
+
vsnprintf(buffer2, len+1, format, args_copy);
|
| 5861 |
+
buffer2[len] = 0;
|
| 5862 |
+
g_state.log_callback(level, buffer2, g_state.log_callback_user_data);
|
| 5863 |
+
delete[] buffer2;
|
| 5864 |
+
}
|
| 5865 |
+
va_end(args_copy);
|
| 5866 |
+
}
|
| 5867 |
+
|
| 5868 |
+
static void whisper_log_internal(ggml_log_level level, const char * format, ...) {
|
| 5869 |
+
va_list args;
|
| 5870 |
+
va_start(args, format);
|
| 5871 |
+
whisper_log_internal_v(level, format, args);
|
| 5872 |
+
va_end(args);
|
| 5873 |
+
}
|
| 5874 |
+
|
| 5875 |
+
static void whisper_log_callback_default(ggml_log_level level, const char * text, void * user_data) {
|
| 5876 |
+
(void) level;
|
| 5877 |
+
(void) user_data;
|
| 5878 |
+
fputs(text, stderr);
|
| 5879 |
+
fflush(stderr);
|
| 5880 |
}
|
|
@@ -1,6 +1,8 @@
|
|
| 1 |
#ifndef WHISPER_H
|
| 2 |
#define WHISPER_H
|
| 3 |
|
|
|
|
|
|
|
| 4 |
#include <stddef.h>
|
| 5 |
#include <stdint.h>
|
| 6 |
#include <stdbool.h>
|
|
@@ -110,15 +112,15 @@ extern "C" {
|
|
| 110 |
// Various functions for loading a ggml whisper model.
|
| 111 |
// Allocate (almost) all memory needed for the model.
|
| 112 |
// Return NULL on failure
|
| 113 |
-
WHISPER_API struct whisper_context * whisper_init_from_file_with_params(const char * path_model,
|
| 114 |
-
WHISPER_API struct whisper_context * whisper_init_from_buffer_with_params(void * buffer, size_t buffer_size,
|
| 115 |
-
WHISPER_API struct whisper_context * whisper_init_with_params(struct whisper_model_loader * loader, struct whisper_context_params params);
|
| 116 |
|
| 117 |
// These are the same as the above, but the internal state of the context is not allocated automatically
|
| 118 |
// It is the responsibility of the caller to allocate the state using whisper_init_state() (#523)
|
| 119 |
-
WHISPER_API struct whisper_context * whisper_init_from_file_with_params_no_state(const char * path_model,
|
| 120 |
-
WHISPER_API struct whisper_context * whisper_init_from_buffer_with_params_no_state(void * buffer, size_t buffer_size,
|
| 121 |
-
WHISPER_API struct whisper_context * whisper_init_with_params_no_state(struct whisper_model_loader * loader, struct whisper_context_params params);
|
| 122 |
|
| 123 |
WHISPER_DEPRECATED(
|
| 124 |
WHISPER_API struct whisper_context * whisper_init_from_file(const char * path_model),
|
|
@@ -570,8 +572,7 @@ extern "C" {
|
|
| 570 |
|
| 571 |
// Control logging output; default behavior is to print to stderr
|
| 572 |
|
| 573 |
-
|
| 574 |
-
WHISPER_API void whisper_set_log_callback(whisper_log_callback callback);
|
| 575 |
|
| 576 |
#ifdef __cplusplus
|
| 577 |
}
|
|
|
|
| 1 |
#ifndef WHISPER_H
|
| 2 |
#define WHISPER_H
|
| 3 |
|
| 4 |
+
#include "ggml.h"
|
| 5 |
+
|
| 6 |
#include <stddef.h>
|
| 7 |
#include <stdint.h>
|
| 8 |
#include <stdbool.h>
|
|
|
|
| 112 |
// Various functions for loading a ggml whisper model.
|
| 113 |
// Allocate (almost) all memory needed for the model.
|
| 114 |
// Return NULL on failure
|
| 115 |
+
WHISPER_API struct whisper_context * whisper_init_from_file_with_params (const char * path_model, struct whisper_context_params params);
|
| 116 |
+
WHISPER_API struct whisper_context * whisper_init_from_buffer_with_params(void * buffer, size_t buffer_size, struct whisper_context_params params);
|
| 117 |
+
WHISPER_API struct whisper_context * whisper_init_with_params (struct whisper_model_loader * loader, struct whisper_context_params params);
|
| 118 |
|
| 119 |
// These are the same as the above, but the internal state of the context is not allocated automatically
|
| 120 |
// It is the responsibility of the caller to allocate the state using whisper_init_state() (#523)
|
| 121 |
+
WHISPER_API struct whisper_context * whisper_init_from_file_with_params_no_state (const char * path_model, struct whisper_context_params params);
|
| 122 |
+
WHISPER_API struct whisper_context * whisper_init_from_buffer_with_params_no_state(void * buffer, size_t buffer_size, struct whisper_context_params params);
|
| 123 |
+
WHISPER_API struct whisper_context * whisper_init_with_params_no_state (struct whisper_model_loader * loader, struct whisper_context_params params);
|
| 124 |
|
| 125 |
WHISPER_DEPRECATED(
|
| 126 |
WHISPER_API struct whisper_context * whisper_init_from_file(const char * path_model),
|
|
|
|
| 572 |
|
| 573 |
// Control logging output; default behavior is to print to stderr
|
| 574 |
|
| 575 |
+
WHISPER_API void whisper_log_set(ggml_log_callback log_callback, void * user_data);
|
|
|
|
| 576 |
|
| 577 |
#ifdef __cplusplus
|
| 578 |
}
|