Spaces:
Running
Running
talk-llama : sync llama.cpp
Browse files- examples/talk-llama/llama-arch.cpp +106 -0
- examples/talk-llama/llama-arch.h +5 -0
- examples/talk-llama/llama-batch.cpp +76 -70
- examples/talk-llama/llama-batch.h +24 -18
- examples/talk-llama/llama-chat.cpp +43 -1
- examples/talk-llama/llama-chat.h +2 -0
- examples/talk-llama/llama-context.cpp +182 -108
- examples/talk-llama/llama-context.h +26 -16
- examples/talk-llama/llama-cparams.h +3 -2
- examples/talk-llama/llama-graph.cpp +203 -39
- examples/talk-llama/llama-graph.h +147 -72
- examples/talk-llama/llama-hparams.cpp +40 -0
- examples/talk-llama/llama-hparams.h +10 -2
- examples/talk-llama/llama-kv-cache-unified-iswa.cpp +11 -5
- examples/talk-llama/llama-kv-cache-unified-iswa.h +3 -0
- examples/talk-llama/llama-kv-cache-unified.cpp +698 -302
- examples/talk-llama/llama-kv-cache-unified.h +89 -31
- examples/talk-llama/llama-memory-hybrid.cpp +1 -0
- examples/talk-llama/llama-memory-recurrent.cpp +16 -1
- examples/talk-llama/llama-model.cpp +0 -0
- examples/talk-llama/llama-model.h +3 -4
- examples/talk-llama/llama-quant.cpp +1 -2
- examples/talk-llama/llama-vocab.cpp +363 -8
- examples/talk-llama/llama-vocab.h +2 -0
- examples/talk-llama/llama.h +15 -7
- examples/talk-llama/unicode.cpp +207 -0
- examples/talk-llama/unicode.h +2 -0
examples/talk-llama/llama-arch.cpp
CHANGED
|
@@ -34,6 +34,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
|
| 34 |
{ LLM_ARCH_PHI3, "phi3" },
|
| 35 |
{ LLM_ARCH_PHIMOE, "phimoe" },
|
| 36 |
{ LLM_ARCH_PLAMO, "plamo" },
|
|
|
|
| 37 |
{ LLM_ARCH_CODESHELL, "codeshell" },
|
| 38 |
{ LLM_ARCH_ORION, "orion" },
|
| 39 |
{ LLM_ARCH_INTERNLM2, "internlm2" },
|
|
@@ -67,6 +68,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
|
| 67 |
{ LLM_ARCH_JAIS, "jais" },
|
| 68 |
{ LLM_ARCH_NEMOTRON, "nemotron" },
|
| 69 |
{ LLM_ARCH_EXAONE, "exaone" },
|
|
|
|
| 70 |
{ LLM_ARCH_RWKV6, "rwkv6" },
|
| 71 |
{ LLM_ARCH_RWKV6QWEN2, "rwkv6qwen2" },
|
| 72 |
{ LLM_ARCH_RWKV7, "rwkv7" },
|
|
@@ -81,9 +83,11 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
|
| 81 |
{ LLM_ARCH_DOTS1, "dots1" },
|
| 82 |
{ LLM_ARCH_ARCEE, "arcee" },
|
| 83 |
{ LLM_ARCH_ERNIE4_5, "ernie4_5" },
|
|
|
|
| 84 |
{ LLM_ARCH_HUNYUAN_MOE, "hunyuan-moe" },
|
| 85 |
{ LLM_ARCH_SMOLLM3, "smollm3" },
|
| 86 |
{ LLM_ARCH_LFM2, "lfm2" },
|
|
|
|
| 87 |
{ LLM_ARCH_UNKNOWN, "(unknown)" },
|
| 88 |
};
|
| 89 |
|
|
@@ -784,6 +788,36 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
|
| 784 |
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
| 785 |
},
|
| 786 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 787 |
{
|
| 788 |
LLM_ARCH_CODESHELL,
|
| 789 |
{
|
|
@@ -1477,6 +1511,26 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
|
| 1477 |
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
| 1478 |
},
|
| 1479 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1480 |
{
|
| 1481 |
LLM_ARCH_RWKV6,
|
| 1482 |
{
|
|
@@ -1793,6 +1847,31 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
|
| 1793 |
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
| 1794 |
},
|
| 1795 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1796 |
{
|
| 1797 |
LLM_ARCH_HUNYUAN_MOE,
|
| 1798 |
{
|
|
@@ -1854,6 +1933,23 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
|
| 1854 |
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
|
| 1855 |
}
|
| 1856 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1857 |
{
|
| 1858 |
LLM_ARCH_UNKNOWN,
|
| 1859 |
{
|
|
@@ -2094,6 +2190,7 @@ bool llm_arch_is_hybrid(const llm_arch & arch) {
|
|
| 2094 |
switch (arch) {
|
| 2095 |
case LLM_ARCH_JAMBA:
|
| 2096 |
case LLM_ARCH_FALCON_H1:
|
|
|
|
| 2097 |
case LLM_ARCH_GRANITE_HYBRID:
|
| 2098 |
case LLM_ARCH_LFM2:
|
| 2099 |
return true;
|
|
@@ -2101,3 +2198,12 @@ bool llm_arch_is_hybrid(const llm_arch & arch) {
|
|
| 2101 |
return false;
|
| 2102 |
}
|
| 2103 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
{ LLM_ARCH_PHI3, "phi3" },
|
| 35 |
{ LLM_ARCH_PHIMOE, "phimoe" },
|
| 36 |
{ LLM_ARCH_PLAMO, "plamo" },
|
| 37 |
+
{ LLM_ARCH_PLAMO2, "plamo2" },
|
| 38 |
{ LLM_ARCH_CODESHELL, "codeshell" },
|
| 39 |
{ LLM_ARCH_ORION, "orion" },
|
| 40 |
{ LLM_ARCH_INTERNLM2, "internlm2" },
|
|
|
|
| 68 |
{ LLM_ARCH_JAIS, "jais" },
|
| 69 |
{ LLM_ARCH_NEMOTRON, "nemotron" },
|
| 70 |
{ LLM_ARCH_EXAONE, "exaone" },
|
| 71 |
+
{ LLM_ARCH_EXAONE4, "exaone4" },
|
| 72 |
{ LLM_ARCH_RWKV6, "rwkv6" },
|
| 73 |
{ LLM_ARCH_RWKV6QWEN2, "rwkv6qwen2" },
|
| 74 |
{ LLM_ARCH_RWKV7, "rwkv7" },
|
|
|
|
| 83 |
{ LLM_ARCH_DOTS1, "dots1" },
|
| 84 |
{ LLM_ARCH_ARCEE, "arcee" },
|
| 85 |
{ LLM_ARCH_ERNIE4_5, "ernie4_5" },
|
| 86 |
+
{ LLM_ARCH_ERNIE4_5_MOE, "ernie4_5-moe" },
|
| 87 |
{ LLM_ARCH_HUNYUAN_MOE, "hunyuan-moe" },
|
| 88 |
{ LLM_ARCH_SMOLLM3, "smollm3" },
|
| 89 |
{ LLM_ARCH_LFM2, "lfm2" },
|
| 90 |
+
{ LLM_ARCH_DREAM, "dream" },
|
| 91 |
{ LLM_ARCH_UNKNOWN, "(unknown)" },
|
| 92 |
};
|
| 93 |
|
|
|
|
| 788 |
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
| 789 |
},
|
| 790 |
},
|
| 791 |
+
{
|
| 792 |
+
LLM_ARCH_PLAMO2,
|
| 793 |
+
{
|
| 794 |
+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
| 795 |
+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
| 796 |
+
{ LLM_TENSOR_OUTPUT, "output" },
|
| 797 |
+
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
|
| 798 |
+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
| 799 |
+
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
|
| 800 |
+
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
|
| 801 |
+
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
|
| 802 |
+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
| 803 |
+
{ LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
|
| 804 |
+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
| 805 |
+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
| 806 |
+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
| 807 |
+
{ LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" },
|
| 808 |
+
{ LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" },
|
| 809 |
+
{ LLM_TENSOR_SSM_X, "blk.%d.ssm_x" },
|
| 810 |
+
{ LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" },
|
| 811 |
+
{ LLM_TENSOR_SSM_A, "blk.%d.ssm_a" },
|
| 812 |
+
{ LLM_TENSOR_SSM_D, "blk.%d.ssm_d" },
|
| 813 |
+
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
|
| 814 |
+
{ LLM_TENSOR_SSM_DT_NORM, "blk.%d.ssm_dt_norm" },
|
| 815 |
+
{ LLM_TENSOR_SSM_B_NORM, "blk.%d.ssm_b_norm" },
|
| 816 |
+
{ LLM_TENSOR_SSM_C_NORM, "blk.%d.ssm_c_norm" },
|
| 817 |
+
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
|
| 818 |
+
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
|
| 819 |
+
},
|
| 820 |
+
},
|
| 821 |
{
|
| 822 |
LLM_ARCH_CODESHELL,
|
| 823 |
{
|
|
|
|
| 1511 |
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
| 1512 |
},
|
| 1513 |
},
|
| 1514 |
+
{
|
| 1515 |
+
LLM_ARCH_EXAONE4,
|
| 1516 |
+
{
|
| 1517 |
+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
| 1518 |
+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
| 1519 |
+
{ LLM_TENSOR_OUTPUT, "output" },
|
| 1520 |
+
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
|
| 1521 |
+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
| 1522 |
+
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
|
| 1523 |
+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
| 1524 |
+
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
|
| 1525 |
+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
| 1526 |
+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
| 1527 |
+
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
|
| 1528 |
+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
| 1529 |
+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
| 1530 |
+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
| 1531 |
+
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
|
| 1532 |
+
}
|
| 1533 |
+
},
|
| 1534 |
{
|
| 1535 |
LLM_ARCH_RWKV6,
|
| 1536 |
{
|
|
|
|
| 1847 |
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
| 1848 |
},
|
| 1849 |
},
|
| 1850 |
+
{
|
| 1851 |
+
LLM_ARCH_ERNIE4_5_MOE,
|
| 1852 |
+
{
|
| 1853 |
+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
| 1854 |
+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
| 1855 |
+
{ LLM_TENSOR_OUTPUT, "output" },
|
| 1856 |
+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
| 1857 |
+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
| 1858 |
+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
| 1859 |
+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
| 1860 |
+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
| 1861 |
+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
| 1862 |
+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
| 1863 |
+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
| 1864 |
+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
| 1865 |
+
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
| 1866 |
+
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
|
| 1867 |
+
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
|
| 1868 |
+
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
|
| 1869 |
+
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
| 1870 |
+
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
| 1871 |
+
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
| 1872 |
+
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
|
| 1873 |
+
},
|
| 1874 |
+
},
|
| 1875 |
{
|
| 1876 |
LLM_ARCH_HUNYUAN_MOE,
|
| 1877 |
{
|
|
|
|
| 1933 |
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
|
| 1934 |
}
|
| 1935 |
},
|
| 1936 |
+
{
|
| 1937 |
+
LLM_ARCH_DREAM,
|
| 1938 |
+
{
|
| 1939 |
+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
| 1940 |
+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
| 1941 |
+
{ LLM_TENSOR_OUTPUT, "output" },
|
| 1942 |
+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
| 1943 |
+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
| 1944 |
+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
| 1945 |
+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
| 1946 |
+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
| 1947 |
+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
| 1948 |
+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
| 1949 |
+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
| 1950 |
+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
| 1951 |
+
},
|
| 1952 |
+
},
|
| 1953 |
{
|
| 1954 |
LLM_ARCH_UNKNOWN,
|
| 1955 |
{
|
|
|
|
| 2190 |
switch (arch) {
|
| 2191 |
case LLM_ARCH_JAMBA:
|
| 2192 |
case LLM_ARCH_FALCON_H1:
|
| 2193 |
+
case LLM_ARCH_PLAMO2:
|
| 2194 |
case LLM_ARCH_GRANITE_HYBRID:
|
| 2195 |
case LLM_ARCH_LFM2:
|
| 2196 |
return true;
|
|
|
|
| 2198 |
return false;
|
| 2199 |
}
|
| 2200 |
}
|
| 2201 |
+
|
| 2202 |
+
bool llm_arch_is_diffusion(const llm_arch & arch) {
|
| 2203 |
+
switch (arch) {
|
| 2204 |
+
case LLM_ARCH_DREAM:
|
| 2205 |
+
return true;
|
| 2206 |
+
default:
|
| 2207 |
+
return false;
|
| 2208 |
+
}
|
| 2209 |
+
}
|
examples/talk-llama/llama-arch.h
CHANGED
|
@@ -38,6 +38,7 @@ enum llm_arch {
|
|
| 38 |
LLM_ARCH_PHI3,
|
| 39 |
LLM_ARCH_PHIMOE,
|
| 40 |
LLM_ARCH_PLAMO,
|
|
|
|
| 41 |
LLM_ARCH_CODESHELL,
|
| 42 |
LLM_ARCH_ORION,
|
| 43 |
LLM_ARCH_INTERNLM2,
|
|
@@ -71,6 +72,7 @@ enum llm_arch {
|
|
| 71 |
LLM_ARCH_JAIS,
|
| 72 |
LLM_ARCH_NEMOTRON,
|
| 73 |
LLM_ARCH_EXAONE,
|
|
|
|
| 74 |
LLM_ARCH_RWKV6,
|
| 75 |
LLM_ARCH_RWKV6QWEN2,
|
| 76 |
LLM_ARCH_RWKV7,
|
|
@@ -85,9 +87,11 @@ enum llm_arch {
|
|
| 85 |
LLM_ARCH_DOTS1,
|
| 86 |
LLM_ARCH_ARCEE,
|
| 87 |
LLM_ARCH_ERNIE4_5,
|
|
|
|
| 88 |
LLM_ARCH_HUNYUAN_MOE,
|
| 89 |
LLM_ARCH_SMOLLM3,
|
| 90 |
LLM_ARCH_LFM2,
|
|
|
|
| 91 |
LLM_ARCH_UNKNOWN,
|
| 92 |
};
|
| 93 |
|
|
@@ -478,3 +482,4 @@ const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor);
|
|
| 478 |
|
| 479 |
bool llm_arch_is_recurrent(const llm_arch & arch);
|
| 480 |
bool llm_arch_is_hybrid (const llm_arch & arch);
|
|
|
|
|
|
| 38 |
LLM_ARCH_PHI3,
|
| 39 |
LLM_ARCH_PHIMOE,
|
| 40 |
LLM_ARCH_PLAMO,
|
| 41 |
+
LLM_ARCH_PLAMO2,
|
| 42 |
LLM_ARCH_CODESHELL,
|
| 43 |
LLM_ARCH_ORION,
|
| 44 |
LLM_ARCH_INTERNLM2,
|
|
|
|
| 72 |
LLM_ARCH_JAIS,
|
| 73 |
LLM_ARCH_NEMOTRON,
|
| 74 |
LLM_ARCH_EXAONE,
|
| 75 |
+
LLM_ARCH_EXAONE4,
|
| 76 |
LLM_ARCH_RWKV6,
|
| 77 |
LLM_ARCH_RWKV6QWEN2,
|
| 78 |
LLM_ARCH_RWKV7,
|
|
|
|
| 87 |
LLM_ARCH_DOTS1,
|
| 88 |
LLM_ARCH_ARCEE,
|
| 89 |
LLM_ARCH_ERNIE4_5,
|
| 90 |
+
LLM_ARCH_ERNIE4_5_MOE,
|
| 91 |
LLM_ARCH_HUNYUAN_MOE,
|
| 92 |
LLM_ARCH_SMOLLM3,
|
| 93 |
LLM_ARCH_LFM2,
|
| 94 |
+
LLM_ARCH_DREAM,
|
| 95 |
LLM_ARCH_UNKNOWN,
|
| 96 |
};
|
| 97 |
|
|
|
|
| 482 |
|
| 483 |
bool llm_arch_is_recurrent(const llm_arch & arch);
|
| 484 |
bool llm_arch_is_hybrid (const llm_arch & arch);
|
| 485 |
+
bool llm_arch_is_diffusion(const llm_arch & arch);
|
examples/talk-llama/llama-batch.cpp
CHANGED
|
@@ -27,6 +27,7 @@ bool llama_batch_allocr::init(
|
|
| 27 |
const llama_vocab & vocab,
|
| 28 |
const llama_memory_i * memory,
|
| 29 |
uint32_t n_embd,
|
|
|
|
| 30 |
bool output_all) {
|
| 31 |
clear();
|
| 32 |
|
|
@@ -40,6 +41,11 @@ bool llama_batch_allocr::init(
|
|
| 40 |
// validate input batch
|
| 41 |
//
|
| 42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
if (batch.token) {
|
| 44 |
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
| 45 |
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= vocab.n_tokens()) {
|
|
@@ -52,8 +58,8 @@ bool llama_batch_allocr::init(
|
|
| 52 |
if (batch.seq_id) {
|
| 53 |
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
| 54 |
for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
|
| 55 |
-
if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >=
|
| 56 |
-
LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s],
|
| 57 |
return false;
|
| 58 |
}
|
| 59 |
}
|
|
@@ -86,7 +92,7 @@ bool llama_batch_allocr::init(
|
|
| 86 |
|
| 87 |
// initialize the starting position for each sequence based on the positions in the memory
|
| 88 |
llama_pos p0[LLAMA_MAX_SEQ];
|
| 89 |
-
for (
|
| 90 |
if (!memory) {
|
| 91 |
// if no memory -> start from 0
|
| 92 |
p0[s] = 0;
|
|
@@ -143,13 +149,16 @@ bool llama_batch_allocr::init(
|
|
| 143 |
// compute stats
|
| 144 |
//
|
| 145 |
|
| 146 |
-
this->n_embd
|
|
|
|
| 147 |
|
| 148 |
// count the outputs in this batch
|
| 149 |
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
| 150 |
n_outputs += batch.logits[i] != 0;
|
| 151 |
}
|
| 152 |
|
|
|
|
|
|
|
| 153 |
// determine coupled sequences
|
| 154 |
// these are pairs of sequences that have at least one token in the input batch that is assigned to both of them
|
| 155 |
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
|
@@ -189,7 +198,7 @@ bool llama_batch_allocr::init(
|
|
| 189 |
seq_set_map[cur].push_back(i);
|
| 190 |
}
|
| 191 |
|
| 192 |
-
for (
|
| 193 |
if (seq_set_unq.test(s)) {
|
| 194 |
seq_idx[s] = seq_id_unq.size();
|
| 195 |
seq_id_unq.push_back(s);
|
|
@@ -201,7 +210,7 @@ bool llama_batch_allocr::init(
|
|
| 201 |
LLAMA_LOG_DEBUG("%s: input batch info:\n", __func__);
|
| 202 |
|
| 203 |
llama_ubatch ubatch {
|
| 204 |
-
/*.
|
| 205 |
/*.n_tokens =*/ (uint32_t) batch.n_tokens,
|
| 206 |
/*.n_seq_tokens =*/ (uint32_t) 1,
|
| 207 |
/*.n_seqs =*/ (uint32_t) batch.n_tokens,
|
|
@@ -214,6 +223,7 @@ bool llama_batch_allocr::init(
|
|
| 214 |
/*.seq_id_unq =*/ this->seq_id_unq.data(),
|
| 215 |
/*.seq_idx =*/ this->seq_idx.data(),
|
| 216 |
/*.output =*/ batch.logits,
|
|
|
|
| 217 |
};
|
| 218 |
|
| 219 |
ubatch_print(ubatch, debug);
|
|
@@ -241,7 +251,7 @@ bool llama_batch_allocr::init(
|
|
| 241 |
// consistency checks
|
| 242 |
//
|
| 243 |
|
| 244 |
-
for (
|
| 245 |
if (seq_pos[s].empty()) {
|
| 246 |
continue;
|
| 247 |
}
|
|
@@ -284,8 +294,8 @@ bool llama_batch_allocr::init(
|
|
| 284 |
}
|
| 285 |
|
| 286 |
if (memory) {
|
| 287 |
-
for (
|
| 288 |
-
for (
|
| 289 |
if (seq_cpl[s0][s1]) {
|
| 290 |
if (memory->seq_pos_min(s0) != memory->seq_pos_min(s1) ||
|
| 291 |
memory->seq_pos_max(s0) != memory->seq_pos_max(s1)) {
|
|
@@ -316,12 +326,12 @@ bool llama_batch_allocr::init(
|
|
| 316 |
//
|
| 317 |
{
|
| 318 |
seq_set_t cur_seq_set[LLAMA_MAX_SEQ];
|
| 319 |
-
for (
|
| 320 |
cur_seq_set[s].set();
|
| 321 |
}
|
| 322 |
|
| 323 |
llama_pos cur_seq_pos[LLAMA_MAX_SEQ];
|
| 324 |
-
for (
|
| 325 |
cur_seq_pos[s] = -1;
|
| 326 |
}
|
| 327 |
|
|
@@ -357,39 +367,38 @@ llama_ubatch llama_batch_allocr::ubatch_reserve(uint32_t n_seq_tokens, uint32_t
|
|
| 357 |
clear();
|
| 358 |
split_reset();
|
| 359 |
|
| 360 |
-
|
| 361 |
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
ubatch.seq_idx .resize(LLAMA_MAX_SEQ, -1);
|
| 371 |
-
ubatch.output .resize(n_tokens);
|
| 372 |
|
| 373 |
for (uint32_t s = 0; s < n_seqs; ++s) {
|
| 374 |
-
|
| 375 |
-
|
| 376 |
}
|
| 377 |
|
| 378 |
llama_ubatch res {
|
| 379 |
-
/*.
|
| 380 |
/*.n_tokens =*/ n_tokens,
|
| 381 |
/*.n_seq_tokens =*/ n_seq_tokens,
|
| 382 |
/*.n_seqs =*/ n_seqs,
|
| 383 |
/*.n_seqs_unq =*/ n_seqs,
|
| 384 |
|
| 385 |
-
/*.token =*/
|
| 386 |
/*.embd =*/ nullptr,
|
| 387 |
-
/*.pos =*/
|
| 388 |
-
/*.n_seq_id =*/
|
| 389 |
-
/*.seq_id =*/
|
| 390 |
-
/*.seq_id_unq =*/
|
| 391 |
-
/*.seq_idx =*/
|
| 392 |
-
/*.output =*/
|
|
|
|
| 393 |
};
|
| 394 |
|
| 395 |
return res;
|
|
@@ -430,8 +439,6 @@ void llama_batch_allocr::split_reset() {
|
|
| 430 |
|
| 431 |
used.clear();
|
| 432 |
used.resize(get_n_tokens(), false);
|
| 433 |
-
|
| 434 |
-
ubatches.clear();
|
| 435 |
}
|
| 436 |
|
| 437 |
llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
|
|
@@ -646,78 +653,77 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
|
|
| 646 |
|
| 647 |
assert(n_tokens%n_seqs == 0);
|
| 648 |
|
| 649 |
-
|
| 650 |
-
|
| 651 |
-
auto & ubatch = ubatches.back();
|
| 652 |
|
| 653 |
const int32_t n_pos_cur = batch.embd ? n_pos_per_embd : 1;
|
| 654 |
|
| 655 |
const int64_t n_embd_all = batch.embd ? (int64_t) n_tokens*n_embd : 0;
|
| 656 |
const int64_t n_pos_all = (int64_t) n_tokens*n_pos_cur;
|
| 657 |
|
| 658 |
-
|
| 659 |
-
|
| 660 |
-
|
| 661 |
-
|
| 662 |
-
|
| 663 |
-
|
| 664 |
-
|
| 665 |
-
|
| 666 |
|
| 667 |
seq_set_t seq_set_unq;
|
| 668 |
|
| 669 |
for (size_t i = 0; i < idxs.size(); ++i) {
|
| 670 |
if (batch.token) {
|
| 671 |
-
|
| 672 |
}
|
| 673 |
|
| 674 |
if (batch.embd) {
|
| 675 |
-
memcpy(
|
| 676 |
}
|
| 677 |
|
| 678 |
for (int j = 0; j < n_pos_cur; ++j) {
|
| 679 |
-
|
| 680 |
}
|
| 681 |
|
| 682 |
-
|
| 683 |
-
|
| 684 |
-
|
| 685 |
|
| 686 |
-
for (int s = 0; s <
|
| 687 |
-
seq_set_unq.set(
|
| 688 |
}
|
| 689 |
|
| 690 |
-
if (
|
| 691 |
out_ids.push_back(idxs[i]);
|
| 692 |
}
|
| 693 |
}
|
| 694 |
|
| 695 |
-
for (
|
| 696 |
if (seq_set_unq.test(s)) {
|
| 697 |
-
|
| 698 |
-
|
| 699 |
}
|
| 700 |
}
|
| 701 |
|
| 702 |
llama_ubatch res {
|
| 703 |
-
/*.
|
| 704 |
/*.n_tokens =*/ n_tokens,
|
| 705 |
/*.n_seq_tokens =*/ n_tokens/n_seqs,
|
| 706 |
/*.n_seqs =*/ n_seqs,
|
| 707 |
-
/*.n_seqs_unq =*/ (uint32_t)
|
| 708 |
-
|
| 709 |
-
/*.token =*/ batch.token ?
|
| 710 |
-
/*.embd =*/ batch.embd ?
|
| 711 |
-
/*.pos =*/
|
| 712 |
-
/*.n_seq_id =*/
|
| 713 |
-
/*.seq_id =*/
|
| 714 |
-
/*.seq_id_unq =*/
|
| 715 |
-
/*.seq_idx =*/
|
| 716 |
-
/*.output =*/
|
|
|
|
| 717 |
};
|
| 718 |
|
| 719 |
if (debug > 0) {
|
| 720 |
-
LLAMA_LOG_DEBUG("%s: added ubatch
|
| 721 |
|
| 722 |
ubatch_print(res, debug);
|
| 723 |
}
|
|
@@ -727,7 +733,7 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
|
|
| 727 |
|
| 728 |
void llama_batch_allocr::ubatch_print(const llama_ubatch & ubatch, int debug) {
|
| 729 |
if (debug > 0) {
|
| 730 |
-
LLAMA_LOG_DEBUG("%s: equal_seqs = %d\n", __func__, ubatch.equal_seqs);
|
| 731 |
LLAMA_LOG_DEBUG("%s: n_tokens = %d\n", __func__, ubatch.n_tokens);
|
| 732 |
LLAMA_LOG_DEBUG("%s: n_seq_tokens = %d\n", __func__, ubatch.n_seq_tokens);
|
| 733 |
LLAMA_LOG_DEBUG("%s: n_seqs = %d\n", __func__, ubatch.n_seqs);
|
|
|
|
| 27 |
const llama_vocab & vocab,
|
| 28 |
const llama_memory_i * memory,
|
| 29 |
uint32_t n_embd,
|
| 30 |
+
uint32_t n_seq_max,
|
| 31 |
bool output_all) {
|
| 32 |
clear();
|
| 33 |
|
|
|
|
| 41 |
// validate input batch
|
| 42 |
//
|
| 43 |
|
| 44 |
+
if (n_seq_max > LLAMA_MAX_SEQ) {
|
| 45 |
+
LLAMA_LOG_ERROR("%s: n_seq_max = %d > %d\n", __func__, n_seq_max, LLAMA_MAX_SEQ);
|
| 46 |
+
return false;
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
if (batch.token) {
|
| 50 |
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
| 51 |
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= vocab.n_tokens()) {
|
|
|
|
| 58 |
if (batch.seq_id) {
|
| 59 |
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
| 60 |
for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
|
| 61 |
+
if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= (llama_seq_id) n_seq_max)) {
|
| 62 |
+
LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s], (llama_seq_id) n_seq_max);
|
| 63 |
return false;
|
| 64 |
}
|
| 65 |
}
|
|
|
|
| 92 |
|
| 93 |
// initialize the starting position for each sequence based on the positions in the memory
|
| 94 |
llama_pos p0[LLAMA_MAX_SEQ];
|
| 95 |
+
for (uint32_t s = 0; s < n_seq_max; ++s) {
|
| 96 |
if (!memory) {
|
| 97 |
// if no memory -> start from 0
|
| 98 |
p0[s] = 0;
|
|
|
|
| 149 |
// compute stats
|
| 150 |
//
|
| 151 |
|
| 152 |
+
this->n_embd = n_embd;
|
| 153 |
+
this->n_seq_max = n_seq_max;
|
| 154 |
|
| 155 |
// count the outputs in this batch
|
| 156 |
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
| 157 |
n_outputs += batch.logits[i] != 0;
|
| 158 |
}
|
| 159 |
|
| 160 |
+
has_cpl = false;
|
| 161 |
+
|
| 162 |
// determine coupled sequences
|
| 163 |
// these are pairs of sequences that have at least one token in the input batch that is assigned to both of them
|
| 164 |
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
|
|
|
| 198 |
seq_set_map[cur].push_back(i);
|
| 199 |
}
|
| 200 |
|
| 201 |
+
for (uint32_t s = 0; s < n_seq_max; ++s) {
|
| 202 |
if (seq_set_unq.test(s)) {
|
| 203 |
seq_idx[s] = seq_id_unq.size();
|
| 204 |
seq_id_unq.push_back(s);
|
|
|
|
| 210 |
LLAMA_LOG_DEBUG("%s: input batch info:\n", __func__);
|
| 211 |
|
| 212 |
llama_ubatch ubatch {
|
| 213 |
+
/*.b_equal_seqs =*/ false,
|
| 214 |
/*.n_tokens =*/ (uint32_t) batch.n_tokens,
|
| 215 |
/*.n_seq_tokens =*/ (uint32_t) 1,
|
| 216 |
/*.n_seqs =*/ (uint32_t) batch.n_tokens,
|
|
|
|
| 223 |
/*.seq_id_unq =*/ this->seq_id_unq.data(),
|
| 224 |
/*.seq_idx =*/ this->seq_idx.data(),
|
| 225 |
/*.output =*/ batch.logits,
|
| 226 |
+
/*.data =*/ {},
|
| 227 |
};
|
| 228 |
|
| 229 |
ubatch_print(ubatch, debug);
|
|
|
|
| 251 |
// consistency checks
|
| 252 |
//
|
| 253 |
|
| 254 |
+
for (uint32_t s = 0; s < n_seq_max; ++s) {
|
| 255 |
if (seq_pos[s].empty()) {
|
| 256 |
continue;
|
| 257 |
}
|
|
|
|
| 294 |
}
|
| 295 |
|
| 296 |
if (memory) {
|
| 297 |
+
for (uint32_t s0 = 0; s0 < n_seq_max; ++s0) {
|
| 298 |
+
for (uint32_t s1 = 0; s1 < n_seq_max; ++s1) {
|
| 299 |
if (seq_cpl[s0][s1]) {
|
| 300 |
if (memory->seq_pos_min(s0) != memory->seq_pos_min(s1) ||
|
| 301 |
memory->seq_pos_max(s0) != memory->seq_pos_max(s1)) {
|
|
|
|
| 326 |
//
|
| 327 |
{
|
| 328 |
seq_set_t cur_seq_set[LLAMA_MAX_SEQ];
|
| 329 |
+
for (uint32_t s = 0; s < n_seq_max; ++s) {
|
| 330 |
cur_seq_set[s].set();
|
| 331 |
}
|
| 332 |
|
| 333 |
llama_pos cur_seq_pos[LLAMA_MAX_SEQ];
|
| 334 |
+
for (uint32_t s = 0; s < n_seq_max; ++s) {
|
| 335 |
cur_seq_pos[s] = -1;
|
| 336 |
}
|
| 337 |
|
|
|
|
| 367 |
clear();
|
| 368 |
split_reset();
|
| 369 |
|
| 370 |
+
auto udata = std::make_shared<llama_ubatch::data_t>();
|
| 371 |
|
| 372 |
+
udata->token .resize(n_tokens);
|
| 373 |
+
udata->embd .clear();
|
| 374 |
+
udata->pos .resize(n_tokens);
|
| 375 |
+
udata->n_seq_id .resize(n_tokens);
|
| 376 |
+
udata->seq_id .resize(n_tokens);
|
| 377 |
+
udata->seq_id_unq.resize(0);
|
| 378 |
+
udata->seq_idx .resize(LLAMA_MAX_SEQ, -1);
|
| 379 |
+
udata->output .resize(n_tokens);
|
|
|
|
|
|
|
| 380 |
|
| 381 |
for (uint32_t s = 0; s < n_seqs; ++s) {
|
| 382 |
+
udata->seq_idx[s] = s;
|
| 383 |
+
udata->seq_id_unq.push_back(s);
|
| 384 |
}
|
| 385 |
|
| 386 |
llama_ubatch res {
|
| 387 |
+
/*.b_equal_seqs =*/ true,
|
| 388 |
/*.n_tokens =*/ n_tokens,
|
| 389 |
/*.n_seq_tokens =*/ n_seq_tokens,
|
| 390 |
/*.n_seqs =*/ n_seqs,
|
| 391 |
/*.n_seqs_unq =*/ n_seqs,
|
| 392 |
|
| 393 |
+
/*.token =*/ udata->token.data(),
|
| 394 |
/*.embd =*/ nullptr,
|
| 395 |
+
/*.pos =*/ udata->pos.data(),
|
| 396 |
+
/*.n_seq_id =*/ udata->n_seq_id.data(),
|
| 397 |
+
/*.seq_id =*/ udata->seq_id.data(),
|
| 398 |
+
/*.seq_id_unq =*/ udata->seq_id_unq.data(),
|
| 399 |
+
/*.seq_idx =*/ udata->seq_idx.data(),
|
| 400 |
+
/*.output =*/ udata->output.data(),
|
| 401 |
+
/*.data =*/ std::move(udata),
|
| 402 |
};
|
| 403 |
|
| 404 |
return res;
|
|
|
|
| 439 |
|
| 440 |
used.clear();
|
| 441 |
used.resize(get_n_tokens(), false);
|
|
|
|
|
|
|
| 442 |
}
|
| 443 |
|
| 444 |
llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
|
|
|
|
| 653 |
|
| 654 |
assert(n_tokens%n_seqs == 0);
|
| 655 |
|
| 656 |
+
auto udata = std::make_shared<llama_ubatch::data_t>();
|
|
|
|
|
|
|
| 657 |
|
| 658 |
const int32_t n_pos_cur = batch.embd ? n_pos_per_embd : 1;
|
| 659 |
|
| 660 |
const int64_t n_embd_all = batch.embd ? (int64_t) n_tokens*n_embd : 0;
|
| 661 |
const int64_t n_pos_all = (int64_t) n_tokens*n_pos_cur;
|
| 662 |
|
| 663 |
+
udata->token .resize(n_tokens);
|
| 664 |
+
udata->embd .resize(n_embd_all);
|
| 665 |
+
udata->pos .resize(n_pos_all);
|
| 666 |
+
udata->n_seq_id .resize(n_tokens);
|
| 667 |
+
udata->seq_id .resize(n_tokens);
|
| 668 |
+
udata->seq_id_unq.resize(0);
|
| 669 |
+
udata->seq_idx .resize(LLAMA_MAX_SEQ, -1);
|
| 670 |
+
udata->output .resize(n_tokens);
|
| 671 |
|
| 672 |
seq_set_t seq_set_unq;
|
| 673 |
|
| 674 |
for (size_t i = 0; i < idxs.size(); ++i) {
|
| 675 |
if (batch.token) {
|
| 676 |
+
udata->token[i] = batch.token[idxs[i]];
|
| 677 |
}
|
| 678 |
|
| 679 |
if (batch.embd) {
|
| 680 |
+
memcpy(udata->embd.data() + i*n_embd, batch.embd + (int64_t) idxs[i]*n_embd, n_embd*sizeof(float));
|
| 681 |
}
|
| 682 |
|
| 683 |
for (int j = 0; j < n_pos_cur; ++j) {
|
| 684 |
+
udata->pos[j*n_tokens + i] = batch.pos[j*batch.n_tokens + idxs[i]];
|
| 685 |
}
|
| 686 |
|
| 687 |
+
udata->n_seq_id[i] = batch.n_seq_id[idxs[i]];
|
| 688 |
+
udata->seq_id[i] = batch.seq_id[idxs[i]];
|
| 689 |
+
udata->output[i] = batch.logits[idxs[i]];
|
| 690 |
|
| 691 |
+
for (int s = 0; s < udata->n_seq_id[i]; ++s) {
|
| 692 |
+
seq_set_unq.set(udata->seq_id[i][s]);
|
| 693 |
}
|
| 694 |
|
| 695 |
+
if (udata->output[i]) {
|
| 696 |
out_ids.push_back(idxs[i]);
|
| 697 |
}
|
| 698 |
}
|
| 699 |
|
| 700 |
+
for (uint32_t s = 0; s < n_seq_max; ++s) {
|
| 701 |
if (seq_set_unq.test(s)) {
|
| 702 |
+
udata->seq_idx[s] = udata->seq_id_unq.size();
|
| 703 |
+
udata->seq_id_unq.push_back(s);
|
| 704 |
}
|
| 705 |
}
|
| 706 |
|
| 707 |
llama_ubatch res {
|
| 708 |
+
/*.b_equal_seqs =*/ equal_seqs,
|
| 709 |
/*.n_tokens =*/ n_tokens,
|
| 710 |
/*.n_seq_tokens =*/ n_tokens/n_seqs,
|
| 711 |
/*.n_seqs =*/ n_seqs,
|
| 712 |
+
/*.n_seqs_unq =*/ (uint32_t) udata->seq_id_unq.size(),
|
| 713 |
+
|
| 714 |
+
/*.token =*/ batch.token ? udata->token.data() : nullptr,
|
| 715 |
+
/*.embd =*/ batch.embd ? udata->embd.data() : nullptr,
|
| 716 |
+
/*.pos =*/ udata->pos.data(),
|
| 717 |
+
/*.n_seq_id =*/ udata->n_seq_id.data(),
|
| 718 |
+
/*.seq_id =*/ udata->seq_id.data(),
|
| 719 |
+
/*.seq_id_unq =*/ udata->seq_id_unq.data(),
|
| 720 |
+
/*.seq_idx =*/ udata->seq_idx.data(),
|
| 721 |
+
/*.output =*/ udata->output.data(),
|
| 722 |
+
/*.data =*/ std::move(udata),
|
| 723 |
};
|
| 724 |
|
| 725 |
if (debug > 0) {
|
| 726 |
+
LLAMA_LOG_DEBUG("%s: added ubatch to split:\n", __func__);
|
| 727 |
|
| 728 |
ubatch_print(res, debug);
|
| 729 |
}
|
|
|
|
| 733 |
|
| 734 |
void llama_batch_allocr::ubatch_print(const llama_ubatch & ubatch, int debug) {
|
| 735 |
if (debug > 0) {
|
| 736 |
+
LLAMA_LOG_DEBUG("%s: equal_seqs = %d\n", __func__, ubatch.equal_seqs());
|
| 737 |
LLAMA_LOG_DEBUG("%s: n_tokens = %d\n", __func__, ubatch.n_tokens);
|
| 738 |
LLAMA_LOG_DEBUG("%s: n_seq_tokens = %d\n", __func__, ubatch.n_seq_tokens);
|
| 739 |
LLAMA_LOG_DEBUG("%s: n_seqs = %d\n", __func__, ubatch.n_seqs);
|
examples/talk-llama/llama-batch.h
CHANGED
|
@@ -8,12 +8,17 @@
|
|
| 8 |
#include <vector>
|
| 9 |
#include <set>
|
| 10 |
#include <bitset>
|
|
|
|
| 11 |
#include <unordered_map>
|
| 12 |
|
| 13 |
// keep this struct lightweight
|
| 14 |
-
// it points to data in `llama_batch_allocr`
|
| 15 |
struct llama_ubatch {
|
| 16 |
-
bool equal_seqs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
// TODO: whole_seqs for embeddings?
|
| 18 |
|
| 19 |
uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs)
|
|
@@ -34,6 +39,20 @@ struct llama_ubatch {
|
|
| 34 |
llama_seq_id * seq_id_unq; // [n_seqs_unq] | s | seq_id
|
| 35 |
int32_t * seq_idx; // [LLAMA_MAX_SEQ] | - | seq_idx
|
| 36 |
int8_t * output; // [n_tokens] | i | -
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
};
|
| 38 |
|
| 39 |
// a helper for sanitizing, fulfilling and splitting a batch
|
|
@@ -48,6 +67,7 @@ public:
|
|
| 48 |
const llama_vocab & vocab,
|
| 49 |
const llama_memory_i * memory,
|
| 50 |
uint32_t n_embd,
|
|
|
|
| 51 |
bool output_all);
|
| 52 |
|
| 53 |
const llama_batch & get_batch() const;
|
|
@@ -100,6 +120,7 @@ private:
|
|
| 100 |
const uint32_t n_pos_per_embd;
|
| 101 |
|
| 102 |
uint32_t n_embd;
|
|
|
|
| 103 |
uint32_t n_outputs;
|
| 104 |
|
| 105 |
std::array<llama_seq_id, 1> seq_id_0 = { 0 }; // default sequence id
|
|
@@ -115,7 +136,7 @@ private:
|
|
| 115 |
using seq_cpl_t = std::vector<bool>;
|
| 116 |
|
| 117 |
// helper flag to quickly determine if there are any coupled sequences in the batch
|
| 118 |
-
bool has_cpl;
|
| 119 |
|
| 120 |
std::vector<pos_set_t> seq_pos; // seq_pos[s]: the set of positions in sequence s
|
| 121 |
std::vector<seq_cpl_t> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1
|
|
@@ -135,20 +156,5 @@ private:
|
|
| 135 |
// used[i] indicates if token i has already been used in a previous ubatch
|
| 136 |
std::vector<bool> used;
|
| 137 |
|
| 138 |
-
// llama_ubatch points to this data:
|
| 139 |
-
struct ubatch {
|
| 140 |
-
std::vector<llama_token> token;
|
| 141 |
-
std::vector<float> embd;
|
| 142 |
-
std::vector<llama_pos> pos;
|
| 143 |
-
std::vector<int32_t> n_seq_id;
|
| 144 |
-
std::vector<llama_seq_id *> seq_id;
|
| 145 |
-
std::vector<llama_seq_id> seq_id_unq;
|
| 146 |
-
std::vector<int32_t> seq_idx;
|
| 147 |
-
std::vector<int8_t> output;
|
| 148 |
-
};
|
| 149 |
-
|
| 150 |
-
// current splitting state:
|
| 151 |
-
std::vector<ubatch> ubatches;
|
| 152 |
-
|
| 153 |
int debug;
|
| 154 |
};
|
|
|
|
| 8 |
#include <vector>
|
| 9 |
#include <set>
|
| 10 |
#include <bitset>
|
| 11 |
+
#include <memory>
|
| 12 |
#include <unordered_map>
|
| 13 |
|
| 14 |
// keep this struct lightweight
|
|
|
|
| 15 |
struct llama_ubatch {
|
| 16 |
+
bool equal_seqs() const {
|
| 17 |
+
return b_equal_seqs != 0;
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
uint32_t b_equal_seqs; // note: this is a boolean, but we use an int32_t for alignment
|
| 21 |
+
// otherwise address sanitizer complains
|
| 22 |
// TODO: whole_seqs for embeddings?
|
| 23 |
|
| 24 |
uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs)
|
|
|
|
| 39 |
llama_seq_id * seq_id_unq; // [n_seqs_unq] | s | seq_id
|
| 40 |
int32_t * seq_idx; // [LLAMA_MAX_SEQ] | - | seq_idx
|
| 41 |
int8_t * output; // [n_tokens] | i | -
|
| 42 |
+
|
| 43 |
+
struct data_t {
|
| 44 |
+
std::vector<llama_token> token;
|
| 45 |
+
std::vector<float> embd;
|
| 46 |
+
std::vector<llama_pos> pos;
|
| 47 |
+
std::vector<int32_t> n_seq_id;
|
| 48 |
+
std::vector<llama_seq_id *> seq_id;
|
| 49 |
+
std::vector<llama_seq_id> seq_id_unq;
|
| 50 |
+
std::vector<int32_t> seq_idx;
|
| 51 |
+
std::vector<int8_t> output;
|
| 52 |
+
};
|
| 53 |
+
|
| 54 |
+
// the llama_ubatch pointers above point to this data if set. otherwise - points to non-owning data
|
| 55 |
+
std::shared_ptr<data_t> data;
|
| 56 |
};
|
| 57 |
|
| 58 |
// a helper for sanitizing, fulfilling and splitting a batch
|
|
|
|
| 67 |
const llama_vocab & vocab,
|
| 68 |
const llama_memory_i * memory,
|
| 69 |
uint32_t n_embd,
|
| 70 |
+
uint32_t n_seq_max,
|
| 71 |
bool output_all);
|
| 72 |
|
| 73 |
const llama_batch & get_batch() const;
|
|
|
|
| 120 |
const uint32_t n_pos_per_embd;
|
| 121 |
|
| 122 |
uint32_t n_embd;
|
| 123 |
+
uint32_t n_seq_max;
|
| 124 |
uint32_t n_outputs;
|
| 125 |
|
| 126 |
std::array<llama_seq_id, 1> seq_id_0 = { 0 }; // default sequence id
|
|
|
|
| 136 |
using seq_cpl_t = std::vector<bool>;
|
| 137 |
|
| 138 |
// helper flag to quickly determine if there are any coupled sequences in the batch
|
| 139 |
+
bool has_cpl = false;
|
| 140 |
|
| 141 |
std::vector<pos_set_t> seq_pos; // seq_pos[s]: the set of positions in sequence s
|
| 142 |
std::vector<seq_cpl_t> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1
|
|
|
|
| 156 |
// used[i] indicates if token i has already been used in a previous ubatch
|
| 157 |
std::vector<bool> used;
|
| 158 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
int debug;
|
| 160 |
};
|
examples/talk-llama/llama-chat.cpp
CHANGED
|
@@ -56,6 +56,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
|
|
| 56 |
{ "glmedge", LLM_CHAT_TEMPLATE_GLMEDGE },
|
| 57 |
{ "minicpm", LLM_CHAT_TEMPLATE_MINICPM },
|
| 58 |
{ "exaone3", LLM_CHAT_TEMPLATE_EXAONE_3 },
|
|
|
|
| 59 |
{ "rwkv-world", LLM_CHAT_TEMPLATE_RWKV_WORLD },
|
| 60 |
{ "granite", LLM_CHAT_TEMPLATE_GRANITE },
|
| 61 |
{ "gigachat", LLM_CHAT_TEMPLATE_GIGACHAT },
|
|
@@ -65,6 +66,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
|
|
| 65 |
{ "llama4", LLM_CHAT_TEMPLATE_LLAMA4 },
|
| 66 |
{ "smolvlm", LLM_CHAT_TEMPLATE_SMOLVLM },
|
| 67 |
{ "hunyuan-moe", LLM_CHAT_TEMPLATE_HUNYUAN_MOE },
|
|
|
|
| 68 |
};
|
| 69 |
|
| 70 |
llm_chat_template llm_chat_template_from_str(const std::string & name) {
|
|
@@ -167,10 +169,13 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
|
|
| 167 |
} else if (tmpl_contains(LU8("<|Assistant|>")) && tmpl_contains(LU8("<|User|>")) && tmpl_contains(LU8("<|end▁of▁sentence|>"))) {
|
| 168 |
return LLM_CHAT_TEMPLATE_DEEPSEEK_3;
|
| 169 |
} else if (tmpl_contains("[|system|]") && tmpl_contains("[|assistant|]") && tmpl_contains("[|endofturn|]")) {
|
|
|
|
|
|
|
|
|
|
| 170 |
// ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/discussions/8#66bae61b1893d14ee8ed85bb
|
| 171 |
// EXAONE-3.0-7.8B-Instruct
|
| 172 |
return LLM_CHAT_TEMPLATE_EXAONE_3;
|
| 173 |
-
} else if (tmpl_contains("rwkv-world")) {
|
| 174 |
return LLM_CHAT_TEMPLATE_RWKV_WORLD;
|
| 175 |
} else if (tmpl_contains("<|start_of_role|>")) {
|
| 176 |
return LLM_CHAT_TEMPLATE_GRANITE;
|
|
@@ -188,6 +193,8 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
|
|
| 188 |
return LLM_CHAT_TEMPLATE_DOTS1;
|
| 189 |
} else if (tmpl_contains("<|startoftext|>") && tmpl_contains("<|extra_4|>")) {
|
| 190 |
return LLM_CHAT_TEMPLATE_HUNYUAN_MOE;
|
|
|
|
|
|
|
| 191 |
}
|
| 192 |
return LLM_CHAT_TEMPLATE_UNKNOWN;
|
| 193 |
}
|
|
@@ -529,6 +536,22 @@ int32_t llm_chat_apply_template(
|
|
| 529 |
if (add_ass) {
|
| 530 |
ss << "[|assistant|]";
|
| 531 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 532 |
} else if (tmpl == LLM_CHAT_TEMPLATE_RWKV_WORLD) {
|
| 533 |
// this template requires the model to have "\n\n" as EOT token
|
| 534 |
for (size_t i = 0; i < chat.size(); i++) {
|
|
@@ -680,6 +703,25 @@ int32_t llm_chat_apply_template(
|
|
| 680 |
ss << "<|startoftext|>" << message->content << "<|extra_0|>";
|
| 681 |
}
|
| 682 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 683 |
} else {
|
| 684 |
// template not supported
|
| 685 |
return -1;
|
|
|
|
| 56 |
{ "glmedge", LLM_CHAT_TEMPLATE_GLMEDGE },
|
| 57 |
{ "minicpm", LLM_CHAT_TEMPLATE_MINICPM },
|
| 58 |
{ "exaone3", LLM_CHAT_TEMPLATE_EXAONE_3 },
|
| 59 |
+
{ "exaone4", LLM_CHAT_TEMPLATE_EXAONE_4 },
|
| 60 |
{ "rwkv-world", LLM_CHAT_TEMPLATE_RWKV_WORLD },
|
| 61 |
{ "granite", LLM_CHAT_TEMPLATE_GRANITE },
|
| 62 |
{ "gigachat", LLM_CHAT_TEMPLATE_GIGACHAT },
|
|
|
|
| 66 |
{ "llama4", LLM_CHAT_TEMPLATE_LLAMA4 },
|
| 67 |
{ "smolvlm", LLM_CHAT_TEMPLATE_SMOLVLM },
|
| 68 |
{ "hunyuan-moe", LLM_CHAT_TEMPLATE_HUNYUAN_MOE },
|
| 69 |
+
{ "kimi-k2", LLM_CHAT_TEMPLATE_KIMI_K2 },
|
| 70 |
};
|
| 71 |
|
| 72 |
llm_chat_template llm_chat_template_from_str(const std::string & name) {
|
|
|
|
| 169 |
} else if (tmpl_contains(LU8("<|Assistant|>")) && tmpl_contains(LU8("<|User|>")) && tmpl_contains(LU8("<|end▁of▁sentence|>"))) {
|
| 170 |
return LLM_CHAT_TEMPLATE_DEEPSEEK_3;
|
| 171 |
} else if (tmpl_contains("[|system|]") && tmpl_contains("[|assistant|]") && tmpl_contains("[|endofturn|]")) {
|
| 172 |
+
if (tmpl_contains("[|tool|]")) {
|
| 173 |
+
return LLM_CHAT_TEMPLATE_EXAONE_4;
|
| 174 |
+
}
|
| 175 |
// ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/discussions/8#66bae61b1893d14ee8ed85bb
|
| 176 |
// EXAONE-3.0-7.8B-Instruct
|
| 177 |
return LLM_CHAT_TEMPLATE_EXAONE_3;
|
| 178 |
+
} else if (tmpl_contains("rwkv-world") || tmpl_contains("{{- 'User: ' + message['content']|trim + '\\n\\n' -}}")) {
|
| 179 |
return LLM_CHAT_TEMPLATE_RWKV_WORLD;
|
| 180 |
} else if (tmpl_contains("<|start_of_role|>")) {
|
| 181 |
return LLM_CHAT_TEMPLATE_GRANITE;
|
|
|
|
| 193 |
return LLM_CHAT_TEMPLATE_DOTS1;
|
| 194 |
} else if (tmpl_contains("<|startoftext|>") && tmpl_contains("<|extra_4|>")) {
|
| 195 |
return LLM_CHAT_TEMPLATE_HUNYUAN_MOE;
|
| 196 |
+
} else if (tmpl_contains("<|im_assistant|>assistant<|im_middle|>")) {
|
| 197 |
+
return LLM_CHAT_TEMPLATE_KIMI_K2;
|
| 198 |
}
|
| 199 |
return LLM_CHAT_TEMPLATE_UNKNOWN;
|
| 200 |
}
|
|
|
|
| 536 |
if (add_ass) {
|
| 537 |
ss << "[|assistant|]";
|
| 538 |
}
|
| 539 |
+
} else if (tmpl == LLM_CHAT_TEMPLATE_EXAONE_4) {
|
| 540 |
+
for (auto message : chat) {
|
| 541 |
+
std::string role(message->role);
|
| 542 |
+
if (role == "system") {
|
| 543 |
+
ss << "[|system|]" << trim(message->content) << "[|endofturn|]\n";
|
| 544 |
+
} else if (role == "user") {
|
| 545 |
+
ss << "[|user|]" << trim(message->content) << "\n";
|
| 546 |
+
} else if (role == "assistant") {
|
| 547 |
+
ss << "[|assistant|]" << trim(message->content) << "[|endofturn|]\n";
|
| 548 |
+
} else if (role == "tool") {
|
| 549 |
+
ss << "[|tool|]" << trim(message->content) << "[|endofturn|]\n";
|
| 550 |
+
}
|
| 551 |
+
}
|
| 552 |
+
if (add_ass) {
|
| 553 |
+
ss << "[|assistant|]";
|
| 554 |
+
}
|
| 555 |
} else if (tmpl == LLM_CHAT_TEMPLATE_RWKV_WORLD) {
|
| 556 |
// this template requires the model to have "\n\n" as EOT token
|
| 557 |
for (size_t i = 0; i < chat.size(); i++) {
|
|
|
|
| 703 |
ss << "<|startoftext|>" << message->content << "<|extra_0|>";
|
| 704 |
}
|
| 705 |
}
|
| 706 |
+
} else if (tmpl == LLM_CHAT_TEMPLATE_KIMI_K2) {
|
| 707 |
+
// moonshotai/Kimi-K2-Instruct
|
| 708 |
+
for (auto message : chat) {
|
| 709 |
+
std::string role(message->role);
|
| 710 |
+
if (role == "system") {
|
| 711 |
+
ss << "<|im_system|>system<|im_middle|>";
|
| 712 |
+
} else if (role == "user") {
|
| 713 |
+
ss << "<|im_user|>user<|im_middle|>";
|
| 714 |
+
} else if (role == "assistant") {
|
| 715 |
+
ss << "<|im_assistant|>assistant<|im_middle|>";
|
| 716 |
+
} else if (role == "tool") {
|
| 717 |
+
ss << "<|im_system|>tool<|im_middle|>";
|
| 718 |
+
}
|
| 719 |
+
|
| 720 |
+
ss << message->content << "<|im_end|>";
|
| 721 |
+
}
|
| 722 |
+
if (add_ass) {
|
| 723 |
+
ss << "<|im_assistant|>assistant<|im_middle|>";
|
| 724 |
+
}
|
| 725 |
} else {
|
| 726 |
// template not supported
|
| 727 |
return -1;
|
examples/talk-llama/llama-chat.h
CHANGED
|
@@ -35,6 +35,7 @@ enum llm_chat_template {
|
|
| 35 |
LLM_CHAT_TEMPLATE_GLMEDGE,
|
| 36 |
LLM_CHAT_TEMPLATE_MINICPM,
|
| 37 |
LLM_CHAT_TEMPLATE_EXAONE_3,
|
|
|
|
| 38 |
LLM_CHAT_TEMPLATE_RWKV_WORLD,
|
| 39 |
LLM_CHAT_TEMPLATE_GRANITE,
|
| 40 |
LLM_CHAT_TEMPLATE_GIGACHAT,
|
|
@@ -45,6 +46,7 @@ enum llm_chat_template {
|
|
| 45 |
LLM_CHAT_TEMPLATE_SMOLVLM,
|
| 46 |
LLM_CHAT_TEMPLATE_DOTS1,
|
| 47 |
LLM_CHAT_TEMPLATE_HUNYUAN_MOE,
|
|
|
|
| 48 |
LLM_CHAT_TEMPLATE_UNKNOWN,
|
| 49 |
};
|
| 50 |
|
|
|
|
| 35 |
LLM_CHAT_TEMPLATE_GLMEDGE,
|
| 36 |
LLM_CHAT_TEMPLATE_MINICPM,
|
| 37 |
LLM_CHAT_TEMPLATE_EXAONE_3,
|
| 38 |
+
LLM_CHAT_TEMPLATE_EXAONE_4,
|
| 39 |
LLM_CHAT_TEMPLATE_RWKV_WORLD,
|
| 40 |
LLM_CHAT_TEMPLATE_GRANITE,
|
| 41 |
LLM_CHAT_TEMPLATE_GIGACHAT,
|
|
|
|
| 46 |
LLM_CHAT_TEMPLATE_SMOLVLM,
|
| 47 |
LLM_CHAT_TEMPLATE_DOTS1,
|
| 48 |
LLM_CHAT_TEMPLATE_HUNYUAN_MOE,
|
| 49 |
+
LLM_CHAT_TEMPLATE_KIMI_K2,
|
| 50 |
LLM_CHAT_TEMPLATE_UNKNOWN,
|
| 51 |
};
|
| 52 |
|
examples/talk-llama/llama-context.cpp
CHANGED
|
@@ -98,10 +98,20 @@ llama_context::llama_context(
|
|
| 98 |
LLAMA_LOG_WARN("%s: n_batch is less than GGML_KQ_MASK_PAD - increasing to %d\n", __func__, GGML_KQ_MASK_PAD);
|
| 99 |
cparams.n_batch = GGML_KQ_MASK_PAD;
|
| 100 |
}
|
| 101 |
-
|
| 102 |
cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
|
| 103 |
|
| 104 |
cparams.op_offload = params.op_offload;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
|
| 106 |
const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
|
| 107 |
|
|
@@ -112,6 +122,7 @@ llama_context::llama_context(
|
|
| 112 |
LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
|
| 113 |
LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn);
|
| 114 |
LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn);
|
|
|
|
| 115 |
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
|
| 116 |
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);
|
| 117 |
|
|
@@ -227,8 +238,8 @@ llama_context::llama_context(
|
|
| 227 |
|
| 228 |
LLAMA_LOG_DEBUG("%s: max_nodes = %zu\n", __func__, max_nodes);
|
| 229 |
|
| 230 |
-
|
| 231 |
-
|
| 232 |
|
| 233 |
// TODO: move these checks to ggml_backend_sched
|
| 234 |
// enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary
|
|
@@ -267,7 +278,7 @@ llama_context::llama_context(
|
|
| 267 |
|
| 268 |
// reserve worst-case graph
|
| 269 |
if (!hparams.vocab_only && memory) {
|
| 270 |
-
const uint32_t n_seqs = cparams.n_seq_max;
|
| 271 |
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
| 272 |
|
| 273 |
LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
|
|
@@ -287,7 +298,7 @@ llama_context::llama_context(
|
|
| 287 |
|
| 288 |
cross.v_embd.clear();
|
| 289 |
|
| 290 |
-
// reserve pp graph first so that buffers are only allocated once
|
| 291 |
{
|
| 292 |
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
|
| 293 |
if (!gf) {
|
|
@@ -298,9 +309,9 @@ llama_context::llama_context(
|
|
| 298 |
n_nodes_pp = ggml_graph_n_nodes(gf);
|
| 299 |
}
|
| 300 |
|
| 301 |
-
// reserve with tg graph to get the number of splits and nodes
|
| 302 |
{
|
| 303 |
-
auto * gf = graph_reserve(
|
| 304 |
if (!gf) {
|
| 305 |
throw std::runtime_error("failed to allocate compute tg buffers");
|
| 306 |
}
|
|
@@ -311,6 +322,10 @@ llama_context::llama_context(
|
|
| 311 |
|
| 312 |
// reserve again with pp graph to avoid ggml-alloc reallocations during inference
|
| 313 |
{
|
|
|
|
|
|
|
|
|
|
|
|
|
| 314 |
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
|
| 315 |
if (!gf) {
|
| 316 |
throw std::runtime_error("failed to allocate compute pp buffers");
|
|
@@ -388,10 +403,6 @@ ggml_backend_sched_t llama_context::get_sched() const {
|
|
| 388 |
return sched.get();
|
| 389 |
}
|
| 390 |
|
| 391 |
-
ggml_context * llama_context::get_ctx_compute() const {
|
| 392 |
-
return ctx_compute.get();
|
| 393 |
-
}
|
| 394 |
-
|
| 395 |
uint32_t llama_context::n_ctx() const {
|
| 396 |
return cparams.n_ctx;
|
| 397 |
}
|
|
@@ -463,6 +474,11 @@ bool llama_context::kv_self_update(bool optimize) {
|
|
| 463 |
}
|
| 464 |
}
|
| 465 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 466 |
if (!mctx->apply()) {
|
| 467 |
LLAMA_LOG_ERROR("%s: failed to apply memory update\n", __func__);
|
| 468 |
}
|
|
@@ -475,7 +491,7 @@ bool llama_context::kv_self_update(bool optimize) {
|
|
| 475 |
throw std::runtime_error("failed to initialize memory context");
|
| 476 |
}
|
| 477 |
|
| 478 |
-
const uint32_t n_seqs
|
| 479 |
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
| 480 |
|
| 481 |
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
|
|
@@ -492,12 +508,16 @@ enum llama_pooling_type llama_context::pooling_type() const {
|
|
| 492 |
}
|
| 493 |
|
| 494 |
float * llama_context::get_logits() {
|
|
|
|
|
|
|
| 495 |
return logits;
|
| 496 |
}
|
| 497 |
|
| 498 |
float * llama_context::get_logits_ith(int32_t i) {
|
| 499 |
int64_t j = -1;
|
| 500 |
|
|
|
|
|
|
|
| 501 |
try {
|
| 502 |
if (logits == nullptr) {
|
| 503 |
throw std::runtime_error("no logits");
|
|
@@ -534,12 +554,16 @@ float * llama_context::get_logits_ith(int32_t i) {
|
|
| 534 |
}
|
| 535 |
|
| 536 |
float * llama_context::get_embeddings() {
|
|
|
|
|
|
|
| 537 |
return embd;
|
| 538 |
}
|
| 539 |
|
| 540 |
float * llama_context::get_embeddings_ith(int32_t i) {
|
| 541 |
int64_t j = -1;
|
| 542 |
|
|
|
|
|
|
|
| 543 |
try {
|
| 544 |
if (embd == nullptr) {
|
| 545 |
throw std::runtime_error("no embeddings");
|
|
@@ -678,38 +702,59 @@ bool llama_context::apply_adapter_cvec(
|
|
| 678 |
return cvec.apply(model, data, len, n_embd, il_start, il_end);
|
| 679 |
}
|
| 680 |
|
| 681 |
-
|
| 682 |
if (mctx && !mctx->apply()) {
|
| 683 |
LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__);
|
| 684 |
ret = GGML_STATUS_FAILED;
|
| 685 |
return nullptr;
|
| 686 |
}
|
| 687 |
|
| 688 |
-
auto *
|
| 689 |
-
|
| 690 |
-
LLAMA_LOG_ERROR("%s: failed to initialize graph\n", __func__);
|
| 691 |
-
ret = GGML_STATUS_FAILED;
|
| 692 |
-
return nullptr;
|
| 693 |
-
}
|
| 694 |
|
| 695 |
-
|
| 696 |
-
|
| 697 |
-
|
| 698 |
-
ret = GGML_STATUS_FAILED;
|
| 699 |
-
return nullptr;
|
| 700 |
-
}
|
| 701 |
|
| 702 |
-
|
|
|
|
| 703 |
|
| 704 |
-
|
| 705 |
-
|
| 706 |
-
|
| 707 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 708 |
}
|
| 709 |
|
| 710 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 711 |
|
| 712 |
-
const auto status = graph_compute(
|
| 713 |
if (status != GGML_STATUS_SUCCESS) {
|
| 714 |
LLAMA_LOG_ERROR("%s: failed to compute graph, compute status: %d\n", __func__, status);
|
| 715 |
ret = status;
|
|
@@ -731,16 +776,19 @@ int llama_context::encode(const llama_batch & batch_inp) {
|
|
| 731 |
|
| 732 |
const auto & hparams = model.hparams;
|
| 733 |
|
| 734 |
-
const int64_t n_embd
|
|
|
|
| 735 |
|
| 736 |
// note: during encode, we always pass the full sequence starting from pos = 0
|
| 737 |
-
if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd, true)) {
|
| 738 |
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
|
| 739 |
return -1;
|
| 740 |
}
|
| 741 |
|
| 742 |
const uint32_t n_tokens = balloc->get_n_tokens();
|
| 743 |
|
|
|
|
|
|
|
| 744 |
const llama_ubatch ubatch = balloc->split_simple(n_tokens);
|
| 745 |
|
| 746 |
// micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
|
|
@@ -767,9 +815,6 @@ int llama_context::encode(const llama_batch & batch_inp) {
|
|
| 767 |
|
| 768 |
n_outputs = n_tokens;
|
| 769 |
|
| 770 |
-
ggml_backend_sched_reset(sched.get());
|
| 771 |
-
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
|
| 772 |
-
|
| 773 |
const auto causal_attn_org = cparams.causal_attn;
|
| 774 |
|
| 775 |
// always use non-causal attention for encoder graphs
|
|
@@ -778,7 +823,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
|
|
| 778 |
cparams.causal_attn = false;
|
| 779 |
|
| 780 |
ggml_status status;
|
| 781 |
-
const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status);
|
| 782 |
|
| 783 |
cparams.causal_attn = causal_attn_org;
|
| 784 |
|
|
@@ -791,10 +836,20 @@ int llama_context::encode(const llama_batch & batch_inp) {
|
|
| 791 |
}
|
| 792 |
}
|
| 793 |
|
|
|
|
| 794 |
auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd();
|
| 795 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 796 |
// extract embeddings
|
| 797 |
-
if (t_embd) {
|
| 798 |
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
|
| 799 |
GGML_ASSERT(backend_embd != nullptr);
|
| 800 |
|
|
@@ -844,9 +899,11 @@ int llama_context::encode(const llama_batch & batch_inp) {
|
|
| 844 |
}
|
| 845 |
}
|
| 846 |
|
| 847 |
-
|
| 848 |
-
|
| 849 |
-
|
|
|
|
|
|
|
| 850 |
|
| 851 |
// TODO: hacky solution
|
| 852 |
if (model.arch == LLM_ARCH_T5 && t_embd) {
|
|
@@ -899,7 +956,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|
| 899 |
// when computing embeddings, all tokens are output
|
| 900 |
const bool output_all = cparams.embeddings;
|
| 901 |
|
| 902 |
-
if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, output_all)) {
|
| 903 |
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
|
| 904 |
return -1;
|
| 905 |
}
|
|
@@ -927,6 +984,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|
| 927 |
|
| 928 |
// TODO: this clear of the buffer can easily be forgotten - need something better
|
| 929 |
embd_seq.clear();
|
|
|
|
| 930 |
|
| 931 |
bool did_optimize = false;
|
| 932 |
|
|
@@ -1005,11 +1063,8 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|
| 1005 |
n_outputs = n_outputs_new;
|
| 1006 |
}
|
| 1007 |
|
| 1008 |
-
ggml_backend_sched_reset(sched.get());
|
| 1009 |
-
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
|
| 1010 |
-
|
| 1011 |
ggml_status status;
|
| 1012 |
-
const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status);
|
| 1013 |
|
| 1014 |
if (!res) {
|
| 1015 |
// the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
|
|
@@ -1149,9 +1204,6 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|
| 1149 |
// make the outputs have the same order they had in the user-provided batch
|
| 1150 |
// note: this is mostly relevant for recurrent models atm
|
| 1151 |
if (!sorted_output) {
|
| 1152 |
-
const uint32_t n_vocab = model.vocab.n_tokens();
|
| 1153 |
-
const uint64_t n_embd = model.hparams.n_embd;
|
| 1154 |
-
|
| 1155 |
GGML_ASSERT((size_t) n_outputs == out_ids.size());
|
| 1156 |
|
| 1157 |
// TODO: is there something more efficient which also minimizes swaps?
|
|
@@ -1167,16 +1219,9 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|
| 1167 |
continue;
|
| 1168 |
}
|
| 1169 |
std::swap(out_ids[i], out_ids[j_min]);
|
| 1170 |
-
|
| 1171 |
-
|
| 1172 |
-
|
| 1173 |
-
}
|
| 1174 |
-
}
|
| 1175 |
-
if (embd_size > 0) {
|
| 1176 |
-
for (uint32_t k = 0; k < n_embd; k++) {
|
| 1177 |
-
std::swap(embd[i*n_embd + k], embd[j_min*n_embd + k]);
|
| 1178 |
-
}
|
| 1179 |
-
}
|
| 1180 |
}
|
| 1181 |
|
| 1182 |
std::fill(output_ids.begin(), output_ids.end(), -1);
|
|
@@ -1190,9 +1235,11 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|
| 1190 |
// wait for the computation to finish (automatically done when obtaining the model output)
|
| 1191 |
//synchronize();
|
| 1192 |
|
| 1193 |
-
|
| 1194 |
-
|
| 1195 |
-
|
|
|
|
|
|
|
| 1196 |
|
| 1197 |
return 0;
|
| 1198 |
}
|
|
@@ -1271,24 +1318,40 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
|
|
| 1271 |
return n_outputs_max;
|
| 1272 |
}
|
| 1273 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1274 |
//
|
| 1275 |
// graph
|
| 1276 |
//
|
| 1277 |
|
| 1278 |
-
|
| 1279 |
-
return std::max<
|
| 1280 |
}
|
| 1281 |
|
| 1282 |
-
|
| 1283 |
-
|
| 1284 |
-
/*.mem_size =*/ buf_compute_meta.size(),
|
| 1285 |
-
/*.mem_buffer =*/ buf_compute_meta.data(),
|
| 1286 |
-
/*.no_alloc =*/ true,
|
| 1287 |
-
};
|
| 1288 |
-
|
| 1289 |
-
ctx_compute.reset(ggml_init(params));
|
| 1290 |
-
|
| 1291 |
-
return ggml_new_graph_custom(ctx_compute.get(), graph_max_nodes(), false);
|
| 1292 |
}
|
| 1293 |
|
| 1294 |
ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx) {
|
|
@@ -1301,6 +1364,11 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
|
|
| 1301 |
LLAMA_LOG_DEBUG("%s: making n_tokens a multiple of n_seqs - n_tokens = %u, n_seqs = %u, n_outputs = %u\n", __func__, n_tokens, n_seqs, n_outputs);
|
| 1302 |
}
|
| 1303 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1304 |
// store the n_outputs as it is, and restore it afterwards
|
| 1305 |
// TODO: not sure if needed, might simplify in the future by removing this
|
| 1306 |
const auto save_n_outputs = this->n_outputs;
|
|
@@ -1310,17 +1378,15 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
|
|
| 1310 |
llama_batch_allocr balloc(model.hparams.n_pos_per_embd());
|
| 1311 |
llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens/n_seqs, n_seqs);
|
| 1312 |
|
| 1313 |
-
auto *
|
| 1314 |
-
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx);
|
| 1315 |
|
| 1316 |
-
|
| 1317 |
|
| 1318 |
-
|
| 1319 |
-
LLAMA_LOG_ERROR("%s: failed to build worst-case graph\n", __func__);
|
| 1320 |
-
return nullptr;
|
| 1321 |
-
}
|
| 1322 |
|
| 1323 |
-
|
|
|
|
|
|
|
| 1324 |
|
| 1325 |
// initialize scheduler with the specified graph
|
| 1326 |
if (!ggml_backend_sched_reserve(sched.get(), gf)) {
|
|
@@ -1331,28 +1397,27 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
|
|
| 1331 |
return gf;
|
| 1332 |
}
|
| 1333 |
|
| 1334 |
-
|
| 1335 |
-
|
| 1336 |
-
|
| 1337 |
-
|
| 1338 |
-
|
| 1339 |
-
|
| 1340 |
-
|
| 1341 |
-
|
| 1342 |
-
|
| 1343 |
-
|
| 1344 |
-
|
| 1345 |
-
|
| 1346 |
-
|
| 1347 |
-
|
| 1348 |
-
|
| 1349 |
-
|
| 1350 |
-
|
| 1351 |
-
|
| 1352 |
-
|
| 1353 |
-
|
| 1354 |
-
|
| 1355 |
-
}, gf, gtype);
|
| 1356 |
}
|
| 1357 |
|
| 1358 |
ggml_status llama_context::graph_compute(
|
|
@@ -1930,6 +1995,7 @@ llama_perf_context_data llama_context::perf_get_data() const {
|
|
| 1930 |
data.t_eval_ms = 1e-3 * t_eval_us;
|
| 1931 |
data.n_p_eval = std::max(1, n_p_eval);
|
| 1932 |
data.n_eval = std::max(1, n_eval);
|
|
|
|
| 1933 |
|
| 1934 |
return data;
|
| 1935 |
}
|
|
@@ -1938,6 +2004,7 @@ void llama_context::perf_reset() {
|
|
| 1938 |
t_start_us = ggml_time_us();
|
| 1939 |
t_eval_us = n_eval = 0;
|
| 1940 |
t_p_eval_us = n_p_eval = 0;
|
|
|
|
| 1941 |
}
|
| 1942 |
|
| 1943 |
//
|
|
@@ -2028,7 +2095,7 @@ void llama_context::opt_epoch_iter(
|
|
| 2028 |
batch.logits [pos_batch] = true;
|
| 2029 |
}
|
| 2030 |
|
| 2031 |
-
if (!balloc->init(batch, model.vocab, nullptr, model.hparams.n_embd, true)) {
|
| 2032 |
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
|
| 2033 |
return;
|
| 2034 |
}
|
|
@@ -2064,8 +2131,13 @@ void llama_context::opt_epoch_iter(
|
|
| 2064 |
break;
|
| 2065 |
}
|
| 2066 |
|
| 2067 |
-
auto *
|
| 2068 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2069 |
|
| 2070 |
struct ggml_context * ctx_compute_opt;
|
| 2071 |
{
|
|
@@ -2187,6 +2259,7 @@ llama_context_params llama_context_default_params() {
|
|
| 2187 |
/*.no_perf =*/ true,
|
| 2188 |
/*.op_offload =*/ true,
|
| 2189 |
/*.swa_full =*/ true,
|
|
|
|
| 2190 |
};
|
| 2191 |
|
| 2192 |
return result;
|
|
@@ -2807,6 +2880,7 @@ void llama_perf_context_print(const llama_context * ctx) {
|
|
| 2807 |
LLAMA_LOG_INFO("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
|
| 2808 |
__func__, data.t_eval_ms, data.n_eval, data.t_eval_ms / data.n_eval, 1e3 / data.t_eval_ms * data.n_eval);
|
| 2809 |
LLAMA_LOG_INFO("%s: total time = %10.2f ms / %5d tokens\n", __func__, (t_end_ms - data.t_start_ms), (data.n_p_eval + data.n_eval));
|
|
|
|
| 2810 |
}
|
| 2811 |
|
| 2812 |
void llama_perf_context_reset(llama_context * ctx) {
|
|
|
|
| 98 |
LLAMA_LOG_WARN("%s: n_batch is less than GGML_KQ_MASK_PAD - increasing to %d\n", __func__, GGML_KQ_MASK_PAD);
|
| 99 |
cparams.n_batch = GGML_KQ_MASK_PAD;
|
| 100 |
}
|
|
|
|
| 101 |
cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
|
| 102 |
|
| 103 |
cparams.op_offload = params.op_offload;
|
| 104 |
+
cparams.kv_unified = params.kv_unified;
|
| 105 |
+
|
| 106 |
+
{
|
| 107 |
+
const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS");
|
| 108 |
+
supports_set_rows = LLAMA_SET_ROWS ? (atoi(LLAMA_SET_ROWS) != 0) : false;
|
| 109 |
+
|
| 110 |
+
if (!supports_set_rows && !cparams.kv_unified) {
|
| 111 |
+
LLAMA_LOG_WARN("%s: non-unified KV cache requires ggml_set_rows() - forcing unified KV cache\n", __func__);
|
| 112 |
+
cparams.kv_unified = true;
|
| 113 |
+
}
|
| 114 |
+
}
|
| 115 |
|
| 116 |
const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
|
| 117 |
|
|
|
|
| 122 |
LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
|
| 123 |
LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn);
|
| 124 |
LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn);
|
| 125 |
+
LLAMA_LOG_INFO("%s: kv_unified = %s\n", __func__, cparams.kv_unified ? "true" : "false");
|
| 126 |
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
|
| 127 |
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);
|
| 128 |
|
|
|
|
| 238 |
|
| 239 |
LLAMA_LOG_DEBUG("%s: max_nodes = %zu\n", __func__, max_nodes);
|
| 240 |
|
| 241 |
+
gf_res_prev.reset(new llm_graph_result(max_nodes));
|
| 242 |
+
gf_res_reserve.reset(new llm_graph_result(max_nodes));
|
| 243 |
|
| 244 |
// TODO: move these checks to ggml_backend_sched
|
| 245 |
// enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary
|
|
|
|
| 278 |
|
| 279 |
// reserve worst-case graph
|
| 280 |
if (!hparams.vocab_only && memory) {
|
| 281 |
+
const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
|
| 282 |
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
| 283 |
|
| 284 |
LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
|
|
|
|
| 298 |
|
| 299 |
cross.v_embd.clear();
|
| 300 |
|
| 301 |
+
// reserve pp (prompt processing) graph first so that buffers are only allocated once
|
| 302 |
{
|
| 303 |
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
|
| 304 |
if (!gf) {
|
|
|
|
| 309 |
n_nodes_pp = ggml_graph_n_nodes(gf);
|
| 310 |
}
|
| 311 |
|
| 312 |
+
// reserve with tg (token generation) graph to get the number of splits and nodes
|
| 313 |
{
|
| 314 |
+
auto * gf = graph_reserve(n_seqs, n_seqs, n_seqs, mctx.get());
|
| 315 |
if (!gf) {
|
| 316 |
throw std::runtime_error("failed to allocate compute tg buffers");
|
| 317 |
}
|
|
|
|
| 322 |
|
| 323 |
// reserve again with pp graph to avoid ggml-alloc reallocations during inference
|
| 324 |
{
|
| 325 |
+
// TODO: not sure if the following graph would be worster case for multi-stream KV caches:
|
| 326 |
+
//
|
| 327 |
+
// auto * gf = graph_reserve(n_tokens, 1, n_tokens, mctx.get());
|
| 328 |
+
//
|
| 329 |
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
|
| 330 |
if (!gf) {
|
| 331 |
throw std::runtime_error("failed to allocate compute pp buffers");
|
|
|
|
| 403 |
return sched.get();
|
| 404 |
}
|
| 405 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 406 |
uint32_t llama_context::n_ctx() const {
|
| 407 |
return cparams.n_ctx;
|
| 408 |
}
|
|
|
|
| 474 |
}
|
| 475 |
}
|
| 476 |
|
| 477 |
+
// reset the previous graph result to make sure that it won't be reused
|
| 478 |
+
// TODO: change the mctx->apply() to return information if a graph reserve is needed
|
| 479 |
+
// reset the graph result only if the memory module did reset the scheduler
|
| 480 |
+
gf_res_prev->reset();
|
| 481 |
+
|
| 482 |
if (!mctx->apply()) {
|
| 483 |
LLAMA_LOG_ERROR("%s: failed to apply memory update\n", __func__);
|
| 484 |
}
|
|
|
|
| 491 |
throw std::runtime_error("failed to initialize memory context");
|
| 492 |
}
|
| 493 |
|
| 494 |
+
const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
|
| 495 |
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
| 496 |
|
| 497 |
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
|
|
|
|
| 508 |
}
|
| 509 |
|
| 510 |
float * llama_context::get_logits() {
|
| 511 |
+
output_reorder();
|
| 512 |
+
|
| 513 |
return logits;
|
| 514 |
}
|
| 515 |
|
| 516 |
float * llama_context::get_logits_ith(int32_t i) {
|
| 517 |
int64_t j = -1;
|
| 518 |
|
| 519 |
+
output_reorder();
|
| 520 |
+
|
| 521 |
try {
|
| 522 |
if (logits == nullptr) {
|
| 523 |
throw std::runtime_error("no logits");
|
|
|
|
| 554 |
}
|
| 555 |
|
| 556 |
float * llama_context::get_embeddings() {
|
| 557 |
+
output_reorder();
|
| 558 |
+
|
| 559 |
return embd;
|
| 560 |
}
|
| 561 |
|
| 562 |
float * llama_context::get_embeddings_ith(int32_t i) {
|
| 563 |
int64_t j = -1;
|
| 564 |
|
| 565 |
+
output_reorder();
|
| 566 |
+
|
| 567 |
try {
|
| 568 |
if (embd == nullptr) {
|
| 569 |
throw std::runtime_error("no embeddings");
|
|
|
|
| 702 |
return cvec.apply(model, data, len, n_embd, il_start, il_end);
|
| 703 |
}
|
| 704 |
|
| 705 |
+
llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) {
|
| 706 |
if (mctx && !mctx->apply()) {
|
| 707 |
LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__);
|
| 708 |
ret = GGML_STATUS_FAILED;
|
| 709 |
return nullptr;
|
| 710 |
}
|
| 711 |
|
| 712 |
+
auto * res = gf_res_prev.get();
|
| 713 |
+
auto * gf = res->get_gf();
|
|
|
|
|
|
|
|
|
|
|
|
|
| 714 |
|
| 715 |
+
// the new graph parameters
|
| 716 |
+
// in order to correctly reuse a graph, it's full topology has to be uniquely determined by these parameters
|
| 717 |
+
const auto gparams = graph_params(res, ubatch, mctx, gtype);
|
|
|
|
|
|
|
|
|
|
| 718 |
|
| 719 |
+
if (res->can_reuse(gparams)) {
|
| 720 |
+
//LLAMA_LOG_DEBUG("%s: reusing previous graph\n", __func__);
|
| 721 |
|
| 722 |
+
n_reused++;
|
| 723 |
+
} else {
|
| 724 |
+
res->reset();
|
| 725 |
+
|
| 726 |
+
ggml_backend_sched_reset(sched.get());
|
| 727 |
+
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
|
| 728 |
+
|
| 729 |
+
//const auto t_start_us = ggml_time_us();
|
| 730 |
+
|
| 731 |
+
gf = model.build_graph(gparams);
|
| 732 |
+
|
| 733 |
+
//LLAMA_LOG_INFO("graph build time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0);
|
| 734 |
+
|
| 735 |
+
if (!gf) {
|
| 736 |
+
LLAMA_LOG_ERROR("%s: failed to initialize graph\n", __func__);
|
| 737 |
+
ret = GGML_STATUS_FAILED;
|
| 738 |
+
return nullptr;
|
| 739 |
+
}
|
| 740 |
+
|
| 741 |
+
if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) {
|
| 742 |
+
LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__);
|
| 743 |
+
ret = GGML_STATUS_ALLOC_FAILED;
|
| 744 |
+
return nullptr;
|
| 745 |
+
}
|
| 746 |
}
|
| 747 |
|
| 748 |
+
// set the input data for the input tensors
|
| 749 |
+
{
|
| 750 |
+
//const auto t_start_us = ggml_time_us();
|
| 751 |
+
|
| 752 |
+
res->set_inputs(&ubatch);
|
| 753 |
+
|
| 754 |
+
//LLAMA_LOG_INFO("graph set inputs time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0);
|
| 755 |
+
}
|
| 756 |
|
| 757 |
+
const auto status = graph_compute(res->get_gf(), ubatch.n_tokens > 1);
|
| 758 |
if (status != GGML_STATUS_SUCCESS) {
|
| 759 |
LLAMA_LOG_ERROR("%s: failed to compute graph, compute status: %d\n", __func__, status);
|
| 760 |
ret = status;
|
|
|
|
| 776 |
|
| 777 |
const auto & hparams = model.hparams;
|
| 778 |
|
| 779 |
+
const int64_t n_embd = hparams.n_embd;
|
| 780 |
+
const int32_t n_vocab = model.vocab.n_tokens();
|
| 781 |
|
| 782 |
// note: during encode, we always pass the full sequence starting from pos = 0
|
| 783 |
+
if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) {
|
| 784 |
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
|
| 785 |
return -1;
|
| 786 |
}
|
| 787 |
|
| 788 |
const uint32_t n_tokens = balloc->get_n_tokens();
|
| 789 |
|
| 790 |
+
// [TAG_NO_CACHE_PAD]
|
| 791 |
+
// TODO: add new split mode where we pad the input sequences so that ubatch.equal_seqs == true
|
| 792 |
const llama_ubatch ubatch = balloc->split_simple(n_tokens);
|
| 793 |
|
| 794 |
// micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
|
|
|
|
| 815 |
|
| 816 |
n_outputs = n_tokens;
|
| 817 |
|
|
|
|
|
|
|
|
|
|
| 818 |
const auto causal_attn_org = cparams.causal_attn;
|
| 819 |
|
| 820 |
// always use non-causal attention for encoder graphs
|
|
|
|
| 823 |
cparams.causal_attn = false;
|
| 824 |
|
| 825 |
ggml_status status;
|
| 826 |
+
const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status);
|
| 827 |
|
| 828 |
cparams.causal_attn = causal_attn_org;
|
| 829 |
|
|
|
|
| 836 |
}
|
| 837 |
}
|
| 838 |
|
| 839 |
+
auto * t_logits = res->get_logits();
|
| 840 |
auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd();
|
| 841 |
|
| 842 |
+
// extract logits
|
| 843 |
+
if (logits && t_logits) {
|
| 844 |
+
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits);
|
| 845 |
+
GGML_ASSERT(backend_res != nullptr);
|
| 846 |
+
GGML_ASSERT(logits != nullptr);
|
| 847 |
+
|
| 848 |
+
ggml_backend_tensor_get_async(backend_res, t_logits, logits, 0, n_tokens*n_vocab*sizeof(float));
|
| 849 |
+
}
|
| 850 |
+
|
| 851 |
// extract embeddings
|
| 852 |
+
if (embd && t_embd) {
|
| 853 |
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
|
| 854 |
GGML_ASSERT(backend_embd != nullptr);
|
| 855 |
|
|
|
|
| 899 |
}
|
| 900 |
}
|
| 901 |
|
| 902 |
+
if (!supports_set_rows) {
|
| 903 |
+
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
|
| 904 |
+
// overlap with device computation.
|
| 905 |
+
ggml_backend_sched_reset(sched.get());
|
| 906 |
+
}
|
| 907 |
|
| 908 |
// TODO: hacky solution
|
| 909 |
if (model.arch == LLM_ARCH_T5 && t_embd) {
|
|
|
|
| 956 |
// when computing embeddings, all tokens are output
|
| 957 |
const bool output_all = cparams.embeddings;
|
| 958 |
|
| 959 |
+
if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, output_all)) {
|
| 960 |
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
|
| 961 |
return -1;
|
| 962 |
}
|
|
|
|
| 984 |
|
| 985 |
// TODO: this clear of the buffer can easily be forgotten - need something better
|
| 986 |
embd_seq.clear();
|
| 987 |
+
output_swaps.clear();
|
| 988 |
|
| 989 |
bool did_optimize = false;
|
| 990 |
|
|
|
|
| 1063 |
n_outputs = n_outputs_new;
|
| 1064 |
}
|
| 1065 |
|
|
|
|
|
|
|
|
|
|
| 1066 |
ggml_status status;
|
| 1067 |
+
const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status);
|
| 1068 |
|
| 1069 |
if (!res) {
|
| 1070 |
// the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
|
|
|
|
| 1204 |
// make the outputs have the same order they had in the user-provided batch
|
| 1205 |
// note: this is mostly relevant for recurrent models atm
|
| 1206 |
if (!sorted_output) {
|
|
|
|
|
|
|
|
|
|
| 1207 |
GGML_ASSERT((size_t) n_outputs == out_ids.size());
|
| 1208 |
|
| 1209 |
// TODO: is there something more efficient which also minimizes swaps?
|
|
|
|
| 1219 |
continue;
|
| 1220 |
}
|
| 1221 |
std::swap(out_ids[i], out_ids[j_min]);
|
| 1222 |
+
|
| 1223 |
+
// remember the swaps and apply them lazily upon logits/embeddings access
|
| 1224 |
+
output_swaps.push_back({ i, j_min });
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1225 |
}
|
| 1226 |
|
| 1227 |
std::fill(output_ids.begin(), output_ids.end(), -1);
|
|
|
|
| 1235 |
// wait for the computation to finish (automatically done when obtaining the model output)
|
| 1236 |
//synchronize();
|
| 1237 |
|
| 1238 |
+
if (!supports_set_rows) {
|
| 1239 |
+
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
|
| 1240 |
+
// overlap with device computation.
|
| 1241 |
+
ggml_backend_sched_reset(sched.get());
|
| 1242 |
+
}
|
| 1243 |
|
| 1244 |
return 0;
|
| 1245 |
}
|
|
|
|
| 1318 |
return n_outputs_max;
|
| 1319 |
}
|
| 1320 |
|
| 1321 |
+
void llama_context::output_reorder() {
|
| 1322 |
+
const uint32_t n_vocab = model.vocab.n_tokens();
|
| 1323 |
+
const uint64_t n_embd = model.hparams.n_embd;
|
| 1324 |
+
|
| 1325 |
+
for (uint32_t s = 0; s < output_swaps.size(); ++s) {
|
| 1326 |
+
const uint32_t i0 = output_swaps[s].i0;
|
| 1327 |
+
const uint32_t i1 = output_swaps[s].i1;
|
| 1328 |
+
|
| 1329 |
+
if (logits_size > 0) {
|
| 1330 |
+
for (uint32_t k = 0; k < n_vocab; k++) {
|
| 1331 |
+
std::swap(logits[i0*n_vocab + k], logits[i1*n_vocab + k]);
|
| 1332 |
+
}
|
| 1333 |
+
}
|
| 1334 |
+
|
| 1335 |
+
if (embd_size > 0) {
|
| 1336 |
+
for (uint32_t k = 0; k < n_embd; k++) {
|
| 1337 |
+
std::swap(embd[i0*n_embd + k], embd[i1*n_embd + k]);
|
| 1338 |
+
}
|
| 1339 |
+
}
|
| 1340 |
+
}
|
| 1341 |
+
|
| 1342 |
+
output_swaps.clear();
|
| 1343 |
+
}
|
| 1344 |
+
|
| 1345 |
//
|
| 1346 |
// graph
|
| 1347 |
//
|
| 1348 |
|
| 1349 |
+
uint32_t llama_context::graph_max_nodes() const {
|
| 1350 |
+
return std::max<uint32_t>(1024u, 8u*model.n_tensors());
|
| 1351 |
}
|
| 1352 |
|
| 1353 |
+
llm_graph_result * llama_context::get_gf_res_reserve() const {
|
| 1354 |
+
return static_cast<llm_graph_result *>(gf_res_reserve.get());
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1355 |
}
|
| 1356 |
|
| 1357 |
ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx) {
|
|
|
|
| 1364 |
LLAMA_LOG_DEBUG("%s: making n_tokens a multiple of n_seqs - n_tokens = %u, n_seqs = %u, n_outputs = %u\n", __func__, n_tokens, n_seqs, n_outputs);
|
| 1365 |
}
|
| 1366 |
|
| 1367 |
+
ggml_backend_sched_reset(sched.get());
|
| 1368 |
+
|
| 1369 |
+
// when the scheduler is reset, we cannnot reuse the old graph, so we reset the previous graph result to prevent that
|
| 1370 |
+
gf_res_prev->reset();
|
| 1371 |
+
|
| 1372 |
// store the n_outputs as it is, and restore it afterwards
|
| 1373 |
// TODO: not sure if needed, might simplify in the future by removing this
|
| 1374 |
const auto save_n_outputs = this->n_outputs;
|
|
|
|
| 1378 |
llama_batch_allocr balloc(model.hparams.n_pos_per_embd());
|
| 1379 |
llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens/n_seqs, n_seqs);
|
| 1380 |
|
| 1381 |
+
auto * res = gf_res_reserve.get();
|
|
|
|
| 1382 |
|
| 1383 |
+
const auto gparams = graph_params(res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT);
|
| 1384 |
|
| 1385 |
+
res->reset();
|
|
|
|
|
|
|
|
|
|
| 1386 |
|
| 1387 |
+
auto * gf = model.build_graph(gparams);
|
| 1388 |
+
|
| 1389 |
+
this->n_outputs = save_n_outputs;
|
| 1390 |
|
| 1391 |
// initialize scheduler with the specified graph
|
| 1392 |
if (!ggml_backend_sched_reserve(sched.get(), gf)) {
|
|
|
|
| 1397 |
return gf;
|
| 1398 |
}
|
| 1399 |
|
| 1400 |
+
llm_graph_params llama_context::graph_params(
|
| 1401 |
+
llm_graph_result * res,
|
| 1402 |
+
const llama_ubatch & ubatch,
|
| 1403 |
+
const llama_memory_context_i * mctx,
|
| 1404 |
+
llm_graph_type gtype) const {
|
| 1405 |
+
return {
|
| 1406 |
+
/*.arch =*/ model.arch,
|
| 1407 |
+
/*.hparams =*/ model.hparams,
|
| 1408 |
+
/*.cparams =*/ cparams,
|
| 1409 |
+
/*.ubatch =*/ ubatch,
|
| 1410 |
+
/*.gtype =*/ gtype,
|
| 1411 |
+
/*.sched =*/ sched.get(),
|
| 1412 |
+
/*.backend_cpu =*/ backend_cpu,
|
| 1413 |
+
/*.cvec =*/ &cvec,
|
| 1414 |
+
/*.loras =*/ &loras,
|
| 1415 |
+
/*.mctx =*/ mctx,
|
| 1416 |
+
/*.cross =*/ &cross,
|
| 1417 |
+
/*.n_outputs =*/ n_outputs,
|
| 1418 |
+
/*.cb =*/ graph_get_cb(),
|
| 1419 |
+
/*.res =*/ res,
|
| 1420 |
+
};
|
|
|
|
| 1421 |
}
|
| 1422 |
|
| 1423 |
ggml_status llama_context::graph_compute(
|
|
|
|
| 1995 |
data.t_eval_ms = 1e-3 * t_eval_us;
|
| 1996 |
data.n_p_eval = std::max(1, n_p_eval);
|
| 1997 |
data.n_eval = std::max(1, n_eval);
|
| 1998 |
+
data.n_reused = std::max(0, n_reused);
|
| 1999 |
|
| 2000 |
return data;
|
| 2001 |
}
|
|
|
|
| 2004 |
t_start_us = ggml_time_us();
|
| 2005 |
t_eval_us = n_eval = 0;
|
| 2006 |
t_p_eval_us = n_p_eval = 0;
|
| 2007 |
+
n_reused = 0;
|
| 2008 |
}
|
| 2009 |
|
| 2010 |
//
|
|
|
|
| 2095 |
batch.logits [pos_batch] = true;
|
| 2096 |
}
|
| 2097 |
|
| 2098 |
+
if (!balloc->init(batch, model.vocab, nullptr, model.hparams.n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) {
|
| 2099 |
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
|
| 2100 |
return;
|
| 2101 |
}
|
|
|
|
| 2131 |
break;
|
| 2132 |
}
|
| 2133 |
|
| 2134 |
+
auto * res = gf_res_prev.get();
|
| 2135 |
+
|
| 2136 |
+
const auto gparams = graph_params(res, ubatch, mctx.get(), LLM_GRAPH_TYPE_DEFAULT);
|
| 2137 |
+
|
| 2138 |
+
res->reset();
|
| 2139 |
+
|
| 2140 |
+
auto * gf = model.build_graph(gparams);
|
| 2141 |
|
| 2142 |
struct ggml_context * ctx_compute_opt;
|
| 2143 |
{
|
|
|
|
| 2259 |
/*.no_perf =*/ true,
|
| 2260 |
/*.op_offload =*/ true,
|
| 2261 |
/*.swa_full =*/ true,
|
| 2262 |
+
/*.kv_unified =*/ false,
|
| 2263 |
};
|
| 2264 |
|
| 2265 |
return result;
|
|
|
|
| 2880 |
LLAMA_LOG_INFO("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
|
| 2881 |
__func__, data.t_eval_ms, data.n_eval, data.t_eval_ms / data.n_eval, 1e3 / data.t_eval_ms * data.n_eval);
|
| 2882 |
LLAMA_LOG_INFO("%s: total time = %10.2f ms / %5d tokens\n", __func__, (t_end_ms - data.t_start_ms), (data.n_p_eval + data.n_eval));
|
| 2883 |
+
LLAMA_LOG_INFO("%s: graphs reused = %10d\n", __func__, data.n_reused);
|
| 2884 |
}
|
| 2885 |
|
| 2886 |
void llama_perf_context_reset(llama_context * ctx) {
|
examples/talk-llama/llama-context.h
CHANGED
|
@@ -35,8 +35,6 @@ struct llama_context {
|
|
| 35 |
|
| 36 |
ggml_backend_sched_t get_sched() const;
|
| 37 |
|
| 38 |
-
ggml_context * get_ctx_compute() const;
|
| 39 |
-
|
| 40 |
uint32_t n_ctx() const;
|
| 41 |
uint32_t n_ctx_per_seq() const;
|
| 42 |
uint32_t n_batch() const;
|
|
@@ -96,7 +94,7 @@ struct llama_context {
|
|
| 96 |
// if memory_context is provided, it will be applied first to the context's memory
|
| 97 |
// ret contains the status of the graph computation
|
| 98 |
// returns nullptr only if ret != GGML_STATUS_SUCCESS
|
| 99 |
-
|
| 100 |
const llama_ubatch & ubatch,
|
| 101 |
llm_graph_type gtype,
|
| 102 |
llama_memory_context_i * mctx,
|
|
@@ -183,15 +181,17 @@ private:
|
|
| 183 |
// Returns max number of outputs for which space was reserved.
|
| 184 |
uint32_t output_reserve(int32_t n_outputs);
|
| 185 |
|
|
|
|
|
|
|
| 186 |
//
|
| 187 |
// graph
|
| 188 |
//
|
| 189 |
|
| 190 |
public:
|
| 191 |
-
|
| 192 |
|
| 193 |
-
//
|
| 194 |
-
|
| 195 |
|
| 196 |
// returns the result of ggml_backend_sched_graph_compute_async execution
|
| 197 |
ggml_status graph_compute(ggml_cgraph * gf, bool batched);
|
|
@@ -200,12 +200,11 @@ public:
|
|
| 200 |
ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx);
|
| 201 |
|
| 202 |
private:
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
const llama_memory_context_i * mctx);
|
| 209 |
|
| 210 |
llm_graph_cb graph_get_cb() const;
|
| 211 |
|
|
@@ -253,13 +252,18 @@ private:
|
|
| 253 |
|
| 254 |
std::vector<int32_t> output_ids; // map batch token positions to ids of the logits and embd buffers
|
| 255 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 256 |
ggml_backend_sched_ptr sched;
|
| 257 |
|
| 258 |
ggml_backend_t backend_cpu = nullptr;
|
| 259 |
std::vector<ggml_backend_ptr> backends;
|
| 260 |
|
| 261 |
-
ggml_context_ptr ctx_compute;
|
| 262 |
-
|
| 263 |
// training
|
| 264 |
ggml_opt_context_t opt_ctx = nullptr;
|
| 265 |
|
|
@@ -275,14 +279,18 @@ private:
|
|
| 275 |
std::vector<ggml_backend_t> backend_ptrs;
|
| 276 |
std::vector<ggml_backend_buffer_type_t> backend_buft;
|
| 277 |
|
| 278 |
-
|
| 279 |
-
|
| 280 |
|
| 281 |
// host buffer for the model output (logits and embeddings)
|
| 282 |
ggml_backend_buffer_ptr buf_output;
|
| 283 |
|
| 284 |
bool has_evaluated_once = false;
|
| 285 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 286 |
// perf
|
| 287 |
mutable int64_t t_start_us = 0;
|
| 288 |
mutable int64_t t_load_us = 0;
|
|
@@ -294,4 +302,6 @@ private:
|
|
| 294 |
|
| 295 |
mutable int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1)
|
| 296 |
mutable int32_t n_eval = 0; // number of eval calls
|
|
|
|
|
|
|
| 297 |
};
|
|
|
|
| 35 |
|
| 36 |
ggml_backend_sched_t get_sched() const;
|
| 37 |
|
|
|
|
|
|
|
| 38 |
uint32_t n_ctx() const;
|
| 39 |
uint32_t n_ctx_per_seq() const;
|
| 40 |
uint32_t n_batch() const;
|
|
|
|
| 94 |
// if memory_context is provided, it will be applied first to the context's memory
|
| 95 |
// ret contains the status of the graph computation
|
| 96 |
// returns nullptr only if ret != GGML_STATUS_SUCCESS
|
| 97 |
+
llm_graph_result * process_ubatch(
|
| 98 |
const llama_ubatch & ubatch,
|
| 99 |
llm_graph_type gtype,
|
| 100 |
llama_memory_context_i * mctx,
|
|
|
|
| 181 |
// Returns max number of outputs for which space was reserved.
|
| 182 |
uint32_t output_reserve(int32_t n_outputs);
|
| 183 |
|
| 184 |
+
void output_reorder();
|
| 185 |
+
|
| 186 |
//
|
| 187 |
// graph
|
| 188 |
//
|
| 189 |
|
| 190 |
public:
|
| 191 |
+
uint32_t graph_max_nodes() const;
|
| 192 |
|
| 193 |
+
// can reuse the llm_graph_result instance of the context (for example to update a memory module)
|
| 194 |
+
llm_graph_result * get_gf_res_reserve() const;
|
| 195 |
|
| 196 |
// returns the result of ggml_backend_sched_graph_compute_async execution
|
| 197 |
ggml_status graph_compute(ggml_cgraph * gf, bool batched);
|
|
|
|
| 200 |
ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx);
|
| 201 |
|
| 202 |
private:
|
| 203 |
+
llm_graph_params graph_params(
|
| 204 |
+
llm_graph_result * res,
|
| 205 |
+
const llama_ubatch & ubatch,
|
| 206 |
+
const llama_memory_context_i * mctx,
|
| 207 |
+
llm_graph_type gtype) const;
|
|
|
|
| 208 |
|
| 209 |
llm_graph_cb graph_get_cb() const;
|
| 210 |
|
|
|
|
| 252 |
|
| 253 |
std::vector<int32_t> output_ids; // map batch token positions to ids of the logits and embd buffers
|
| 254 |
|
| 255 |
+
struct swap_info {
|
| 256 |
+
uint32_t i0;
|
| 257 |
+
uint32_t i1;
|
| 258 |
+
};
|
| 259 |
+
|
| 260 |
+
std::vector<swap_info> output_swaps;
|
| 261 |
+
|
| 262 |
ggml_backend_sched_ptr sched;
|
| 263 |
|
| 264 |
ggml_backend_t backend_cpu = nullptr;
|
| 265 |
std::vector<ggml_backend_ptr> backends;
|
| 266 |
|
|
|
|
|
|
|
| 267 |
// training
|
| 268 |
ggml_opt_context_t opt_ctx = nullptr;
|
| 269 |
|
|
|
|
| 279 |
std::vector<ggml_backend_t> backend_ptrs;
|
| 280 |
std::vector<ggml_backend_buffer_type_t> backend_buft;
|
| 281 |
|
| 282 |
+
llm_graph_result_ptr gf_res_prev;
|
| 283 |
+
llm_graph_result_ptr gf_res_reserve;
|
| 284 |
|
| 285 |
// host buffer for the model output (logits and embeddings)
|
| 286 |
ggml_backend_buffer_ptr buf_output;
|
| 287 |
|
| 288 |
bool has_evaluated_once = false;
|
| 289 |
|
| 290 |
+
// env: LLAMA_SET_ROWS (temporary)
|
| 291 |
+
// ref: https://github.com/ggml-org/llama.cpp/pull/14285
|
| 292 |
+
bool supports_set_rows = false;
|
| 293 |
+
|
| 294 |
// perf
|
| 295 |
mutable int64_t t_start_us = 0;
|
| 296 |
mutable int64_t t_load_us = 0;
|
|
|
|
| 302 |
|
| 303 |
mutable int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1)
|
| 304 |
mutable int32_t n_eval = 0; // number of eval calls
|
| 305 |
+
|
| 306 |
+
mutable int32_t n_reused = 0; // number of times the previous graph was reused
|
| 307 |
};
|
examples/talk-llama/llama-cparams.h
CHANGED
|
@@ -11,8 +11,8 @@ struct llama_cparams {
|
|
| 11 |
uint32_t n_batch;
|
| 12 |
uint32_t n_ubatch;
|
| 13 |
uint32_t n_seq_max;
|
| 14 |
-
|
| 15 |
-
|
| 16 |
|
| 17 |
float rope_freq_base;
|
| 18 |
float rope_freq_scale;
|
|
@@ -33,6 +33,7 @@ struct llama_cparams {
|
|
| 33 |
bool no_perf;
|
| 34 |
bool warmup;
|
| 35 |
bool op_offload;
|
|
|
|
| 36 |
|
| 37 |
enum llama_pooling_type pooling_type;
|
| 38 |
|
|
|
|
| 11 |
uint32_t n_batch;
|
| 12 |
uint32_t n_ubatch;
|
| 13 |
uint32_t n_seq_max;
|
| 14 |
+
int32_t n_threads; // number of threads to use for generation
|
| 15 |
+
int32_t n_threads_batch; // number of threads to use for batch processing
|
| 16 |
|
| 17 |
float rope_freq_base;
|
| 18 |
float rope_freq_scale;
|
|
|
|
| 33 |
bool no_perf;
|
| 34 |
bool warmup;
|
| 35 |
bool op_offload;
|
| 36 |
+
bool kv_unified;
|
| 37 |
|
| 38 |
enum llama_pooling_type pooling_type;
|
| 39 |
|
examples/talk-llama/llama-graph.cpp
CHANGED
|
@@ -28,6 +28,15 @@ void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
|
|
| 28 |
}
|
| 29 |
}
|
| 30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) {
|
| 32 |
if (ubatch->pos && pos) {
|
| 33 |
const int64_t n_tokens = ubatch->n_tokens;
|
|
@@ -50,6 +59,14 @@ void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) {
|
|
| 50 |
}
|
| 51 |
}
|
| 52 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
void llm_graph_input_attn_temp::set_input(const llama_ubatch * ubatch) {
|
| 54 |
if (ubatch->pos && attn_scale) {
|
| 55 |
const int64_t n_tokens = ubatch->n_tokens;
|
|
@@ -71,7 +88,7 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
|
|
| 71 |
const int64_t n_tokens = ubatch->n_tokens;
|
| 72 |
|
| 73 |
GGML_ASSERT(ggml_backend_buffer_is_host(pos_bucket->buffer));
|
| 74 |
-
GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
|
| 75 |
|
| 76 |
int32_t * data = (int32_t *) pos_bucket->data;
|
| 77 |
|
|
@@ -118,6 +135,14 @@ void llm_graph_input_out_ids::set_input(const llama_ubatch * ubatch) {
|
|
| 118 |
}
|
| 119 |
}
|
| 120 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
|
| 122 |
if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
|
| 123 |
const int64_t n_tokens = ubatch->n_tokens;
|
|
@@ -287,6 +312,24 @@ void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
|
|
| 287 |
mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
| 288 |
}
|
| 289 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 290 |
void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
|
| 291 |
mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch);
|
| 292 |
mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
|
|
@@ -299,6 +342,30 @@ void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch
|
|
| 299 |
mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
|
| 300 |
}
|
| 301 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 302 |
void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
|
| 303 |
GGML_ASSERT(cross_kq_mask);
|
| 304 |
|
|
@@ -306,7 +373,7 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
|
|
| 306 |
const int64_t n_tokens = ubatch->n_tokens;
|
| 307 |
|
| 308 |
GGML_ASSERT(ggml_backend_buffer_is_host(cross_kq_mask->buffer));
|
| 309 |
-
GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
|
| 310 |
|
| 311 |
float * data = (float *) cross_kq_mask->data;
|
| 312 |
|
|
@@ -340,6 +407,91 @@ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
|
|
| 340 |
inp_rs->set_input(ubatch);
|
| 341 |
}
|
| 342 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 343 |
//
|
| 344 |
// llm_graph_context
|
| 345 |
//
|
|
@@ -374,7 +526,6 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
|
|
| 374 |
n_ctx_orig (cparams.n_ctx_orig_yarn),
|
| 375 |
pooling_type (cparams.pooling_type),
|
| 376 |
rope_type (hparams.rope_type),
|
| 377 |
-
ctx0 (params.ctx),
|
| 378 |
sched (params.sched),
|
| 379 |
backend_cpu (params.backend_cpu),
|
| 380 |
cvec (params.cvec),
|
|
@@ -382,7 +533,10 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
|
|
| 382 |
mctx (params.mctx),
|
| 383 |
cross (params.cross),
|
| 384 |
cb_func (params.cb),
|
| 385 |
-
res (
|
|
|
|
|
|
|
|
|
|
| 386 |
}
|
| 387 |
|
| 388 |
void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const {
|
|
@@ -753,20 +907,28 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|
| 753 |
cb(cur, "ffn_moe_weighted", il);
|
| 754 |
}
|
| 755 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 756 |
// aggregate experts
|
| 757 |
-
|
| 758 |
-
|
| 759 |
-
|
| 760 |
-
|
| 761 |
|
| 762 |
-
|
| 763 |
-
|
| 764 |
-
} else {
|
| 765 |
-
moe_out = ggml_add(ctx0, moe_out, cur_expert);
|
| 766 |
-
}
|
| 767 |
}
|
| 768 |
|
| 769 |
-
if (n_expert_used == 1) {
|
| 770 |
// avoid returning a non-contiguous tensor
|
| 771 |
moe_out = ggml_cont(ctx0, moe_out);
|
| 772 |
}
|
|
@@ -972,7 +1134,6 @@ ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_t
|
|
| 972 |
}
|
| 973 |
|
| 974 |
ggml_tensor * llm_graph_context::build_attn_mha(
|
| 975 |
-
ggml_cgraph * gf,
|
| 976 |
ggml_tensor * q,
|
| 977 |
ggml_tensor * k,
|
| 978 |
ggml_tensor * v,
|
|
@@ -982,13 +1143,16 @@ ggml_tensor * llm_graph_context::build_attn_mha(
|
|
| 982 |
float kq_scale) const {
|
| 983 |
const bool v_trans = v->nb[1] > v->nb[2];
|
| 984 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 985 |
q = ggml_permute(ctx0, q, 0, 2, 1, 3);
|
| 986 |
k = ggml_permute(ctx0, k, 0, 2, 1, 3);
|
| 987 |
v = ggml_permute(ctx0, v, 0, 2, 1, 3);
|
| 988 |
|
| 989 |
-
const auto
|
| 990 |
-
const auto n_head = q->ne[2];
|
| 991 |
-
const auto n_kv = k->ne[1];
|
| 992 |
|
| 993 |
ggml_tensor * cur;
|
| 994 |
|
|
@@ -1030,7 +1194,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
|
|
| 1030 |
#endif
|
| 1031 |
}
|
| 1032 |
|
| 1033 |
-
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*
|
| 1034 |
} else {
|
| 1035 |
ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
|
| 1036 |
|
|
@@ -1075,7 +1239,8 @@ ggml_tensor * llm_graph_context::build_attn_mha(
|
|
| 1075 |
|
| 1076 |
cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
|
| 1077 |
|
| 1078 |
-
|
|
|
|
| 1079 |
|
| 1080 |
if (!cparams.offload_kqv) {
|
| 1081 |
// all nodes between the KV store and the attention output are run on the CPU
|
|
@@ -1102,7 +1267,6 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con
|
|
| 1102 |
|
| 1103 |
ggml_tensor * llm_graph_context::build_attn(
|
| 1104 |
llm_graph_input_attn_no_cache * inp,
|
| 1105 |
-
ggml_cgraph * gf,
|
| 1106 |
ggml_tensor * wo,
|
| 1107 |
ggml_tensor * wo_b,
|
| 1108 |
ggml_tensor * q_cur,
|
|
@@ -1122,11 +1286,15 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
| 1122 |
|
| 1123 |
const auto & kq_mask = inp->get_kq_mask();
|
| 1124 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1125 |
ggml_tensor * q = q_cur;
|
| 1126 |
ggml_tensor * k = k_cur;
|
| 1127 |
ggml_tensor * v = v_cur;
|
| 1128 |
|
| 1129 |
-
ggml_tensor * cur = build_attn_mha(
|
| 1130 |
cb(cur, "kqv_out", il);
|
| 1131 |
|
| 1132 |
if (wo) {
|
|
@@ -1156,13 +1324,14 @@ static std::unique_ptr<llm_graph_input_attn_kv_unified> build_attn_inp_kv_unifie
|
|
| 1156 |
{
|
| 1157 |
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
|
| 1158 |
|
| 1159 |
-
const auto n_kv
|
| 1160 |
const auto n_tokens = ubatch.n_tokens;
|
|
|
|
| 1161 |
|
| 1162 |
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
|
| 1163 |
inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
|
| 1164 |
|
| 1165 |
-
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1,
|
| 1166 |
ggml_set_input(inp->self_kq_mask);
|
| 1167 |
|
| 1168 |
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
|
@@ -1181,7 +1350,6 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
|
|
| 1181 |
|
| 1182 |
ggml_tensor * llm_graph_context::build_attn(
|
| 1183 |
llm_graph_input_attn_kv_unified * inp,
|
| 1184 |
-
ggml_cgraph * gf,
|
| 1185 |
ggml_tensor * wo,
|
| 1186 |
ggml_tensor * wo_b,
|
| 1187 |
ggml_tensor * q_cur,
|
|
@@ -1214,7 +1382,7 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
| 1214 |
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
|
| 1215 |
ggml_tensor * v = mctx_cur->get_v(ctx0, il);
|
| 1216 |
|
| 1217 |
-
ggml_tensor * cur = build_attn_mha(
|
| 1218 |
cb(cur, "kqv_out", il);
|
| 1219 |
|
| 1220 |
if (wo) {
|
|
@@ -1234,7 +1402,6 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
| 1234 |
|
| 1235 |
ggml_tensor * llm_graph_context::build_attn(
|
| 1236 |
llm_graph_input_attn_kv_unified_iswa * inp,
|
| 1237 |
-
ggml_cgraph * gf,
|
| 1238 |
ggml_tensor * wo,
|
| 1239 |
ggml_tensor * wo_b,
|
| 1240 |
ggml_tensor * q_cur,
|
|
@@ -1281,7 +1448,7 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
| 1281 |
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
|
| 1282 |
ggml_tensor * v = mctx_cur->get_v(ctx0, il);
|
| 1283 |
|
| 1284 |
-
ggml_tensor * cur = build_attn_mha(
|
| 1285 |
cb(cur, "kqv_out", il);
|
| 1286 |
|
| 1287 |
if (wo) {
|
|
@@ -1314,7 +1481,6 @@ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
|
|
| 1314 |
|
| 1315 |
ggml_tensor * llm_graph_context::build_attn(
|
| 1316 |
llm_graph_input_attn_cross * inp,
|
| 1317 |
-
ggml_cgraph * gf,
|
| 1318 |
ggml_tensor * wo,
|
| 1319 |
ggml_tensor * wo_b,
|
| 1320 |
ggml_tensor * q_cur,
|
|
@@ -1336,7 +1502,7 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
| 1336 |
ggml_tensor * k = k_cur;
|
| 1337 |
ggml_tensor * v = v_cur;
|
| 1338 |
|
| 1339 |
-
ggml_tensor * cur = build_attn_mha(
|
| 1340 |
cb(cur, "kqv_out", il);
|
| 1341 |
|
| 1342 |
if (wo) {
|
|
@@ -1362,13 +1528,15 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
|
|
| 1362 |
|
| 1363 |
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur);
|
| 1364 |
|
|
|
|
|
|
|
| 1365 |
{
|
| 1366 |
const auto n_kv = mctx_cur->get_base()->get_n_kv();
|
| 1367 |
|
| 1368 |
inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
|
| 1369 |
inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
|
| 1370 |
|
| 1371 |
-
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1,
|
| 1372 |
ggml_set_input(inp->self_kq_mask);
|
| 1373 |
|
| 1374 |
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
|
@@ -1382,7 +1550,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
|
|
| 1382 |
inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
|
| 1383 |
inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
|
| 1384 |
|
| 1385 |
-
inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1,
|
| 1386 |
ggml_set_input(inp->self_kq_mask_swa);
|
| 1387 |
|
| 1388 |
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
|
|
@@ -1392,7 +1560,6 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
|
|
| 1392 |
}
|
| 1393 |
|
| 1394 |
ggml_tensor * llm_graph_context::build_rs(
|
| 1395 |
-
ggml_cgraph * gf,
|
| 1396 |
ggml_tensor * s,
|
| 1397 |
ggml_tensor * state_copy,
|
| 1398 |
int32_t state_size,
|
|
@@ -1450,21 +1617,19 @@ llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
|
|
| 1450 |
|
| 1451 |
ggml_tensor * llm_graph_context::build_rs(
|
| 1452 |
llm_graph_input_rs * inp,
|
| 1453 |
-
ggml_cgraph * gf,
|
| 1454 |
ggml_tensor * s,
|
| 1455 |
int32_t state_size,
|
| 1456 |
int32_t n_seqs,
|
| 1457 |
const llm_graph_get_rows_fn & get_state_rows) const {
|
| 1458 |
const auto * kv_state = inp->mctx;
|
| 1459 |
|
| 1460 |
-
return build_rs(
|
| 1461 |
}
|
| 1462 |
|
| 1463 |
ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
|
| 1464 |
llm_graph_input_rs * inp,
|
| 1465 |
-
ggml_cgraph * gf,
|
| 1466 |
const llama_ubatch & ubatch,
|
| 1467 |
-
|
| 1468 |
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
|
| 1469 |
|
| 1470 |
const auto token_shift_count = hparams.token_shift_count;
|
|
@@ -1474,7 +1639,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
|
|
| 1474 |
ggml_tensor * token_shift_all = mctx_cur->get_r_l(il);
|
| 1475 |
|
| 1476 |
ggml_tensor * token_shift = build_rs(
|
| 1477 |
-
inp,
|
| 1478 |
hparams.n_embd_r(), n_seqs);
|
| 1479 |
|
| 1480 |
token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs);
|
|
@@ -1514,7 +1679,6 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
|
|
| 1514 |
}
|
| 1515 |
|
| 1516 |
void llm_graph_context::build_pooling(
|
| 1517 |
-
ggml_cgraph * gf,
|
| 1518 |
ggml_tensor * cls,
|
| 1519 |
ggml_tensor * cls_b,
|
| 1520 |
ggml_tensor * cls_out,
|
|
|
|
| 28 |
}
|
| 29 |
}
|
| 30 |
|
| 31 |
+
bool llm_graph_input_embd::can_reuse(const llm_graph_params & params) {
|
| 32 |
+
bool res = true;
|
| 33 |
+
|
| 34 |
+
res &= (!tokens && !params.ubatch.token) || (tokens && tokens->ne[0] == params.ubatch.n_tokens);
|
| 35 |
+
res &= (!embd && !params.ubatch.embd) || (embd && embd->ne[0] == params.ubatch.n_tokens);
|
| 36 |
+
|
| 37 |
+
return res;
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) {
|
| 41 |
if (ubatch->pos && pos) {
|
| 42 |
const int64_t n_tokens = ubatch->n_tokens;
|
|
|
|
| 59 |
}
|
| 60 |
}
|
| 61 |
|
| 62 |
+
bool llm_graph_input_pos::can_reuse(const llm_graph_params & params) {
|
| 63 |
+
bool res = true;
|
| 64 |
+
|
| 65 |
+
res &= pos->ne[0] == params.ubatch.n_tokens;
|
| 66 |
+
|
| 67 |
+
return res;
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
void llm_graph_input_attn_temp::set_input(const llama_ubatch * ubatch) {
|
| 71 |
if (ubatch->pos && attn_scale) {
|
| 72 |
const int64_t n_tokens = ubatch->n_tokens;
|
|
|
|
| 88 |
const int64_t n_tokens = ubatch->n_tokens;
|
| 89 |
|
| 90 |
GGML_ASSERT(ggml_backend_buffer_is_host(pos_bucket->buffer));
|
| 91 |
+
GGML_ASSERT(!ubatch->equal_seqs()); // TODO: use ubatch->n_seqs instead of failing
|
| 92 |
|
| 93 |
int32_t * data = (int32_t *) pos_bucket->data;
|
| 94 |
|
|
|
|
| 135 |
}
|
| 136 |
}
|
| 137 |
|
| 138 |
+
bool llm_graph_input_out_ids::can_reuse(const llm_graph_params & params) {
|
| 139 |
+
bool res = true;
|
| 140 |
+
|
| 141 |
+
res &= n_outputs == params.n_outputs;
|
| 142 |
+
|
| 143 |
+
return res;
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
|
| 147 |
if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
|
| 148 |
const int64_t n_tokens = ubatch->n_tokens;
|
|
|
|
| 312 |
mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
| 313 |
}
|
| 314 |
|
| 315 |
+
bool llm_graph_input_attn_kv_unified::can_reuse(const llm_graph_params & params) {
|
| 316 |
+
const auto * mctx = static_cast<const llama_kv_cache_unified_context *>(params.mctx);
|
| 317 |
+
|
| 318 |
+
this->mctx = mctx;
|
| 319 |
+
|
| 320 |
+
bool res = true;
|
| 321 |
+
|
| 322 |
+
res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;
|
| 323 |
+
//res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
|
| 324 |
+
|
| 325 |
+
res &= self_kq_mask->ne[0] == mctx->get_n_kv();
|
| 326 |
+
res &= self_kq_mask->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);
|
| 327 |
+
|
| 328 |
+
res &= mctx->get_supports_set_rows(); // TODO: tmp
|
| 329 |
+
|
| 330 |
+
return res;
|
| 331 |
+
}
|
| 332 |
+
|
| 333 |
void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
|
| 334 |
mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch);
|
| 335 |
mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
|
|
|
|
| 342 |
mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
|
| 343 |
}
|
| 344 |
|
| 345 |
+
bool llm_graph_input_attn_kv_unified_iswa::can_reuse(const llm_graph_params & params) {
|
| 346 |
+
const auto * mctx = static_cast<const llama_kv_cache_unified_iswa_context *>(params.mctx);
|
| 347 |
+
|
| 348 |
+
this->mctx = mctx;
|
| 349 |
+
|
| 350 |
+
bool res = true;
|
| 351 |
+
|
| 352 |
+
res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;
|
| 353 |
+
//res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
|
| 354 |
+
|
| 355 |
+
res &= self_k_idxs_swa->ne[0] == params.ubatch.n_tokens;
|
| 356 |
+
//res &= self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
|
| 357 |
+
|
| 358 |
+
res &= self_kq_mask->ne[0] == mctx->get_base()->get_n_kv();
|
| 359 |
+
res &= self_kq_mask->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);
|
| 360 |
+
|
| 361 |
+
res &= self_kq_mask_swa->ne[0] == mctx->get_swa()->get_n_kv();
|
| 362 |
+
res &= self_kq_mask_swa->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);
|
| 363 |
+
|
| 364 |
+
res &= mctx->get_base()->get_supports_set_rows(); // TODO: tmp
|
| 365 |
+
|
| 366 |
+
return res;
|
| 367 |
+
}
|
| 368 |
+
|
| 369 |
void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
|
| 370 |
GGML_ASSERT(cross_kq_mask);
|
| 371 |
|
|
|
|
| 373 |
const int64_t n_tokens = ubatch->n_tokens;
|
| 374 |
|
| 375 |
GGML_ASSERT(ggml_backend_buffer_is_host(cross_kq_mask->buffer));
|
| 376 |
+
GGML_ASSERT(!ubatch->equal_seqs()); // TODO: use ubatch->n_seqs instead of failing
|
| 377 |
|
| 378 |
float * data = (float *) cross_kq_mask->data;
|
| 379 |
|
|
|
|
| 407 |
inp_rs->set_input(ubatch);
|
| 408 |
}
|
| 409 |
|
| 410 |
+
//
|
| 411 |
+
// llm_graph_result
|
| 412 |
+
//
|
| 413 |
+
|
| 414 |
+
llm_graph_result::llm_graph_result(int64_t max_nodes) : max_nodes(max_nodes) {
|
| 415 |
+
reset();
|
| 416 |
+
|
| 417 |
+
const char * LLAMA_GRAPH_RESULT_DEBUG = getenv("LLAMA_GRAPH_RESULT_DEBUG");
|
| 418 |
+
debug = LLAMA_GRAPH_RESULT_DEBUG ? atoi(LLAMA_GRAPH_RESULT_DEBUG) : 0;
|
| 419 |
+
}
|
| 420 |
+
|
| 421 |
+
int64_t llm_graph_result::get_max_nodes() const {
|
| 422 |
+
return max_nodes;
|
| 423 |
+
}
|
| 424 |
+
|
| 425 |
+
void llm_graph_result::reset() {
|
| 426 |
+
t_tokens = nullptr;
|
| 427 |
+
t_logits = nullptr;
|
| 428 |
+
t_embd = nullptr;
|
| 429 |
+
t_embd_pooled = nullptr;
|
| 430 |
+
|
| 431 |
+
params = {};
|
| 432 |
+
|
| 433 |
+
inputs.clear();
|
| 434 |
+
|
| 435 |
+
buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false));
|
| 436 |
+
|
| 437 |
+
ggml_init_params params = {
|
| 438 |
+
/*.mem_size =*/ buf_compute_meta.size(),
|
| 439 |
+
/*.mem_buffer =*/ buf_compute_meta.data(),
|
| 440 |
+
/*.no_alloc =*/ true,
|
| 441 |
+
};
|
| 442 |
+
|
| 443 |
+
ctx_compute.reset(ggml_init(params));
|
| 444 |
+
|
| 445 |
+
gf = ggml_new_graph_custom(ctx_compute.get(), max_nodes, false);
|
| 446 |
+
}
|
| 447 |
+
|
| 448 |
+
void llm_graph_result::set_inputs(const llama_ubatch * ubatch) {
|
| 449 |
+
for (auto & input : inputs) {
|
| 450 |
+
input->set_input(ubatch);
|
| 451 |
+
}
|
| 452 |
+
}
|
| 453 |
+
|
| 454 |
+
bool llm_graph_result::can_reuse(const llm_graph_params & params) {
|
| 455 |
+
if (!this->params.allow_reuse(params)) {
|
| 456 |
+
if (debug > 1) {
|
| 457 |
+
LLAMA_LOG_DEBUG("%s: cannot reuse graph due to incompatible graph parameters\n", __func__);
|
| 458 |
+
}
|
| 459 |
+
|
| 460 |
+
return false;
|
| 461 |
+
}
|
| 462 |
+
|
| 463 |
+
if (debug > 1) {
|
| 464 |
+
LLAMA_LOG_DEBUG("%s: checking compatibility of %d inputs:\n", __func__, (int) inputs.size());
|
| 465 |
+
}
|
| 466 |
+
|
| 467 |
+
bool res = true;
|
| 468 |
+
|
| 469 |
+
for (auto & input : inputs) {
|
| 470 |
+
const bool cur = input->can_reuse(params);
|
| 471 |
+
|
| 472 |
+
if (debug > 1) {
|
| 473 |
+
LLAMA_LOG_DEBUG("%s: can_reuse = %d\n", "placeholder", cur);
|
| 474 |
+
}
|
| 475 |
+
|
| 476 |
+
res = res && cur;
|
| 477 |
+
}
|
| 478 |
+
|
| 479 |
+
if (debug > 0) {
|
| 480 |
+
LLAMA_LOG_DEBUG("%s: can reuse graph = %d\n", __func__, res);
|
| 481 |
+
}
|
| 482 |
+
|
| 483 |
+
return res;
|
| 484 |
+
}
|
| 485 |
+
|
| 486 |
+
llm_graph_input_i * llm_graph_result::add_input(llm_graph_input_ptr input) {
|
| 487 |
+
inputs.emplace_back(std::move(input));
|
| 488 |
+
return inputs.back().get();
|
| 489 |
+
}
|
| 490 |
+
|
| 491 |
+
void llm_graph_result::set_params(const llm_graph_params & params) {
|
| 492 |
+
this->params = params;
|
| 493 |
+
}
|
| 494 |
+
|
| 495 |
//
|
| 496 |
// llm_graph_context
|
| 497 |
//
|
|
|
|
| 526 |
n_ctx_orig (cparams.n_ctx_orig_yarn),
|
| 527 |
pooling_type (cparams.pooling_type),
|
| 528 |
rope_type (hparams.rope_type),
|
|
|
|
| 529 |
sched (params.sched),
|
| 530 |
backend_cpu (params.backend_cpu),
|
| 531 |
cvec (params.cvec),
|
|
|
|
| 533 |
mctx (params.mctx),
|
| 534 |
cross (params.cross),
|
| 535 |
cb_func (params.cb),
|
| 536 |
+
res (params.res),
|
| 537 |
+
ctx0 (res->get_ctx()),
|
| 538 |
+
gf (res->get_gf()) {
|
| 539 |
+
res->set_params(params);
|
| 540 |
}
|
| 541 |
|
| 542 |
void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const {
|
|
|
|
| 907 |
cb(cur, "ffn_moe_weighted", il);
|
| 908 |
}
|
| 909 |
|
| 910 |
+
ggml_tensor * cur_experts[LLAMA_MAX_EXPERTS] = { nullptr };
|
| 911 |
+
|
| 912 |
+
assert(n_expert_used > 0);
|
| 913 |
+
|
| 914 |
+
// order the views before the adds
|
| 915 |
+
for (uint32_t i = 0; i < hparams.n_expert_used; ++i) {
|
| 916 |
+
cur_experts[i] = ggml_view_2d(ctx0, experts, n_embd, n_tokens, experts->nb[2], i*experts->nb[1]);
|
| 917 |
+
|
| 918 |
+
ggml_build_forward_expand(gf, cur_experts[i]);
|
| 919 |
+
}
|
| 920 |
+
|
| 921 |
// aggregate experts
|
| 922 |
+
// note: here we explicitly use hparams.n_expert_used instead of n_expert_used
|
| 923 |
+
// to avoid potentially a large number of add nodes during warmup
|
| 924 |
+
// ref: https://github.com/ggml-org/llama.cpp/pull/14753
|
| 925 |
+
ggml_tensor * moe_out = cur_experts[0];
|
| 926 |
|
| 927 |
+
for (uint32_t i = 1; i < hparams.n_expert_used; ++i) {
|
| 928 |
+
moe_out = ggml_add(ctx0, moe_out, cur_experts[i]);
|
|
|
|
|
|
|
|
|
|
| 929 |
}
|
| 930 |
|
| 931 |
+
if (hparams.n_expert_used == 1) {
|
| 932 |
// avoid returning a non-contiguous tensor
|
| 933 |
moe_out = ggml_cont(ctx0, moe_out);
|
| 934 |
}
|
|
|
|
| 1134 |
}
|
| 1135 |
|
| 1136 |
ggml_tensor * llm_graph_context::build_attn_mha(
|
|
|
|
| 1137 |
ggml_tensor * q,
|
| 1138 |
ggml_tensor * k,
|
| 1139 |
ggml_tensor * v,
|
|
|
|
| 1143 |
float kq_scale) const {
|
| 1144 |
const bool v_trans = v->nb[1] > v->nb[2];
|
| 1145 |
|
| 1146 |
+
// split the batch into streams if needed
|
| 1147 |
+
const auto n_stream = k->ne[3];
|
| 1148 |
+
|
| 1149 |
+
q = ggml_reshape_4d(ctx0, q, q->ne[0], q->ne[1], q->ne[2]/n_stream, n_stream);
|
| 1150 |
+
|
| 1151 |
q = ggml_permute(ctx0, q, 0, 2, 1, 3);
|
| 1152 |
k = ggml_permute(ctx0, k, 0, 2, 1, 3);
|
| 1153 |
v = ggml_permute(ctx0, v, 0, 2, 1, 3);
|
| 1154 |
|
| 1155 |
+
const auto n_kv = k->ne[1];
|
|
|
|
|
|
|
| 1156 |
|
| 1157 |
ggml_tensor * cur;
|
| 1158 |
|
|
|
|
| 1194 |
#endif
|
| 1195 |
}
|
| 1196 |
|
| 1197 |
+
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]);
|
| 1198 |
} else {
|
| 1199 |
ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
|
| 1200 |
|
|
|
|
| 1239 |
|
| 1240 |
cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
|
| 1241 |
|
| 1242 |
+
// recombine streams
|
| 1243 |
+
cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]);
|
| 1244 |
|
| 1245 |
if (!cparams.offload_kqv) {
|
| 1246 |
// all nodes between the KV store and the attention output are run on the CPU
|
|
|
|
| 1267 |
|
| 1268 |
ggml_tensor * llm_graph_context::build_attn(
|
| 1269 |
llm_graph_input_attn_no_cache * inp,
|
|
|
|
| 1270 |
ggml_tensor * wo,
|
| 1271 |
ggml_tensor * wo_b,
|
| 1272 |
ggml_tensor * q_cur,
|
|
|
|
| 1286 |
|
| 1287 |
const auto & kq_mask = inp->get_kq_mask();
|
| 1288 |
|
| 1289 |
+
// [TAG_NO_CACHE_PAD]
|
| 1290 |
+
// TODO: if ubatch.equal_seqs() == true, we can split the three tensors below into ubatch.n_seqs_unq streams
|
| 1291 |
+
assert(!ubatch.equal_seqs());
|
| 1292 |
+
|
| 1293 |
ggml_tensor * q = q_cur;
|
| 1294 |
ggml_tensor * k = k_cur;
|
| 1295 |
ggml_tensor * v = v_cur;
|
| 1296 |
|
| 1297 |
+
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
| 1298 |
cb(cur, "kqv_out", il);
|
| 1299 |
|
| 1300 |
if (wo) {
|
|
|
|
| 1324 |
{
|
| 1325 |
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
|
| 1326 |
|
| 1327 |
+
const auto n_kv = mctx_cur->get_n_kv();
|
| 1328 |
const auto n_tokens = ubatch.n_tokens;
|
| 1329 |
+
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
|
| 1330 |
|
| 1331 |
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
|
| 1332 |
inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
|
| 1333 |
|
| 1334 |
+
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
|
| 1335 |
ggml_set_input(inp->self_kq_mask);
|
| 1336 |
|
| 1337 |
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
|
|
|
| 1350 |
|
| 1351 |
ggml_tensor * llm_graph_context::build_attn(
|
| 1352 |
llm_graph_input_attn_kv_unified * inp,
|
|
|
|
| 1353 |
ggml_tensor * wo,
|
| 1354 |
ggml_tensor * wo_b,
|
| 1355 |
ggml_tensor * q_cur,
|
|
|
|
| 1382 |
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
|
| 1383 |
ggml_tensor * v = mctx_cur->get_v(ctx0, il);
|
| 1384 |
|
| 1385 |
+
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
| 1386 |
cb(cur, "kqv_out", il);
|
| 1387 |
|
| 1388 |
if (wo) {
|
|
|
|
| 1402 |
|
| 1403 |
ggml_tensor * llm_graph_context::build_attn(
|
| 1404 |
llm_graph_input_attn_kv_unified_iswa * inp,
|
|
|
|
| 1405 |
ggml_tensor * wo,
|
| 1406 |
ggml_tensor * wo_b,
|
| 1407 |
ggml_tensor * q_cur,
|
|
|
|
| 1448 |
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
|
| 1449 |
ggml_tensor * v = mctx_cur->get_v(ctx0, il);
|
| 1450 |
|
| 1451 |
+
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
| 1452 |
cb(cur, "kqv_out", il);
|
| 1453 |
|
| 1454 |
if (wo) {
|
|
|
|
| 1481 |
|
| 1482 |
ggml_tensor * llm_graph_context::build_attn(
|
| 1483 |
llm_graph_input_attn_cross * inp,
|
|
|
|
| 1484 |
ggml_tensor * wo,
|
| 1485 |
ggml_tensor * wo_b,
|
| 1486 |
ggml_tensor * q_cur,
|
|
|
|
| 1502 |
ggml_tensor * k = k_cur;
|
| 1503 |
ggml_tensor * v = v_cur;
|
| 1504 |
|
| 1505 |
+
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
| 1506 |
cb(cur, "kqv_out", il);
|
| 1507 |
|
| 1508 |
if (wo) {
|
|
|
|
| 1528 |
|
| 1529 |
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur);
|
| 1530 |
|
| 1531 |
+
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
|
| 1532 |
+
|
| 1533 |
{
|
| 1534 |
const auto n_kv = mctx_cur->get_base()->get_n_kv();
|
| 1535 |
|
| 1536 |
inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
|
| 1537 |
inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
|
| 1538 |
|
| 1539 |
+
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
|
| 1540 |
ggml_set_input(inp->self_kq_mask);
|
| 1541 |
|
| 1542 |
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
|
|
|
| 1550 |
inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
|
| 1551 |
inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
|
| 1552 |
|
| 1553 |
+
inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
|
| 1554 |
ggml_set_input(inp->self_kq_mask_swa);
|
| 1555 |
|
| 1556 |
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
|
|
|
|
| 1560 |
}
|
| 1561 |
|
| 1562 |
ggml_tensor * llm_graph_context::build_rs(
|
|
|
|
| 1563 |
ggml_tensor * s,
|
| 1564 |
ggml_tensor * state_copy,
|
| 1565 |
int32_t state_size,
|
|
|
|
| 1617 |
|
| 1618 |
ggml_tensor * llm_graph_context::build_rs(
|
| 1619 |
llm_graph_input_rs * inp,
|
|
|
|
| 1620 |
ggml_tensor * s,
|
| 1621 |
int32_t state_size,
|
| 1622 |
int32_t n_seqs,
|
| 1623 |
const llm_graph_get_rows_fn & get_state_rows) const {
|
| 1624 |
const auto * kv_state = inp->mctx;
|
| 1625 |
|
| 1626 |
+
return build_rs(s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), get_state_rows);
|
| 1627 |
}
|
| 1628 |
|
| 1629 |
ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
|
| 1630 |
llm_graph_input_rs * inp,
|
|
|
|
| 1631 |
const llama_ubatch & ubatch,
|
| 1632 |
+
int il) const {
|
| 1633 |
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
|
| 1634 |
|
| 1635 |
const auto token_shift_count = hparams.token_shift_count;
|
|
|
|
| 1639 |
ggml_tensor * token_shift_all = mctx_cur->get_r_l(il);
|
| 1640 |
|
| 1641 |
ggml_tensor * token_shift = build_rs(
|
| 1642 |
+
inp, token_shift_all,
|
| 1643 |
hparams.n_embd_r(), n_seqs);
|
| 1644 |
|
| 1645 |
token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs);
|
|
|
|
| 1679 |
}
|
| 1680 |
|
| 1681 |
void llm_graph_context::build_pooling(
|
|
|
|
| 1682 |
ggml_tensor * cls,
|
| 1683 |
ggml_tensor * cls_b,
|
| 1684 |
ggml_tensor * cls_out,
|
examples/talk-llama/llama-graph.h
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
#pragma once
|
| 2 |
|
| 3 |
#include "llama-arch.h"
|
|
|
|
| 4 |
#include "llama-hparams.h"
|
| 5 |
#include "llama-adapter.h"
|
| 6 |
|
|
@@ -14,7 +15,6 @@ struct ggml_cgraph;
|
|
| 14 |
struct ggml_context;
|
| 15 |
struct ggml_tensor;
|
| 16 |
|
| 17 |
-
struct llama_ubatch;
|
| 18 |
struct llama_cparams;
|
| 19 |
|
| 20 |
struct llama_memory_context_i;
|
|
@@ -69,6 +69,8 @@ struct llama_cross {
|
|
| 69 |
std::vector<std::set<llama_seq_id>> seq_ids_enc;
|
| 70 |
};
|
| 71 |
|
|
|
|
|
|
|
| 72 |
//
|
| 73 |
// llm_graph_input
|
| 74 |
//
|
|
@@ -78,11 +80,19 @@ public:
|
|
| 78 |
virtual ~llm_graph_input_i() = default;
|
| 79 |
|
| 80 |
virtual void set_input(const llama_ubatch * ubatch) = 0;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
};
|
| 82 |
|
| 83 |
using llm_graph_input_ptr = std::unique_ptr<llm_graph_input_i>;
|
| 84 |
|
| 85 |
-
|
| 86 |
class llm_graph_input_embd : public llm_graph_input_i {
|
| 87 |
public:
|
| 88 |
llm_graph_input_embd() = default;
|
|
@@ -90,6 +100,8 @@ public:
|
|
| 90 |
|
| 91 |
void set_input(const llama_ubatch * ubatch) override;
|
| 92 |
|
|
|
|
|
|
|
| 93 |
ggml_tensor * tokens = nullptr; // I32 [n_batch]
|
| 94 |
ggml_tensor * embd = nullptr; // F32 [n_embd, n_batch]
|
| 95 |
};
|
|
@@ -101,6 +113,8 @@ public:
|
|
| 101 |
|
| 102 |
void set_input(const llama_ubatch * ubatch) override;
|
| 103 |
|
|
|
|
|
|
|
| 104 |
ggml_tensor * pos = nullptr; // I32 [n_batch]
|
| 105 |
|
| 106 |
const uint32_t n_pos_per_embd = 1;
|
|
@@ -154,17 +168,19 @@ public:
|
|
| 154 |
llm_graph_input_out_ids(
|
| 155 |
const llama_hparams & hparams,
|
| 156 |
const llama_cparams & cparams,
|
| 157 |
-
|
| 158 |
virtual ~llm_graph_input_out_ids() = default;
|
| 159 |
|
| 160 |
void set_input(const llama_ubatch * ubatch) override;
|
| 161 |
|
|
|
|
|
|
|
| 162 |
ggml_tensor * out_ids; // I32 [n_outputs]
|
| 163 |
|
| 164 |
const llama_hparams & hparams;
|
| 165 |
const llama_cparams & cparams;
|
| 166 |
|
| 167 |
-
const
|
| 168 |
};
|
| 169 |
|
| 170 |
class llm_graph_input_mean : public llm_graph_input_i {
|
|
@@ -249,16 +265,18 @@ public:
|
|
| 249 |
|
| 250 |
void set_input(const llama_ubatch * ubatch) override;
|
| 251 |
|
|
|
|
|
|
|
| 252 |
ggml_tensor * get_k_idxs() const { return self_k_idxs; }
|
| 253 |
ggml_tensor * get_v_idxs() const { return self_v_idxs; }
|
| 254 |
|
| 255 |
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
|
| 256 |
|
| 257 |
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
|
| 258 |
-
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
|
| 259 |
|
| 260 |
-
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch, 1,
|
| 261 |
-
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch, 1,
|
| 262 |
|
| 263 |
const llama_hparams & hparams;
|
| 264 |
const llama_cparams & cparams;
|
|
@@ -280,6 +298,8 @@ public:
|
|
| 280 |
|
| 281 |
void set_input(const llama_ubatch * ubatch) override;
|
| 282 |
|
|
|
|
|
|
|
| 283 |
ggml_tensor * get_k_idxs() const { return self_k_idxs; }
|
| 284 |
ggml_tensor * get_v_idxs() const { return self_v_idxs; }
|
| 285 |
ggml_tensor * get_k_idxs_swa() const { return self_k_idxs_swa; }
|
|
@@ -289,14 +309,14 @@ public:
|
|
| 289 |
ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
|
| 290 |
|
| 291 |
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
|
| 292 |
-
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
|
| 293 |
ggml_tensor * self_k_idxs_swa = nullptr; // I64 [n_batch]
|
| 294 |
-
ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch]
|
| 295 |
|
| 296 |
-
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch, 1,
|
| 297 |
-
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch, 1,
|
| 298 |
-
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch, 1,
|
| 299 |
-
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch, 1,
|
| 300 |
|
| 301 |
const llama_hparams & hparams;
|
| 302 |
const llama_cparams & cparams;
|
|
@@ -351,40 +371,108 @@ public:
|
|
| 351 |
// along with the input tensors, the object also provides commonly used outputs tensors, such as logits, embeddings, etc.
|
| 352 |
// these are used by the llama_context to extact the relevant data, based on the compute parameters
|
| 353 |
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
virtual ~llm_graph_result_i() = default;
|
| 357 |
|
| 358 |
-
|
| 359 |
-
virtual ggml_tensor * get_logits() = 0;
|
| 360 |
-
virtual ggml_tensor * get_embd() = 0;
|
| 361 |
-
virtual ggml_tensor * get_embd_pooled() = 0;
|
| 362 |
|
| 363 |
-
|
| 364 |
-
|
| 365 |
|
| 366 |
-
|
|
|
|
| 367 |
|
|
|
|
| 368 |
|
| 369 |
-
|
| 370 |
-
public:
|
| 371 |
-
virtual ~llm_graph_result() = default;
|
| 372 |
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
ggml_tensor * get_embd() override { return t_embd; }
|
| 376 |
-
ggml_tensor * get_embd_pooled() override { return t_embd_pooled; }
|
| 377 |
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 381 |
}
|
| 382 |
-
}
|
| 383 |
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 387 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 388 |
|
| 389 |
// important graph nodes
|
| 390 |
ggml_tensor * t_tokens = nullptr;
|
|
@@ -393,36 +481,31 @@ public:
|
|
| 393 |
ggml_tensor * t_embd_pooled = nullptr;
|
| 394 |
|
| 395 |
std::vector<llm_graph_input_ptr> inputs;
|
| 396 |
-
};
|
| 397 |
|
| 398 |
-
|
| 399 |
-
// llm_graph_context
|
| 400 |
-
//
|
| 401 |
|
| 402 |
-
//
|
| 403 |
-
|
| 404 |
|
| 405 |
-
|
| 406 |
-
ggml_context * ctx;
|
| 407 |
|
| 408 |
-
|
| 409 |
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
|
|
|
|
|
|
| 413 |
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
const llama_adapter_cvec * cvec;
|
| 418 |
-
const llama_adapter_loras * loras;
|
| 419 |
-
const llama_memory_context_i * mctx;
|
| 420 |
-
const llama_cross * cross;
|
| 421 |
|
| 422 |
-
|
| 423 |
|
| 424 |
-
|
| 425 |
-
|
|
|
|
| 426 |
|
| 427 |
// used in build_rs to properly order writes and avoid unnecessary copies
|
| 428 |
using llm_graph_get_rows_fn = std::function<ggml_tensor * (ggml_context *, ggml_tensor * states, ggml_tensor * ids)>;
|
|
@@ -463,8 +546,6 @@ struct llm_graph_context {
|
|
| 463 |
const enum llama_pooling_type pooling_type;
|
| 464 |
const enum llama_rope_type rope_type;
|
| 465 |
|
| 466 |
-
ggml_context * ctx0 = nullptr;
|
| 467 |
-
|
| 468 |
ggml_backend_sched_t sched;
|
| 469 |
|
| 470 |
ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
|
|
@@ -476,7 +557,10 @@ struct llm_graph_context {
|
|
| 476 |
|
| 477 |
const llm_graph_cb & cb_func;
|
| 478 |
|
| 479 |
-
|
|
|
|
|
|
|
|
|
|
| 480 |
|
| 481 |
llm_graph_context(const llm_graph_params & params);
|
| 482 |
virtual ~llm_graph_context() = default;
|
|
@@ -562,7 +646,6 @@ struct llm_graph_context {
|
|
| 562 |
//
|
| 563 |
|
| 564 |
ggml_tensor * build_attn_mha(
|
| 565 |
-
ggml_cgraph * gf,
|
| 566 |
ggml_tensor * q, // [n_embd_head_q, n_head_q, n_tokens]
|
| 567 |
ggml_tensor * k, // [n_embd_head_k, n_head_k, n_tokens]
|
| 568 |
ggml_tensor * v, // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false)
|
|
@@ -575,7 +658,6 @@ struct llm_graph_context {
|
|
| 575 |
|
| 576 |
ggml_tensor * build_attn(
|
| 577 |
llm_graph_input_attn_no_cache * inp,
|
| 578 |
-
ggml_cgraph * gf,
|
| 579 |
ggml_tensor * wo,
|
| 580 |
ggml_tensor * wo_b,
|
| 581 |
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
|
|
@@ -590,7 +672,6 @@ struct llm_graph_context {
|
|
| 590 |
|
| 591 |
ggml_tensor * build_attn(
|
| 592 |
llm_graph_input_attn_kv_unified * inp,
|
| 593 |
-
ggml_cgraph * gf,
|
| 594 |
ggml_tensor * wo,
|
| 595 |
ggml_tensor * wo_b,
|
| 596 |
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
|
|
@@ -606,7 +687,6 @@ struct llm_graph_context {
|
|
| 606 |
// note: if k_cur or v_cur are not provided, they will not be stored in the memory
|
| 607 |
ggml_tensor * build_attn(
|
| 608 |
llm_graph_input_attn_kv_unified_iswa * inp,
|
| 609 |
-
ggml_cgraph * gf,
|
| 610 |
ggml_tensor * wo,
|
| 611 |
ggml_tensor * wo_b,
|
| 612 |
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
|
|
@@ -621,7 +701,6 @@ struct llm_graph_context {
|
|
| 621 |
|
| 622 |
ggml_tensor * build_attn(
|
| 623 |
llm_graph_input_attn_cross * inp,
|
| 624 |
-
ggml_cgraph * gf,
|
| 625 |
ggml_tensor * wo,
|
| 626 |
ggml_tensor * wo_b,
|
| 627 |
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
|
|
@@ -643,7 +722,6 @@ struct llm_graph_context {
|
|
| 643 |
// implementation in 2 separate methods. the goal is to avoid calling `ggml_build_forward_expand` in
|
| 644 |
// `llama_memory_recurrent`
|
| 645 |
ggml_tensor * build_rs(
|
| 646 |
-
ggml_cgraph * gf,
|
| 647 |
ggml_tensor * s,
|
| 648 |
ggml_tensor * state_copy,
|
| 649 |
int32_t state_size,
|
|
@@ -658,7 +736,6 @@ struct llm_graph_context {
|
|
| 658 |
|
| 659 |
ggml_tensor * build_rs(
|
| 660 |
llm_graph_input_rs * inp,
|
| 661 |
-
ggml_cgraph * gf,
|
| 662 |
ggml_tensor * s,
|
| 663 |
int32_t state_size,
|
| 664 |
int32_t n_seqs,
|
|
@@ -666,9 +743,8 @@ struct llm_graph_context {
|
|
| 666 |
|
| 667 |
ggml_tensor * build_rwkv_token_shift_load(
|
| 668 |
llm_graph_input_rs * inp,
|
| 669 |
-
ggml_cgraph * gf,
|
| 670 |
const llama_ubatch & ubatch,
|
| 671 |
-
|
| 672 |
|
| 673 |
ggml_tensor * build_rwkv_token_shift_store(
|
| 674 |
ggml_tensor * token_shift,
|
|
@@ -685,7 +761,6 @@ struct llm_graph_context {
|
|
| 685 |
//
|
| 686 |
|
| 687 |
void build_pooling(
|
| 688 |
-
ggml_cgraph * gf,
|
| 689 |
ggml_tensor * cls,
|
| 690 |
ggml_tensor * cls_b,
|
| 691 |
ggml_tensor * cls_out,
|
|
|
|
| 1 |
#pragma once
|
| 2 |
|
| 3 |
#include "llama-arch.h"
|
| 4 |
+
#include "llama-batch.h"
|
| 5 |
#include "llama-hparams.h"
|
| 6 |
#include "llama-adapter.h"
|
| 7 |
|
|
|
|
| 15 |
struct ggml_context;
|
| 16 |
struct ggml_tensor;
|
| 17 |
|
|
|
|
| 18 |
struct llama_cparams;
|
| 19 |
|
| 20 |
struct llama_memory_context_i;
|
|
|
|
| 69 |
std::vector<std::set<llama_seq_id>> seq_ids_enc;
|
| 70 |
};
|
| 71 |
|
| 72 |
+
struct llm_graph_params;
|
| 73 |
+
|
| 74 |
//
|
| 75 |
// llm_graph_input
|
| 76 |
//
|
|
|
|
| 80 |
virtual ~llm_graph_input_i() = default;
|
| 81 |
|
| 82 |
virtual void set_input(const llama_ubatch * ubatch) = 0;
|
| 83 |
+
|
| 84 |
+
// return true if the resulting input tensors using the provided graph parameters would be
|
| 85 |
+
// the same as the previous input tensors that we have currently stored in the object
|
| 86 |
+
virtual bool can_reuse(const llm_graph_params & params) {
|
| 87 |
+
// returning false here by default will prevent from reusing the graph if the check
|
| 88 |
+
// for the input type has not been implemented yet
|
| 89 |
+
GGML_UNUSED(params);
|
| 90 |
+
return false;
|
| 91 |
+
}
|
| 92 |
};
|
| 93 |
|
| 94 |
using llm_graph_input_ptr = std::unique_ptr<llm_graph_input_i>;
|
| 95 |
|
|
|
|
| 96 |
class llm_graph_input_embd : public llm_graph_input_i {
|
| 97 |
public:
|
| 98 |
llm_graph_input_embd() = default;
|
|
|
|
| 100 |
|
| 101 |
void set_input(const llama_ubatch * ubatch) override;
|
| 102 |
|
| 103 |
+
bool can_reuse(const llm_graph_params & params) override;
|
| 104 |
+
|
| 105 |
ggml_tensor * tokens = nullptr; // I32 [n_batch]
|
| 106 |
ggml_tensor * embd = nullptr; // F32 [n_embd, n_batch]
|
| 107 |
};
|
|
|
|
| 113 |
|
| 114 |
void set_input(const llama_ubatch * ubatch) override;
|
| 115 |
|
| 116 |
+
bool can_reuse(const llm_graph_params & params) override;
|
| 117 |
+
|
| 118 |
ggml_tensor * pos = nullptr; // I32 [n_batch]
|
| 119 |
|
| 120 |
const uint32_t n_pos_per_embd = 1;
|
|
|
|
| 168 |
llm_graph_input_out_ids(
|
| 169 |
const llama_hparams & hparams,
|
| 170 |
const llama_cparams & cparams,
|
| 171 |
+
uint32_t n_outputs) : hparams(hparams), cparams(cparams), n_outputs(n_outputs) {}
|
| 172 |
virtual ~llm_graph_input_out_ids() = default;
|
| 173 |
|
| 174 |
void set_input(const llama_ubatch * ubatch) override;
|
| 175 |
|
| 176 |
+
bool can_reuse(const llm_graph_params & params) override;
|
| 177 |
+
|
| 178 |
ggml_tensor * out_ids; // I32 [n_outputs]
|
| 179 |
|
| 180 |
const llama_hparams & hparams;
|
| 181 |
const llama_cparams & cparams;
|
| 182 |
|
| 183 |
+
const uint32_t n_outputs;
|
| 184 |
};
|
| 185 |
|
| 186 |
class llm_graph_input_mean : public llm_graph_input_i {
|
|
|
|
| 265 |
|
| 266 |
void set_input(const llama_ubatch * ubatch) override;
|
| 267 |
|
| 268 |
+
bool can_reuse(const llm_graph_params & params) override;
|
| 269 |
+
|
| 270 |
ggml_tensor * get_k_idxs() const { return self_k_idxs; }
|
| 271 |
ggml_tensor * get_v_idxs() const { return self_v_idxs; }
|
| 272 |
|
| 273 |
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
|
| 274 |
|
| 275 |
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
|
| 276 |
+
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
|
| 277 |
|
| 278 |
+
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
|
| 279 |
+
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
|
| 280 |
|
| 281 |
const llama_hparams & hparams;
|
| 282 |
const llama_cparams & cparams;
|
|
|
|
| 298 |
|
| 299 |
void set_input(const llama_ubatch * ubatch) override;
|
| 300 |
|
| 301 |
+
bool can_reuse(const llm_graph_params & params) override;
|
| 302 |
+
|
| 303 |
ggml_tensor * get_k_idxs() const { return self_k_idxs; }
|
| 304 |
ggml_tensor * get_v_idxs() const { return self_v_idxs; }
|
| 305 |
ggml_tensor * get_k_idxs_swa() const { return self_k_idxs_swa; }
|
|
|
|
| 309 |
ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
|
| 310 |
|
| 311 |
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
|
| 312 |
+
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
|
| 313 |
ggml_tensor * self_k_idxs_swa = nullptr; // I64 [n_batch]
|
| 314 |
+
ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
|
| 315 |
|
| 316 |
+
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
|
| 317 |
+
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
|
| 318 |
+
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
|
| 319 |
+
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
|
| 320 |
|
| 321 |
const llama_hparams & hparams;
|
| 322 |
const llama_cparams & cparams;
|
|
|
|
| 371 |
// along with the input tensors, the object also provides commonly used outputs tensors, such as logits, embeddings, etc.
|
| 372 |
// these are used by the llama_context to extact the relevant data, based on the compute parameters
|
| 373 |
|
| 374 |
+
// callback that allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.)
|
| 375 |
+
using llm_graph_cb = std::function<void(const llama_ubatch & ubatch, ggml_tensor * cur, const char * name, int il)>;
|
|
|
|
| 376 |
|
| 377 |
+
class llm_graph_result;
|
|
|
|
|
|
|
|
|
|
| 378 |
|
| 379 |
+
struct llm_graph_params {
|
| 380 |
+
llm_arch arch = LLM_ARCH_UNKNOWN;
|
| 381 |
|
| 382 |
+
llama_hparams hparams;
|
| 383 |
+
llama_cparams cparams;
|
| 384 |
|
| 385 |
+
llama_ubatch ubatch; // note: intentionally make a copy
|
| 386 |
|
| 387 |
+
llm_graph_type gtype;
|
|
|
|
|
|
|
| 388 |
|
| 389 |
+
ggml_backend_sched_t sched;
|
| 390 |
+
ggml_backend_t backend_cpu;
|
|
|
|
|
|
|
| 391 |
|
| 392 |
+
const llama_adapter_cvec * cvec;
|
| 393 |
+
const llama_adapter_loras * loras;
|
| 394 |
+
const llama_memory_context_i * mctx;
|
| 395 |
+
const llama_cross * cross;
|
| 396 |
+
|
| 397 |
+
uint32_t n_outputs;
|
| 398 |
+
|
| 399 |
+
llm_graph_cb cb;
|
| 400 |
+
|
| 401 |
+
llm_graph_result * res;
|
| 402 |
+
|
| 403 |
+
// return true if the "other" params would result in a graph with the same topology as with the current params
|
| 404 |
+
// having the same topology allows us to reuse the graph in some cases
|
| 405 |
+
bool allow_reuse(const llm_graph_params & other) const {
|
| 406 |
+
// first check the ubatch
|
| 407 |
+
bool can_reuse_ubatch =
|
| 408 |
+
ubatch.equal_seqs() == other.ubatch.equal_seqs() &&
|
| 409 |
+
ubatch.n_tokens == other.ubatch.n_tokens &&
|
| 410 |
+
ubatch.n_seq_tokens == other.ubatch.n_seq_tokens &&
|
| 411 |
+
ubatch.n_seqs == other.ubatch.n_seqs &&
|
| 412 |
+
ubatch.n_seqs_unq == other.ubatch.n_seqs_unq &&
|
| 413 |
+
(
|
| 414 |
+
(!ubatch.token && !other.ubatch.token) ||
|
| 415 |
+
(!ubatch.embd && !other.ubatch.embd)
|
| 416 |
+
);
|
| 417 |
+
|
| 418 |
+
if (can_reuse_ubatch && !ubatch.equal_seqs()) {
|
| 419 |
+
if (!ubatch.data) {
|
| 420 |
+
// if the old ubatch does not own it's data, then we cannot guarantee that it is still alive, and
|
| 421 |
+
// therefore we cannot perform the sequence id check. normally should never happen
|
| 422 |
+
can_reuse_ubatch = false;
|
| 423 |
+
} else {
|
| 424 |
+
for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
|
| 425 |
+
can_reuse_ubatch &= ubatch.seq_id_unq[s] == other.ubatch.seq_id_unq[s];
|
| 426 |
+
}
|
| 427 |
+
}
|
| 428 |
}
|
|
|
|
| 429 |
|
| 430 |
+
if (!can_reuse_ubatch) {
|
| 431 |
+
return false;
|
| 432 |
+
}
|
| 433 |
+
|
| 434 |
+
return
|
| 435 |
+
cparams.embeddings == other.cparams.embeddings &&
|
| 436 |
+
cparams.causal_attn == other.cparams.causal_attn &&
|
| 437 |
+
arch == other.arch &&
|
| 438 |
+
gtype == other.gtype &&
|
| 439 |
+
cvec == other.cvec &&
|
| 440 |
+
loras == other.loras &&
|
| 441 |
+
cross == other.cross &&
|
| 442 |
+
n_outputs == other.n_outputs;
|
| 443 |
}
|
| 444 |
+
};
|
| 445 |
+
|
| 446 |
+
class llm_graph_result {
|
| 447 |
+
public:
|
| 448 |
+
llm_graph_result(int64_t max_nodes);
|
| 449 |
+
|
| 450 |
+
virtual ~llm_graph_result() = default;
|
| 451 |
+
|
| 452 |
+
ggml_tensor * get_tokens() const { return t_tokens; }
|
| 453 |
+
ggml_tensor * get_logits() const { return t_logits; }
|
| 454 |
+
ggml_tensor * get_embd() const { return t_embd; }
|
| 455 |
+
ggml_tensor * get_embd_pooled() const { return t_embd_pooled; }
|
| 456 |
+
|
| 457 |
+
ggml_cgraph * get_gf() const { return gf; }
|
| 458 |
+
ggml_context * get_ctx() const { return ctx_compute.get(); }
|
| 459 |
+
|
| 460 |
+
int64_t get_max_nodes() const;
|
| 461 |
+
|
| 462 |
+
void reset();
|
| 463 |
+
|
| 464 |
+
void set_inputs(const llama_ubatch * ubatch);
|
| 465 |
+
|
| 466 |
+
// try to update the existing graph result using the new graph parameters in order to reuse it
|
| 467 |
+
// this can only be done if we determine that the resulting graph using the new graph parameters
|
| 468 |
+
// would be identical to the existing graph. in that case, we simply have to update the memory
|
| 469 |
+
// contexts of the input tensors of the graph and we can reuse it for another computation
|
| 470 |
+
// return true if the graph was updated and can be reused
|
| 471 |
+
bool can_reuse(const llm_graph_params & params);
|
| 472 |
+
|
| 473 |
+
llm_graph_input_i * add_input(llm_graph_input_ptr input);
|
| 474 |
+
|
| 475 |
+
void set_params(const llm_graph_params & params);
|
| 476 |
|
| 477 |
// important graph nodes
|
| 478 |
ggml_tensor * t_tokens = nullptr;
|
|
|
|
| 481 |
ggml_tensor * t_embd_pooled = nullptr;
|
| 482 |
|
| 483 |
std::vector<llm_graph_input_ptr> inputs;
|
|
|
|
| 484 |
|
| 485 |
+
ggml_context_ptr ctx_compute;
|
|
|
|
|
|
|
| 486 |
|
| 487 |
+
// memory buffers used to evaluate the model
|
| 488 |
+
std::vector<uint8_t> buf_compute_meta;
|
| 489 |
|
| 490 |
+
ggml_cgraph * gf;
|
|
|
|
| 491 |
|
| 492 |
+
int64_t max_nodes;
|
| 493 |
|
| 494 |
+
private:
|
| 495 |
+
// keep a copy of the previous graph parameters
|
| 496 |
+
// we will use this to determine whether the graph can be reused by comparing them with the new parameters
|
| 497 |
+
// note: these are updated after constructing the new graph
|
| 498 |
+
llm_graph_params params;
|
| 499 |
|
| 500 |
+
// env: LLAMA_GRAPH_RESULT_DEBUG
|
| 501 |
+
int debug = 0;
|
| 502 |
+
};
|
|
|
|
|
|
|
|
|
|
|
|
|
| 503 |
|
| 504 |
+
using llm_graph_result_ptr = std::unique_ptr<llm_graph_result>;
|
| 505 |
|
| 506 |
+
//
|
| 507 |
+
// llm_graph_context
|
| 508 |
+
//
|
| 509 |
|
| 510 |
// used in build_rs to properly order writes and avoid unnecessary copies
|
| 511 |
using llm_graph_get_rows_fn = std::function<ggml_tensor * (ggml_context *, ggml_tensor * states, ggml_tensor * ids)>;
|
|
|
|
| 546 |
const enum llama_pooling_type pooling_type;
|
| 547 |
const enum llama_rope_type rope_type;
|
| 548 |
|
|
|
|
|
|
|
| 549 |
ggml_backend_sched_t sched;
|
| 550 |
|
| 551 |
ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
|
|
|
|
| 557 |
|
| 558 |
const llm_graph_cb & cb_func;
|
| 559 |
|
| 560 |
+
llm_graph_result * res;
|
| 561 |
+
|
| 562 |
+
ggml_context * ctx0 = nullptr;
|
| 563 |
+
ggml_cgraph * gf = nullptr;
|
| 564 |
|
| 565 |
llm_graph_context(const llm_graph_params & params);
|
| 566 |
virtual ~llm_graph_context() = default;
|
|
|
|
| 646 |
//
|
| 647 |
|
| 648 |
ggml_tensor * build_attn_mha(
|
|
|
|
| 649 |
ggml_tensor * q, // [n_embd_head_q, n_head_q, n_tokens]
|
| 650 |
ggml_tensor * k, // [n_embd_head_k, n_head_k, n_tokens]
|
| 651 |
ggml_tensor * v, // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false)
|
|
|
|
| 658 |
|
| 659 |
ggml_tensor * build_attn(
|
| 660 |
llm_graph_input_attn_no_cache * inp,
|
|
|
|
| 661 |
ggml_tensor * wo,
|
| 662 |
ggml_tensor * wo_b,
|
| 663 |
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
|
|
|
|
| 672 |
|
| 673 |
ggml_tensor * build_attn(
|
| 674 |
llm_graph_input_attn_kv_unified * inp,
|
|
|
|
| 675 |
ggml_tensor * wo,
|
| 676 |
ggml_tensor * wo_b,
|
| 677 |
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
|
|
|
|
| 687 |
// note: if k_cur or v_cur are not provided, they will not be stored in the memory
|
| 688 |
ggml_tensor * build_attn(
|
| 689 |
llm_graph_input_attn_kv_unified_iswa * inp,
|
|
|
|
| 690 |
ggml_tensor * wo,
|
| 691 |
ggml_tensor * wo_b,
|
| 692 |
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
|
|
|
|
| 701 |
|
| 702 |
ggml_tensor * build_attn(
|
| 703 |
llm_graph_input_attn_cross * inp,
|
|
|
|
| 704 |
ggml_tensor * wo,
|
| 705 |
ggml_tensor * wo_b,
|
| 706 |
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
|
|
|
|
| 722 |
// implementation in 2 separate methods. the goal is to avoid calling `ggml_build_forward_expand` in
|
| 723 |
// `llama_memory_recurrent`
|
| 724 |
ggml_tensor * build_rs(
|
|
|
|
| 725 |
ggml_tensor * s,
|
| 726 |
ggml_tensor * state_copy,
|
| 727 |
int32_t state_size,
|
|
|
|
| 736 |
|
| 737 |
ggml_tensor * build_rs(
|
| 738 |
llm_graph_input_rs * inp,
|
|
|
|
| 739 |
ggml_tensor * s,
|
| 740 |
int32_t state_size,
|
| 741 |
int32_t n_seqs,
|
|
|
|
| 743 |
|
| 744 |
ggml_tensor * build_rwkv_token_shift_load(
|
| 745 |
llm_graph_input_rs * inp,
|
|
|
|
| 746 |
const llama_ubatch & ubatch,
|
| 747 |
+
int il) const;
|
| 748 |
|
| 749 |
ggml_tensor * build_rwkv_token_shift_store(
|
| 750 |
ggml_tensor * token_shift,
|
|
|
|
| 761 |
//
|
| 762 |
|
| 763 |
void build_pooling(
|
|
|
|
| 764 |
ggml_tensor * cls,
|
| 765 |
ggml_tensor * cls_b,
|
| 766 |
ggml_tensor * cls_out,
|
examples/talk-llama/llama-hparams.cpp
CHANGED
|
@@ -65,6 +65,46 @@ uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const {
|
|
| 65 |
return n_embd_head_v * n_head_kv;
|
| 66 |
}
|
| 67 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
uint32_t llama_hparams::n_embd_r() const {
|
| 69 |
if (wkv_head_size != 0) {
|
| 70 |
// for RWKV models
|
|
|
|
| 65 |
return n_embd_head_v * n_head_kv;
|
| 66 |
}
|
| 67 |
|
| 68 |
+
bool llama_hparams::is_n_embd_k_gqa_variable() const {
|
| 69 |
+
const uint32_t val = n_embd_k_gqa();
|
| 70 |
+
for (uint32_t il = 0; il < n_layer; ++il) {
|
| 71 |
+
if (val != n_embd_k_gqa(il)) {
|
| 72 |
+
return true;
|
| 73 |
+
}
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
return false;
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
bool llama_hparams::is_n_embd_v_gqa_variable() const {
|
| 80 |
+
const uint32_t val = n_embd_v_gqa();
|
| 81 |
+
for (uint32_t il = 0; il < n_layer; ++il) {
|
| 82 |
+
if (val != n_embd_v_gqa(il)) {
|
| 83 |
+
return true;
|
| 84 |
+
}
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
return false;
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
uint32_t llama_hparams::n_embd_k_gqa_max() const {
|
| 91 |
+
uint32_t val = n_embd_k_gqa();
|
| 92 |
+
for (uint32_t il = 0; il < n_layer; ++il) {
|
| 93 |
+
val = std::max(val, n_embd_k_gqa(il));
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
return val;
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
uint32_t llama_hparams::n_embd_v_gqa_max() const {
|
| 100 |
+
uint32_t val = n_embd_v_gqa();
|
| 101 |
+
for (uint32_t il = 0; il < n_layer; ++il) {
|
| 102 |
+
val = std::max(val, n_embd_v_gqa(il));
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
return val;
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
uint32_t llama_hparams::n_embd_r() const {
|
| 109 |
if (wkv_head_size != 0) {
|
| 110 |
// for RWKV models
|
examples/talk-llama/llama-hparams.h
CHANGED
|
@@ -6,7 +6,7 @@
|
|
| 6 |
|
| 7 |
// bump if necessary
|
| 8 |
#define LLAMA_MAX_LAYERS 512
|
| 9 |
-
#define LLAMA_MAX_EXPERTS
|
| 10 |
|
| 11 |
enum llama_expert_gating_func_type {
|
| 12 |
LLAMA_EXPERT_GATING_FUNC_TYPE_NONE = 0,
|
|
@@ -98,7 +98,7 @@ struct llama_hparams {
|
|
| 98 |
float rope_freq_scale_train;
|
| 99 |
float rope_freq_scale_train_swa;
|
| 100 |
uint32_t n_ctx_orig_yarn;
|
| 101 |
-
float rope_yarn_log_mul;
|
| 102 |
|
| 103 |
std::array<int, 4> rope_sections;
|
| 104 |
|
|
@@ -191,6 +191,14 @@ struct llama_hparams {
|
|
| 191 |
// dimension of value embeddings across all k-v heads
|
| 192 |
uint32_t n_embd_v_gqa(uint32_t il = 0) const;
|
| 193 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
// dimension of the rolling state embeddings
|
| 195 |
// corresponds to Mamba's conv_states size or RWKV's token_shift states size
|
| 196 |
uint32_t n_embd_r() const;
|
|
|
|
| 6 |
|
| 7 |
// bump if necessary
|
| 8 |
#define LLAMA_MAX_LAYERS 512
|
| 9 |
+
#define LLAMA_MAX_EXPERTS 384 // Kimi-K2
|
| 10 |
|
| 11 |
enum llama_expert_gating_func_type {
|
| 12 |
LLAMA_EXPERT_GATING_FUNC_TYPE_NONE = 0,
|
|
|
|
| 98 |
float rope_freq_scale_train;
|
| 99 |
float rope_freq_scale_train_swa;
|
| 100 |
uint32_t n_ctx_orig_yarn;
|
| 101 |
+
float rope_yarn_log_mul = 0.0f;
|
| 102 |
|
| 103 |
std::array<int, 4> rope_sections;
|
| 104 |
|
|
|
|
| 191 |
// dimension of value embeddings across all k-v heads
|
| 192 |
uint32_t n_embd_v_gqa(uint32_t il = 0) const;
|
| 193 |
|
| 194 |
+
// true if any layer has a different n_embd_k_gqa/n_embd_v_gqa
|
| 195 |
+
bool is_n_embd_k_gqa_variable() const;
|
| 196 |
+
bool is_n_embd_v_gqa_variable() const;
|
| 197 |
+
|
| 198 |
+
// return the maximum n_embd_k_gqa/n_embd_v_gqa across all layers
|
| 199 |
+
uint32_t n_embd_k_gqa_max() const;
|
| 200 |
+
uint32_t n_embd_v_gqa_max() const;
|
| 201 |
+
|
| 202 |
// dimension of the rolling state embeddings
|
| 203 |
// corresponds to Mamba's conv_states size or RWKV's token_shift states size
|
| 204 |
uint32_t n_embd_r() const;
|
examples/talk-llama/llama-kv-cache-unified-iswa.cpp
CHANGED
|
@@ -18,16 +18,17 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
|
|
| 18 |
bool v_trans,
|
| 19 |
bool offload,
|
| 20 |
bool swa_full,
|
|
|
|
| 21 |
uint32_t kv_size,
|
| 22 |
uint32_t n_seq_max,
|
| 23 |
uint32_t n_ubatch,
|
| 24 |
-
uint32_t n_pad) : hparams(model.hparams) {
|
| 25 |
llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
|
| 26 |
llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); };
|
| 27 |
|
| 28 |
const uint32_t size_base = kv_size;
|
| 29 |
|
| 30 |
-
uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*n_seq_max + n_ubatch, n_pad));
|
| 31 |
|
| 32 |
// when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size
|
| 33 |
if (swa_full) {
|
|
@@ -41,14 +42,14 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
|
|
| 41 |
|
| 42 |
kv_base = std::make_unique<llama_kv_cache_unified>(
|
| 43 |
model, std::move(filter_base), type_k, type_v,
|
| 44 |
-
v_trans, offload, size_base, n_seq_max, n_pad,
|
| 45 |
0, LLAMA_SWA_TYPE_NONE);
|
| 46 |
|
| 47 |
LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa);
|
| 48 |
|
| 49 |
kv_swa = std::make_unique<llama_kv_cache_unified>(
|
| 50 |
model, std::move(filter_swa), type_k, type_v,
|
| 51 |
-
v_trans, offload, size_swa, n_seq_max, n_pad,
|
| 52 |
hparams.n_swa, hparams.swa_type);
|
| 53 |
}
|
| 54 |
|
|
@@ -100,6 +101,11 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
|
|
| 100 |
|
| 101 |
// first try simple split
|
| 102 |
do {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
balloc.split_reset();
|
| 104 |
|
| 105 |
std::vector<llama_ubatch> ubatches;
|
|
@@ -140,7 +146,7 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
|
|
| 140 |
|
| 141 |
std::vector<llama_ubatch> ubatches;
|
| 142 |
while (true) {
|
| 143 |
-
auto ubatch = balloc.split_equal(n_ubatch,
|
| 144 |
|
| 145 |
if (ubatch.n_tokens == 0) {
|
| 146 |
break;
|
|
|
|
| 18 |
bool v_trans,
|
| 19 |
bool offload,
|
| 20 |
bool swa_full,
|
| 21 |
+
bool unified,
|
| 22 |
uint32_t kv_size,
|
| 23 |
uint32_t n_seq_max,
|
| 24 |
uint32_t n_ubatch,
|
| 25 |
+
uint32_t n_pad) : hparams(model.hparams), unified(unified) {
|
| 26 |
llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
|
| 27 |
llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); };
|
| 28 |
|
| 29 |
const uint32_t size_base = kv_size;
|
| 30 |
|
| 31 |
+
uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*(unified ? n_seq_max : 1) + n_ubatch, n_pad));
|
| 32 |
|
| 33 |
// when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size
|
| 34 |
if (swa_full) {
|
|
|
|
| 42 |
|
| 43 |
kv_base = std::make_unique<llama_kv_cache_unified>(
|
| 44 |
model, std::move(filter_base), type_k, type_v,
|
| 45 |
+
v_trans, offload, unified, size_base, n_seq_max, n_pad,
|
| 46 |
0, LLAMA_SWA_TYPE_NONE);
|
| 47 |
|
| 48 |
LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa);
|
| 49 |
|
| 50 |
kv_swa = std::make_unique<llama_kv_cache_unified>(
|
| 51 |
model, std::move(filter_swa), type_k, type_v,
|
| 52 |
+
v_trans, offload, unified, size_swa, n_seq_max, n_pad,
|
| 53 |
hparams.n_swa, hparams.swa_type);
|
| 54 |
}
|
| 55 |
|
|
|
|
| 101 |
|
| 102 |
// first try simple split
|
| 103 |
do {
|
| 104 |
+
if (!unified) {
|
| 105 |
+
// requires equal splits, so we skip the simple split
|
| 106 |
+
break;
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
balloc.split_reset();
|
| 110 |
|
| 111 |
std::vector<llama_ubatch> ubatches;
|
|
|
|
| 146 |
|
| 147 |
std::vector<llama_ubatch> ubatches;
|
| 148 |
while (true) {
|
| 149 |
+
auto ubatch = balloc.split_equal(n_ubatch, !unified);
|
| 150 |
|
| 151 |
if (ubatch.n_tokens == 0) {
|
| 152 |
break;
|
examples/talk-llama/llama-kv-cache-unified-iswa.h
CHANGED
|
@@ -20,6 +20,7 @@ public:
|
|
| 20 |
bool v_trans,
|
| 21 |
bool offload,
|
| 22 |
bool swa_full,
|
|
|
|
| 23 |
uint32_t kv_size,
|
| 24 |
uint32_t n_seq_max,
|
| 25 |
uint32_t n_ubatch,
|
|
@@ -68,6 +69,8 @@ public:
|
|
| 68 |
private:
|
| 69 |
const llama_hparams & hparams;
|
| 70 |
|
|
|
|
|
|
|
| 71 |
std::unique_ptr<llama_kv_cache_unified> kv_base;
|
| 72 |
std::unique_ptr<llama_kv_cache_unified> kv_swa;
|
| 73 |
};
|
|
|
|
| 20 |
bool v_trans,
|
| 21 |
bool offload,
|
| 22 |
bool swa_full,
|
| 23 |
+
bool unified,
|
| 24 |
uint32_t kv_size,
|
| 25 |
uint32_t n_seq_max,
|
| 26 |
uint32_t n_ubatch,
|
|
|
|
| 69 |
private:
|
| 70 |
const llama_hparams & hparams;
|
| 71 |
|
| 72 |
+
const bool unified;
|
| 73 |
+
|
| 74 |
std::unique_ptr<llama_kv_cache_unified> kv_base;
|
| 75 |
std::unique_ptr<llama_kv_cache_unified> kv_swa;
|
| 76 |
};
|
examples/talk-llama/llama-kv-cache-unified.cpp
CHANGED
|
@@ -23,13 +23,14 @@ llama_kv_cache_unified::llama_kv_cache_unified(
|
|
| 23 |
ggml_type type_v,
|
| 24 |
bool v_trans,
|
| 25 |
bool offload,
|
|
|
|
| 26 |
uint32_t kv_size,
|
| 27 |
uint32_t n_seq_max,
|
| 28 |
uint32_t n_pad,
|
| 29 |
uint32_t n_swa,
|
| 30 |
llama_swa_type swa_type) :
|
| 31 |
model(model), hparams(model.hparams), v_trans(v_trans),
|
| 32 |
-
n_seq_max(n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) {
|
| 33 |
|
| 34 |
GGML_ASSERT(kv_size % n_pad == 0);
|
| 35 |
|
|
@@ -45,7 +46,7 @@ llama_kv_cache_unified::llama_kv_cache_unified(
|
|
| 45 |
auto it = ctx_map.find(buft);
|
| 46 |
if (it == ctx_map.end()) {
|
| 47 |
ggml_init_params params = {
|
| 48 |
-
/*.mem_size =*/ size_t(2u*n_layer_cache*ggml_tensor_overhead()),
|
| 49 |
/*.mem_buffer =*/ NULL,
|
| 50 |
/*.no_alloc =*/ true,
|
| 51 |
};
|
|
@@ -64,9 +65,33 @@ llama_kv_cache_unified::llama_kv_cache_unified(
|
|
| 64 |
return it->second;
|
| 65 |
};
|
| 66 |
|
| 67 |
-
|
| 68 |
|
| 69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
|
| 71 |
for (uint32_t il = 0; il < n_layer_cache; il++) {
|
| 72 |
if (filter && !filter(il)) {
|
|
@@ -74,8 +99,9 @@ llama_kv_cache_unified::llama_kv_cache_unified(
|
|
| 74 |
continue;
|
| 75 |
}
|
| 76 |
|
| 77 |
-
|
| 78 |
-
const uint32_t
|
|
|
|
| 79 |
|
| 80 |
const char * dev_name = "CPU";
|
| 81 |
|
|
@@ -98,14 +124,23 @@ llama_kv_cache_unified::llama_kv_cache_unified(
|
|
| 98 |
ggml_tensor * k;
|
| 99 |
ggml_tensor * v;
|
| 100 |
|
| 101 |
-
k =
|
| 102 |
-
v =
|
| 103 |
|
| 104 |
ggml_format_name(k, "cache_k_l%d", il);
|
| 105 |
ggml_format_name(v, "cache_v_l%d", il);
|
| 106 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
map_layer_ids[il] = layers.size();
|
| 108 |
-
|
|
|
|
| 109 |
}
|
| 110 |
|
| 111 |
// TODO: this is temporary until we support passing reuse layer filters [KV_REUSE]
|
|
@@ -148,8 +183,8 @@ llama_kv_cache_unified::llama_kv_cache_unified(
|
|
| 148 |
const size_t memory_size_k = size_k_bytes();
|
| 149 |
const size_t memory_size_v = size_v_bytes();
|
| 150 |
|
| 151 |
-
LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u seqs), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
|
| 152 |
-
(float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), kv_size, (int) layers.size(), n_seq_max,
|
| 153 |
ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
|
| 154 |
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
|
| 155 |
}
|
|
@@ -158,7 +193,12 @@ llama_kv_cache_unified::llama_kv_cache_unified(
|
|
| 158 |
debug = LLAMA_KV_CACHE_DEBUG ? atoi(LLAMA_KV_CACHE_DEBUG) : 0;
|
| 159 |
|
| 160 |
const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS");
|
| 161 |
-
supports_set_rows = LLAMA_SET_ROWS ? atoi(LLAMA_SET_ROWS) : 0;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
|
| 163 |
if (!supports_set_rows) {
|
| 164 |
LLAMA_LOG_WARN("%s: LLAMA_SET_ROWS=0, using old ggml_cpy() method for backwards compatibility\n", __func__);
|
|
@@ -166,9 +206,10 @@ llama_kv_cache_unified::llama_kv_cache_unified(
|
|
| 166 |
}
|
| 167 |
|
| 168 |
void llama_kv_cache_unified::clear(bool data) {
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
|
|
|
| 172 |
|
| 173 |
if (data) {
|
| 174 |
for (auto & buf : bufs) {
|
|
@@ -178,6 +219,11 @@ void llama_kv_cache_unified::clear(bool data) {
|
|
| 178 |
}
|
| 179 |
|
| 180 |
bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
uint32_t new_head = cells.size();
|
| 182 |
|
| 183 |
if (p0 < 0) {
|
|
@@ -224,30 +270,94 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
|
|
| 224 |
}
|
| 225 |
|
| 226 |
void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
|
| 227 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 228 |
return;
|
| 229 |
}
|
| 230 |
|
| 231 |
-
|
| 232 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 233 |
}
|
| 234 |
|
| 235 |
-
if (p1 <
|
| 236 |
-
|
| 237 |
}
|
| 238 |
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
|
|
|
| 243 |
|
| 244 |
-
|
| 245 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 246 |
}
|
| 247 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 248 |
}
|
| 249 |
|
| 250 |
void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 251 |
uint32_t new_head = cells.size();
|
| 252 |
|
| 253 |
for (uint32_t i = 0; i < cells.size(); ++i) {
|
|
@@ -265,6 +375,11 @@ void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) {
|
|
| 265 |
}
|
| 266 |
|
| 267 |
void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 268 |
if (shift == 0) {
|
| 269 |
return;
|
| 270 |
}
|
|
@@ -304,6 +419,10 @@ void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_po
|
|
| 304 |
}
|
| 305 |
|
| 306 |
void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
|
|
|
|
|
|
|
|
|
|
|
|
|
| 307 |
if (d == 1) {
|
| 308 |
return;
|
| 309 |
}
|
|
@@ -333,10 +452,18 @@ void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_po
|
|
| 333 |
}
|
| 334 |
|
| 335 |
llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const {
|
|
|
|
|
|
|
|
|
|
|
|
|
| 336 |
return cells.seq_pos_min(seq_id);
|
| 337 |
}
|
| 338 |
|
| 339 |
llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
|
|
|
|
|
|
|
|
|
|
|
|
|
| 340 |
return cells.seq_pos_max(seq_id);
|
| 341 |
}
|
| 342 |
|
|
@@ -351,7 +478,7 @@ llama_memory_context_ptr llama_kv_cache_unified::init_batch(
|
|
| 351 |
|
| 352 |
std::vector<llama_ubatch> ubatches;
|
| 353 |
while (true) {
|
| 354 |
-
auto ubatch = balloc.split_simple(n_ubatch);
|
| 355 |
|
| 356 |
if (ubatch.n_tokens == 0) {
|
| 357 |
break;
|
|
@@ -387,7 +514,10 @@ llama_memory_context_ptr llama_kv_cache_unified::init_update(llama_context * lct
|
|
| 387 |
defrag_info dinfo;
|
| 388 |
|
| 389 |
// see if we need to defrag
|
| 390 |
-
{
|
|
|
|
|
|
|
|
|
|
| 391 |
bool do_defrag = optimize;
|
| 392 |
|
| 393 |
const auto thold = lctx->get_cparams().defrag_thold;
|
|
@@ -411,22 +541,22 @@ llama_memory_context_ptr llama_kv_cache_unified::init_update(llama_context * lct
|
|
| 411 |
}
|
| 412 |
}
|
| 413 |
|
| 414 |
-
return std::make_unique<llama_kv_cache_unified_context>(this, lctx, do_shift, std::move(dinfo));
|
| 415 |
}
|
| 416 |
|
| 417 |
llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
|
| 418 |
llama_kv_cache_unified::slot_info_vec_t res;
|
| 419 |
|
| 420 |
-
struct
|
| 421 |
-
uint32_t head_old; // old position of the head, before placing the ubatch
|
| 422 |
-
|
| 423 |
slot_info sinfo; // slot info for the ubatch
|
| 424 |
|
| 425 |
-
|
|
|
|
|
|
|
| 426 |
};
|
| 427 |
|
| 428 |
// remember the old state of the cells so we can restore it in the end
|
| 429 |
-
std::vector<
|
| 430 |
|
| 431 |
bool success = true;
|
| 432 |
|
|
@@ -445,16 +575,35 @@ llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const st
|
|
| 445 |
res.push_back(sinfo_new);
|
| 446 |
|
| 447 |
// store the old state of the cells in the recovery stack
|
| 448 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 449 |
|
| 450 |
// now emplace the ubatch
|
| 451 |
apply_ubatch(sinfo_new, ubatch);
|
| 452 |
}
|
| 453 |
|
|
|
|
|
|
|
| 454 |
// iterate backwards and restore the cells to their original state
|
| 455 |
for (auto it = states.rbegin(); it != states.rend(); ++it) {
|
| 456 |
-
|
| 457 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 458 |
}
|
| 459 |
|
| 460 |
if (!success) {
|
|
@@ -464,11 +613,38 @@ llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const st
|
|
| 464 |
return res;
|
| 465 |
}
|
| 466 |
|
| 467 |
-
bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const defrag_info & dinfo) {
|
| 468 |
bool updated = false;
|
| 469 |
|
| 470 |
auto * sched = lctx->get_sched();
|
| 471 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 472 |
if (do_shift) {
|
| 473 |
if (!get_can_shift()) {
|
| 474 |
GGML_ABORT("The current KV cache / model configuration does not support K-shift");
|
|
@@ -480,14 +656,11 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d
|
|
| 480 |
if (hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
|
| 481 |
ggml_backend_sched_reset(sched);
|
| 482 |
|
| 483 |
-
auto *
|
| 484 |
|
| 485 |
-
|
| 486 |
-
if (!res) {
|
| 487 |
-
LLAMA_LOG_ERROR("%s: failed to build graph for K-shift\n", __func__);
|
| 488 |
-
return updated;
|
| 489 |
-
}
|
| 490 |
|
|
|
|
| 491 |
if (!ggml_backend_sched_alloc_graph(sched, gf)) {
|
| 492 |
LLAMA_LOG_ERROR("%s: failed to allocate compute graph for K-shift\n", __func__);
|
| 493 |
return updated;
|
|
@@ -503,12 +676,20 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d
|
|
| 503 |
updated = true;
|
| 504 |
}
|
| 505 |
|
| 506 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 507 |
}
|
| 508 |
|
| 509 |
if (!dinfo.empty()) {
|
| 510 |
LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__);
|
| 511 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 512 |
// apply moves:
|
| 513 |
{
|
| 514 |
const auto n_kv = dinfo.ids.size();
|
|
@@ -529,14 +710,11 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d
|
|
| 529 |
|
| 530 |
ggml_backend_sched_reset(sched);
|
| 531 |
|
| 532 |
-
auto *
|
| 533 |
|
| 534 |
-
|
| 535 |
-
if (!res) {
|
| 536 |
-
LLAMA_LOG_ERROR("%s: failed to build graph for defrag\n", __func__);
|
| 537 |
-
return updated;
|
| 538 |
-
}
|
| 539 |
|
|
|
|
| 540 |
if (!ggml_backend_sched_alloc_graph(sched, gf)) {
|
| 541 |
LLAMA_LOG_ERROR("%s: failed to allocate compute graph for defrag\n", __func__);
|
| 542 |
return updated;
|
|
@@ -556,23 +734,13 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d
|
|
| 556 |
}
|
| 557 |
|
| 558 |
llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch, bool cont) const {
|
| 559 |
-
|
|
|
|
| 560 |
|
| 561 |
-
|
| 562 |
|
| 563 |
-
|
| 564 |
-
|
| 565 |
-
if (head_cur > cells.get_used() + 2*ubatch.n_tokens) {
|
| 566 |
-
head_cur = 0;
|
| 567 |
-
}
|
| 568 |
-
|
| 569 |
-
if (n_tokens > cells.size()) {
|
| 570 |
-
LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size());
|
| 571 |
-
return { };
|
| 572 |
-
}
|
| 573 |
-
|
| 574 |
-
if (debug > 0) {
|
| 575 |
-
LLAMA_LOG_DEBUG("%s: n = %5d, used = %5d, head = %5d, size = %5d, n_swa = %5d\n", __func__, cells.used_max_p1(), cells.get_used(), head, get_size(), n_swa);
|
| 576 |
|
| 577 |
if ((debug == 2 && n_swa > 0) || debug > 2) {
|
| 578 |
std::string ss;
|
|
@@ -629,86 +797,133 @@ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_
|
|
| 629 |
}
|
| 630 |
}
|
| 631 |
|
| 632 |
-
uint32_t
|
|
|
|
|
|
|
|
|
|
|
|
|
| 633 |
|
| 634 |
-
|
| 635 |
-
|
| 636 |
-
|
| 637 |
|
| 638 |
-
slot_info res
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 639 |
|
| 640 |
-
|
| 641 |
|
| 642 |
-
|
|
|
|
| 643 |
|
| 644 |
-
|
| 645 |
-
|
| 646 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 647 |
head_cur = 0;
|
| 648 |
-
continue;
|
| 649 |
}
|
| 650 |
|
| 651 |
-
|
| 652 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 653 |
|
| 654 |
-
|
| 655 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 656 |
|
| 657 |
-
|
| 658 |
-
|
| 659 |
-
// - the cell is occupied only by one sequence:
|
| 660 |
-
// - (disabled) mask causally, if the sequence is the same as the one we are inserting
|
| 661 |
-
// - mask SWA, using current max pos for that sequence in the cache
|
| 662 |
-
// always insert in the cell with minimum pos
|
| 663 |
-
bool can_use = cells.is_empty(idx);
|
| 664 |
|
| 665 |
-
|
| 666 |
-
|
| 667 |
|
| 668 |
-
//
|
| 669 |
-
//
|
| 670 |
-
//if (cells.seq_has(idx, seq_id)) {
|
| 671 |
-
// can_use = pos_cell >= pos;
|
| 672 |
-
//}
|
| 673 |
|
| 674 |
-
|
| 675 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 676 |
|
| 677 |
-
|
| 678 |
-
|
| 679 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 680 |
}
|
| 681 |
}
|
| 682 |
-
}
|
| 683 |
|
| 684 |
-
|
| 685 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 686 |
|
| 687 |
-
if (
|
| 688 |
-
idxs.push_back(idx);
|
| 689 |
-
} else {
|
| 690 |
break;
|
| 691 |
}
|
| 692 |
-
}
|
| 693 |
|
| 694 |
-
|
| 695 |
-
|
| 696 |
-
|
| 697 |
|
| 698 |
-
|
| 699 |
-
|
|
|
|
|
|
|
| 700 |
}
|
| 701 |
|
| 702 |
-
|
| 703 |
-
|
| 704 |
return { };
|
| 705 |
}
|
| 706 |
}
|
| 707 |
|
| 708 |
-
|
| 709 |
-
if (idxs.size() < n_tokens) {
|
| 710 |
-
res.clear();
|
| 711 |
-
}
|
| 712 |
|
| 713 |
return res;
|
| 714 |
}
|
|
@@ -717,41 +932,51 @@ void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_u
|
|
| 717 |
// keep track of the max sequence position that we would overwrite with this ubatch
|
| 718 |
// for non-SWA cache, this would be always empty
|
| 719 |
llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ];
|
| 720 |
-
for (
|
| 721 |
seq_pos_max_rm[s] = -1;
|
| 722 |
}
|
| 723 |
|
| 724 |
-
assert(ubatch.n_tokens == sinfo.
|
| 725 |
|
| 726 |
-
for (uint32_t
|
| 727 |
-
|
|
|
|
| 728 |
|
| 729 |
-
|
| 730 |
-
assert(cells.seq_count(idx) == 1);
|
| 731 |
|
| 732 |
-
const
|
| 733 |
-
const llama_pos pos = cells.pos_get(idx);
|
| 734 |
|
| 735 |
-
|
|
|
|
| 736 |
|
| 737 |
-
|
| 738 |
-
|
| 739 |
|
| 740 |
-
|
|
|
|
|
|
|
|
|
|
| 741 |
|
| 742 |
-
|
| 743 |
-
|
|
|
|
|
|
|
|
|
|
| 744 |
}
|
| 745 |
}
|
| 746 |
|
| 747 |
// note: we want to preserve the invariant that all positions between [pos_min, pos_max] for each sequence
|
| 748 |
// will be present in the cache. so we have to purge any position which is less than those we would overwrite
|
| 749 |
// ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092
|
| 750 |
-
for (
|
| 751 |
if (seq_pos_max_rm[s] == -1) {
|
| 752 |
continue;
|
| 753 |
}
|
| 754 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 755 |
if (cells.seq_pos_min(s) <= seq_pos_max_rm[s]) {
|
| 756 |
LLAMA_LOG_DEBUG("%s: purging positions [%d, %d] of sequence %d from KV cache\n",
|
| 757 |
__func__, cells.seq_pos_min(s), seq_pos_max_rm[s], s);
|
|
@@ -761,7 +986,11 @@ void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_u
|
|
| 761 |
}
|
| 762 |
|
| 763 |
// move the head at the end of the slot
|
| 764 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 765 |
}
|
| 766 |
|
| 767 |
bool llama_kv_cache_unified::get_can_shift() const {
|
|
@@ -769,49 +998,91 @@ bool llama_kv_cache_unified::get_can_shift() const {
|
|
| 769 |
}
|
| 770 |
|
| 771 |
uint32_t llama_kv_cache_unified::get_size() const {
|
|
|
|
|
|
|
| 772 |
return cells.size();
|
| 773 |
}
|
| 774 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 775 |
bool llama_kv_cache_unified::get_has_shift() const {
|
| 776 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 777 |
}
|
| 778 |
|
| 779 |
uint32_t llama_kv_cache_unified::get_n_kv() const {
|
| 780 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 781 |
}
|
| 782 |
|
| 783 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 784 |
const int32_t ikv = map_layer_ids.at(il);
|
| 785 |
|
| 786 |
auto * k = layers[ikv].k;
|
| 787 |
|
| 788 |
-
|
| 789 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 790 |
ggml_row_size(k->type, hparams.n_embd_head_k),
|
| 791 |
-
ggml_row_size(k->type,
|
| 792 |
-
|
|
|
|
| 793 |
}
|
| 794 |
|
| 795 |
-
ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const {
|
| 796 |
const int32_t ikv = map_layer_ids.at(il);
|
| 797 |
|
| 798 |
auto * v = layers[ikv].v;
|
| 799 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 800 |
if (!v_trans) {
|
| 801 |
// note: v->nb[1] <= v->nb[2]
|
| 802 |
-
return
|
| 803 |
-
hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv,
|
| 804 |
-
ggml_row_size(v->type, hparams.n_embd_head_v),
|
| 805 |
-
ggml_row_size(v->type,
|
| 806 |
-
|
|
|
|
| 807 |
}
|
| 808 |
|
| 809 |
// note: v->nb[1] > v->nb[2]
|
| 810 |
-
return
|
| 811 |
-
n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v,
|
| 812 |
-
ggml_row_size(v->type,
|
| 813 |
-
ggml_row_size(v->type,
|
| 814 |
-
|
|
|
|
| 815 |
}
|
| 816 |
|
| 817 |
ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const {
|
|
@@ -825,12 +1096,18 @@ ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_
|
|
| 825 |
k_cur = ggml_reshape_2d(ctx, k_cur, k->ne[0], n_tokens);
|
| 826 |
|
| 827 |
if (k_idxs && supports_set_rows) {
|
|
|
|
|
|
|
|
|
|
|
|
|
| 828 |
return ggml_set_rows(ctx, k, k_cur, k_idxs);
|
| 829 |
}
|
| 830 |
|
| 831 |
// TODO: fallback to old ggml_cpy() method for backwards compatibility
|
| 832 |
// will be removed when ggml_set_rows() is adopted by all backends
|
| 833 |
|
|
|
|
|
|
|
| 834 |
ggml_tensor * k_view = ggml_view_1d(ctx, k,
|
| 835 |
n_tokens*n_embd_k_gqa,
|
| 836 |
ggml_row_size(k->type, n_embd_k_gqa)*sinfo.head());
|
|
@@ -843,37 +1120,38 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
|
|
| 843 |
|
| 844 |
auto * v = layers[ikv].v;
|
| 845 |
|
| 846 |
-
const int64_t n_embd_v_gqa =
|
| 847 |
-
const int64_t n_tokens
|
| 848 |
|
| 849 |
v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens);
|
| 850 |
|
| 851 |
if (v_idxs && supports_set_rows) {
|
| 852 |
if (!v_trans) {
|
|
|
|
|
|
|
|
|
|
|
|
|
| 853 |
return ggml_set_rows(ctx, v, v_cur, v_idxs);
|
| 854 |
}
|
| 855 |
|
| 856 |
-
//
|
| 857 |
-
|
|
|
|
|
|
|
| 858 |
|
| 859 |
-
//
|
| 860 |
-
|
| 861 |
|
| 862 |
-
|
| 863 |
-
// however, above we take advantage that a row of single element is always continuous regardless of the row stride
|
| 864 |
-
//v_cur = ggml_transpose(ctx, v_cur);
|
| 865 |
-
//v_cur = ggml_cont_3d(ctx, v_cur, 1, v_cur->ne[0], v_cur->ne[1]);
|
| 866 |
|
| 867 |
-
// we broadcast the KV indices n_embd_v_gqa times
|
| 868 |
-
// v [1, n_kv, n_embd_v_gqa]
|
| 869 |
-
// v_cur [1, n_tokens, n_embd_v_gqa]
|
| 870 |
-
// v_idxs [n_tokens, 1, 1]
|
| 871 |
return ggml_set_rows(ctx, v_view, v_cur, v_idxs);
|
| 872 |
}
|
| 873 |
|
| 874 |
// TODO: fallback to old ggml_cpy() method for backwards compatibility
|
| 875 |
// will be removed when ggml_set_rows() is adopted by all backends
|
| 876 |
|
|
|
|
|
|
|
| 877 |
ggml_tensor * v_view = nullptr;
|
| 878 |
|
| 879 |
if (!v_trans) {
|
|
@@ -904,7 +1182,13 @@ ggml_tensor * llama_kv_cache_unified::build_input_k_idxs(ggml_context * ctx, con
|
|
| 904 |
ggml_tensor * llama_kv_cache_unified::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
|
| 905 |
const uint32_t n_tokens = ubatch.n_tokens;
|
| 906 |
|
| 907 |
-
ggml_tensor * v_idxs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 908 |
|
| 909 |
ggml_set_input(v_idxs);
|
| 910 |
|
|
@@ -917,12 +1201,17 @@ void llama_kv_cache_unified::set_input_k_idxs(ggml_tensor * dst, const llama_uba
|
|
| 917 |
}
|
| 918 |
|
| 919 |
const uint32_t n_tokens = ubatch->n_tokens;
|
|
|
|
| 920 |
|
| 921 |
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
|
| 922 |
int64_t * data = (int64_t *) dst->data;
|
| 923 |
|
| 924 |
-
for (
|
| 925 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 926 |
}
|
| 927 |
}
|
| 928 |
|
|
@@ -932,12 +1221,48 @@ void llama_kv_cache_unified::set_input_v_idxs(ggml_tensor * dst, const llama_uba
|
|
| 932 |
}
|
| 933 |
|
| 934 |
const uint32_t n_tokens = ubatch->n_tokens;
|
|
|
|
| 935 |
|
| 936 |
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
|
| 937 |
int64_t * data = (int64_t *) dst->data;
|
| 938 |
|
| 939 |
-
|
| 940 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 941 |
}
|
| 942 |
}
|
| 943 |
|
|
@@ -947,7 +1272,16 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
|
|
| 947 |
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
|
| 948 |
float * data = (float *) dst->data;
|
| 949 |
|
| 950 |
-
const int64_t n_kv
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 951 |
|
| 952 |
// Use only the previous KV cells of the correct sequence for each token of the ubatch.
|
| 953 |
// It's assumed that if a token in the batch has multiple sequences, they are equivalent.
|
|
@@ -961,70 +1295,57 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
|
|
| 961 |
// xxxxx-----
|
| 962 |
// xxxxx-----
|
| 963 |
// To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
|
|
|
|
| 964 |
for (uint32_t h = 0; h < 1; ++h) {
|
| 965 |
-
for (uint32_t
|
| 966 |
-
|
|
|
|
| 967 |
|
| 968 |
-
|
| 969 |
|
| 970 |
-
|
| 971 |
-
float f = 0.0f;
|
| 972 |
|
| 973 |
-
|
| 974 |
|
| 975 |
-
|
| 976 |
-
|
| 977 |
-
|
| 978 |
-
|
|
|
|
|
|
|
| 979 |
|
| 980 |
// mask the token if not the same sequence
|
| 981 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 982 |
|
| 983 |
// mask future tokens
|
| 984 |
-
|
|
|
|
|
|
|
| 985 |
|
| 986 |
// apply SWA if any
|
| 987 |
-
|
| 988 |
-
|
| 989 |
-
if (!masked && hparams.use_alibi) {
|
| 990 |
-
f = -std::abs(p0 - p1);
|
| 991 |
}
|
| 992 |
-
}
|
| 993 |
-
|
| 994 |
-
if (masked) {
|
| 995 |
-
f = -INFINITY;
|
| 996 |
-
}
|
| 997 |
-
|
| 998 |
-
data[h*(n_kv*n_tokens) + i*n_kv + j] = f;
|
| 999 |
-
}
|
| 1000 |
-
}
|
| 1001 |
|
| 1002 |
-
|
| 1003 |
-
if (data) {
|
| 1004 |
-
for (uint32_t i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
|
| 1005 |
-
for (uint32_t j = 0; j < n_kv; ++j) {
|
| 1006 |
-
data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
|
| 1007 |
}
|
| 1008 |
}
|
| 1009 |
}
|
| 1010 |
}
|
| 1011 |
}
|
| 1012 |
|
| 1013 |
-
void llama_kv_cache_unified::set_input_k_shift(ggml_tensor * dst) const {
|
| 1014 |
-
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
|
| 1015 |
-
|
| 1016 |
-
int32_t * data = (int32_t *) dst->data;
|
| 1017 |
-
|
| 1018 |
-
for (uint32_t i = 0; i < cells.size(); ++i) {
|
| 1019 |
-
data[i] = cells.is_empty(i) ? 0 : cells.get_shift(i);
|
| 1020 |
-
}
|
| 1021 |
-
}
|
| 1022 |
-
|
| 1023 |
void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
|
| 1024 |
const int64_t n_tokens = ubatch->n_tokens;
|
| 1025 |
|
|
|
|
|
|
|
|
|
|
| 1026 |
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
|
| 1027 |
-
GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
|
| 1028 |
|
| 1029 |
int32_t * data = (int32_t *) dst->data;
|
| 1030 |
|
|
@@ -1129,7 +1450,7 @@ public:
|
|
| 1129 |
|
| 1130 |
void set_input(const llama_ubatch * ubatch) override;
|
| 1131 |
|
| 1132 |
-
ggml_tensor * k_shift; // I32 [kv_size]
|
| 1133 |
|
| 1134 |
const llama_kv_cache_unified * kv_self;
|
| 1135 |
};
|
|
@@ -1142,20 +1463,20 @@ void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) {
|
|
| 1142 |
}
|
| 1143 |
}
|
| 1144 |
|
| 1145 |
-
|
| 1146 |
-
|
| 1147 |
-
|
| 1148 |
-
ggml_cgraph * gf) const {
|
| 1149 |
-
auto res = std::make_unique<llm_graph_result>();
|
| 1150 |
|
| 1151 |
const auto & n_embd_head_k = hparams.n_embd_head_k;
|
| 1152 |
//const auto & n_embd_head_v = hparams.n_embd_head_v;
|
| 1153 |
|
| 1154 |
auto inp = std::make_unique<llm_graph_input_k_shift>(this);
|
| 1155 |
|
| 1156 |
-
inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32,
|
| 1157 |
ggml_set_input(inp->k_shift);
|
| 1158 |
|
|
|
|
|
|
|
| 1159 |
for (const auto & layer : layers) {
|
| 1160 |
const uint32_t il = layer.il;
|
| 1161 |
|
|
@@ -1169,7 +1490,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
|
|
| 1169 |
|
| 1170 |
ggml_tensor * k =
|
| 1171 |
ggml_view_3d(ctx, layer.k,
|
| 1172 |
-
n_embd_head_k, n_head_kv,
|
| 1173 |
ggml_row_size(layer.k->type, n_embd_head_k),
|
| 1174 |
ggml_row_size(layer.k->type, n_embd_k_gqa),
|
| 1175 |
0);
|
|
@@ -1181,18 +1502,24 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
|
|
| 1181 |
|
| 1182 |
res->add_input(std::move(inp));
|
| 1183 |
|
| 1184 |
-
return
|
| 1185 |
}
|
| 1186 |
|
| 1187 |
-
|
| 1188 |
-
|
| 1189 |
-
|
| 1190 |
-
|
| 1191 |
-
|
| 1192 |
-
auto
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1193 |
|
| 1194 |
const auto & ids = dinfo.ids;
|
| 1195 |
|
|
|
|
|
|
|
| 1196 |
#if 0
|
| 1197 |
// CPU defrag
|
| 1198 |
//
|
|
@@ -1329,10 +1656,14 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
|
|
| 1329 |
//LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes);
|
| 1330 |
#endif
|
| 1331 |
|
| 1332 |
-
return
|
| 1333 |
}
|
| 1334 |
|
| 1335 |
llama_kv_cache_unified::defrag_info llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) const {
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1336 |
const uint32_t n_layer = layers.size();
|
| 1337 |
|
| 1338 |
const uint32_t n_kv = cells.used_max_p1();
|
|
@@ -1478,64 +1809,94 @@ bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const {
|
|
| 1478 |
}
|
| 1479 |
|
| 1480 |
void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
|
| 1481 |
-
|
| 1482 |
-
uint32_t cell_count = 0;
|
| 1483 |
|
| 1484 |
-
|
| 1485 |
-
|
| 1486 |
-
uint32_t cell_range_begin = cells.size();
|
| 1487 |
|
| 1488 |
-
|
| 1489 |
-
|
| 1490 |
-
|
| 1491 |
-
|
| 1492 |
-
|
| 1493 |
-
|
| 1494 |
-
|
| 1495 |
-
|
| 1496 |
-
|
| 1497 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1498 |
}
|
| 1499 |
}
|
| 1500 |
-
}
|
| 1501 |
|
| 1502 |
-
|
| 1503 |
-
|
| 1504 |
-
|
| 1505 |
|
| 1506 |
-
|
| 1507 |
-
|
| 1508 |
-
|
| 1509 |
-
|
| 1510 |
-
|
| 1511 |
-
|
| 1512 |
|
| 1513 |
-
|
| 1514 |
|
| 1515 |
-
|
| 1516 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1517 |
}
|
| 1518 |
|
| 1519 |
void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
|
| 1520 |
-
|
| 1521 |
-
io.read_to(&cell_count, sizeof(cell_count));
|
| 1522 |
|
| 1523 |
-
|
| 1524 |
-
|
| 1525 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1526 |
|
| 1527 |
-
|
| 1528 |
-
|
| 1529 |
-
|
| 1530 |
-
|
| 1531 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1532 |
}
|
| 1533 |
-
throw std::runtime_error("failed to restore kv cache");
|
| 1534 |
}
|
| 1535 |
}
|
| 1536 |
|
| 1537 |
-
void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const
|
| 1538 |
-
|
|
|
|
|
|
|
| 1539 |
for (uint32_t i = range.first; i < range.second; ++i) {
|
| 1540 |
std::vector<llama_seq_id> seq_ids;
|
| 1541 |
|
|
@@ -1560,7 +1921,9 @@ void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std::
|
|
| 1560 |
}
|
| 1561 |
}
|
| 1562 |
|
| 1563 |
-
void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const
|
|
|
|
|
|
|
| 1564 |
const uint32_t v_trans = this->v_trans ? 1 : 0;
|
| 1565 |
const uint32_t n_layer = layers.size();
|
| 1566 |
|
|
@@ -1576,19 +1939,21 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
|
|
| 1576 |
|
| 1577 |
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
| 1578 |
|
|
|
|
|
|
|
| 1579 |
// Write key type
|
| 1580 |
-
const int32_t k_type_i = (int32_t)
|
| 1581 |
io.write(&k_type_i, sizeof(k_type_i));
|
| 1582 |
|
| 1583 |
// Write row size of key
|
| 1584 |
-
const uint64_t k_size_row = ggml_row_size(
|
| 1585 |
io.write(&k_size_row, sizeof(k_size_row));
|
| 1586 |
|
| 1587 |
// Read each range of cells of k_size length each into tmp_buf and write out
|
| 1588 |
-
for (const auto & range :
|
| 1589 |
const size_t range_size = range.second - range.first;
|
| 1590 |
const size_t buf_size = range_size * k_size_row;
|
| 1591 |
-
io.write_tensor(
|
| 1592 |
}
|
| 1593 |
}
|
| 1594 |
|
|
@@ -1598,19 +1963,21 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
|
|
| 1598 |
|
| 1599 |
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
| 1600 |
|
|
|
|
|
|
|
| 1601 |
// Write value type
|
| 1602 |
-
const int32_t v_type_i = (int32_t)
|
| 1603 |
io.write(&v_type_i, sizeof(v_type_i));
|
| 1604 |
|
| 1605 |
// Write row size of value
|
| 1606 |
-
const uint64_t v_size_row = ggml_row_size(
|
| 1607 |
io.write(&v_size_row, sizeof(v_size_row));
|
| 1608 |
|
| 1609 |
// Read each range of cells of v_size length each into tmp_buf and write out
|
| 1610 |
-
for (const auto & range :
|
| 1611 |
const size_t range_size = range.second - range.first;
|
| 1612 |
const size_t buf_size = range_size * v_size_row;
|
| 1613 |
-
io.write_tensor(
|
| 1614 |
}
|
| 1615 |
}
|
| 1616 |
} else {
|
|
@@ -1622,12 +1989,14 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
|
|
| 1622 |
|
| 1623 |
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
| 1624 |
|
|
|
|
|
|
|
| 1625 |
// Write value type
|
| 1626 |
-
const int32_t v_type_i = (int32_t)
|
| 1627 |
io.write(&v_type_i, sizeof(v_type_i));
|
| 1628 |
|
| 1629 |
// Write element size
|
| 1630 |
-
const uint32_t v_size_el = ggml_type_size(
|
| 1631 |
io.write(&v_size_el, sizeof(v_size_el));
|
| 1632 |
|
| 1633 |
// Write GQA embedding size
|
|
@@ -1636,27 +2005,31 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
|
|
| 1636 |
// For each row, we get the element values of each cell
|
| 1637 |
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
|
| 1638 |
// Read each range of cells of v_size_el length each into tmp_buf and write out
|
| 1639 |
-
for (const auto & range :
|
| 1640 |
const size_t range_size = range.second - range.first;
|
| 1641 |
const size_t src_offset = (range.first + j * kv_size) * v_size_el;
|
| 1642 |
const size_t buf_size = range_size * v_size_el;
|
| 1643 |
-
io.write_tensor(
|
| 1644 |
}
|
| 1645 |
}
|
| 1646 |
}
|
| 1647 |
}
|
| 1648 |
}
|
| 1649 |
|
| 1650 |
-
bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) {
|
|
|
|
|
|
|
|
|
|
| 1651 |
if (dest_seq_id != -1) {
|
| 1652 |
// single sequence
|
| 1653 |
-
|
| 1654 |
seq_rm(dest_seq_id, -1, -1);
|
| 1655 |
|
| 1656 |
llama_batch_allocr balloc(hparams.n_pos_per_embd());
|
| 1657 |
|
| 1658 |
llama_ubatch ubatch = balloc.ubatch_reserve(cell_count, 1);
|
| 1659 |
|
|
|
|
|
|
|
| 1660 |
for (uint32_t i = 0; i < cell_count; ++i) {
|
| 1661 |
llama_pos pos;
|
| 1662 |
uint32_t n_seq_id;
|
|
@@ -1693,6 +2066,8 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
|
|
| 1693 |
// keep the head at the old position because we will read the KV data into it in state_read_data()
|
| 1694 |
head = head_cur;
|
| 1695 |
|
|
|
|
|
|
|
| 1696 |
// DEBUG CHECK: head_cur should be our first cell, head_cur + cell_count - 1 should be our last cell (verify seq_id and pos values)
|
| 1697 |
// Assume that this is one contiguous block of cells
|
| 1698 |
GGML_ASSERT(head_cur + cell_count <= cells.size());
|
|
@@ -1738,7 +2113,10 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
|
|
| 1738 |
return true;
|
| 1739 |
}
|
| 1740 |
|
| 1741 |
-
bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell_count) {
|
|
|
|
|
|
|
|
|
|
| 1742 |
uint32_t v_trans;
|
| 1743 |
uint32_t n_layer;
|
| 1744 |
|
|
@@ -1766,10 +2144,12 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
|
|
| 1766 |
|
| 1767 |
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
| 1768 |
|
|
|
|
|
|
|
| 1769 |
// Read type of key
|
| 1770 |
int32_t k_type_i_ref;
|
| 1771 |
io.read_to(&k_type_i_ref, sizeof(k_type_i_ref));
|
| 1772 |
-
const int32_t k_type_i = (int32_t)
|
| 1773 |
if (k_type_i != k_type_i_ref) {
|
| 1774 |
LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
|
| 1775 |
return false;
|
|
@@ -1778,7 +2158,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
|
|
| 1778 |
// Read row size of key
|
| 1779 |
uint64_t k_size_row_ref;
|
| 1780 |
io.read_to(&k_size_row_ref, sizeof(k_size_row_ref));
|
| 1781 |
-
const size_t k_size_row = ggml_row_size(
|
| 1782 |
if (k_size_row != k_size_row_ref) {
|
| 1783 |
LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il);
|
| 1784 |
return false;
|
|
@@ -1786,7 +2166,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
|
|
| 1786 |
|
| 1787 |
if (cell_count) {
|
| 1788 |
// Read and set the keys for the whole cell range
|
| 1789 |
-
ggml_backend_tensor_set(
|
| 1790 |
}
|
| 1791 |
}
|
| 1792 |
|
|
@@ -1796,10 +2176,12 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
|
|
| 1796 |
|
| 1797 |
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
| 1798 |
|
|
|
|
|
|
|
| 1799 |
// Read type of value
|
| 1800 |
int32_t v_type_i_ref;
|
| 1801 |
io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
|
| 1802 |
-
const int32_t v_type_i = (int32_t)
|
| 1803 |
if (v_type_i != v_type_i_ref) {
|
| 1804 |
LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
|
| 1805 |
return false;
|
|
@@ -1808,7 +2190,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
|
|
| 1808 |
// Read row size of value
|
| 1809 |
uint64_t v_size_row_ref;
|
| 1810 |
io.read_to(&v_size_row_ref, sizeof(v_size_row_ref));
|
| 1811 |
-
const size_t v_size_row = ggml_row_size(
|
| 1812 |
if (v_size_row != v_size_row_ref) {
|
| 1813 |
LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il);
|
| 1814 |
return false;
|
|
@@ -1816,7 +2198,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
|
|
| 1816 |
|
| 1817 |
if (cell_count) {
|
| 1818 |
// Read and set the values for the whole cell range
|
| 1819 |
-
ggml_backend_tensor_set(
|
| 1820 |
}
|
| 1821 |
}
|
| 1822 |
} else {
|
|
@@ -1826,10 +2208,12 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
|
|
| 1826 |
|
| 1827 |
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
| 1828 |
|
|
|
|
|
|
|
| 1829 |
// Read type of value
|
| 1830 |
int32_t v_type_i_ref;
|
| 1831 |
io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
|
| 1832 |
-
const int32_t v_type_i = (int32_t)
|
| 1833 |
if (v_type_i != v_type_i_ref) {
|
| 1834 |
LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
|
| 1835 |
return false;
|
|
@@ -1838,7 +2222,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
|
|
| 1838 |
// Read element size of value
|
| 1839 |
uint32_t v_size_el_ref;
|
| 1840 |
io.read_to(&v_size_el_ref, sizeof(v_size_el_ref));
|
| 1841 |
-
const size_t v_size_el = ggml_type_size(
|
| 1842 |
if (v_size_el != v_size_el_ref) {
|
| 1843 |
LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il);
|
| 1844 |
return false;
|
|
@@ -1856,7 +2240,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
|
|
| 1856 |
// For each row in the transposed matrix, read the values for the whole cell range
|
| 1857 |
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
|
| 1858 |
const size_t dst_offset = (head + j * cells.size()) * v_size_el;
|
| 1859 |
-
ggml_backend_tensor_set(
|
| 1860 |
}
|
| 1861 |
}
|
| 1862 |
}
|
|
@@ -1875,18 +2259,26 @@ llama_kv_cache_unified_context::llama_kv_cache_unified_context(
|
|
| 1875 |
llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) {
|
| 1876 |
n_kv = kv->get_size();
|
| 1877 |
|
|
|
|
|
|
|
| 1878 |
// create a dummy slot info - the actual data is irrelevant. we just need to build the graph
|
| 1879 |
sinfos.resize(1);
|
| 1880 |
-
sinfos[0].
|
| 1881 |
-
sinfos[0].
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1882 |
}
|
| 1883 |
|
| 1884 |
llama_kv_cache_unified_context::llama_kv_cache_unified_context(
|
| 1885 |
llama_kv_cache_unified * kv,
|
| 1886 |
llama_context * lctx,
|
| 1887 |
bool do_shift,
|
| 1888 |
-
defrag_info dinfo
|
| 1889 |
-
|
|
|
|
| 1890 |
status = LLAMA_MEMORY_STATUS_NO_UPDATE;
|
| 1891 |
}
|
| 1892 |
}
|
|
@@ -1914,7 +2306,7 @@ bool llama_kv_cache_unified_context::apply() {
|
|
| 1914 |
|
| 1915 |
// no ubatches -> this is a KV cache update
|
| 1916 |
if (ubatches.empty()) {
|
| 1917 |
-
kv->update(lctx, do_shift, dinfo);
|
| 1918 |
|
| 1919 |
return true;
|
| 1920 |
}
|
|
@@ -1940,12 +2332,16 @@ uint32_t llama_kv_cache_unified_context::get_n_kv() const {
|
|
| 1940 |
return n_kv;
|
| 1941 |
}
|
| 1942 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1943 |
ggml_tensor * llama_kv_cache_unified_context::get_k(ggml_context * ctx, int32_t il) const {
|
| 1944 |
-
return kv->get_k(ctx, il, n_kv);
|
| 1945 |
}
|
| 1946 |
|
| 1947 |
ggml_tensor * llama_kv_cache_unified_context::get_v(ggml_context * ctx, int32_t il) const {
|
| 1948 |
-
return kv->get_v(ctx, il, n_kv);
|
| 1949 |
}
|
| 1950 |
|
| 1951 |
ggml_tensor * llama_kv_cache_unified_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const {
|
|
|
|
| 23 |
ggml_type type_v,
|
| 24 |
bool v_trans,
|
| 25 |
bool offload,
|
| 26 |
+
bool unified,
|
| 27 |
uint32_t kv_size,
|
| 28 |
uint32_t n_seq_max,
|
| 29 |
uint32_t n_pad,
|
| 30 |
uint32_t n_swa,
|
| 31 |
llama_swa_type swa_type) :
|
| 32 |
model(model), hparams(model.hparams), v_trans(v_trans),
|
| 33 |
+
n_seq_max(n_seq_max), n_stream(unified ? 1 : n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) {
|
| 34 |
|
| 35 |
GGML_ASSERT(kv_size % n_pad == 0);
|
| 36 |
|
|
|
|
| 46 |
auto it = ctx_map.find(buft);
|
| 47 |
if (it == ctx_map.end()) {
|
| 48 |
ggml_init_params params = {
|
| 49 |
+
/*.mem_size =*/ size_t(2u*(1 + n_stream)*n_layer_cache*ggml_tensor_overhead()),
|
| 50 |
/*.mem_buffer =*/ NULL,
|
| 51 |
/*.no_alloc =*/ true,
|
| 52 |
};
|
|
|
|
| 65 |
return it->second;
|
| 66 |
};
|
| 67 |
|
| 68 |
+
GGML_ASSERT(n_stream == 1 || n_stream == n_seq_max);
|
| 69 |
|
| 70 |
+
v_heads.resize(n_stream);
|
| 71 |
+
for (uint32_t s = 0; s < n_stream; ++s) {
|
| 72 |
+
v_heads[s] = 0;
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
v_cells.resize(n_stream);
|
| 76 |
+
for (uint32_t s = 0; s < n_stream; ++s) {
|
| 77 |
+
v_cells[s].resize(kv_size);
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
// by default, all sequence ids are mapped to the 0th stream
|
| 81 |
+
seq_to_stream.resize(LLAMA_MAX_SEQ, 0);
|
| 82 |
+
|
| 83 |
+
if (n_stream > 1) {
|
| 84 |
+
seq_to_stream.resize(n_stream, 0);
|
| 85 |
+
for (uint32_t s = 0; s < n_stream; ++s) {
|
| 86 |
+
seq_to_stream[s] = s;
|
| 87 |
+
}
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
// [TAG_V_CACHE_VARIABLE]
|
| 91 |
+
if (v_trans && hparams.is_n_embd_v_gqa_variable()) {
|
| 92 |
+
LLAMA_LOG_WARN("%s: the V embeddings have different sizes across layers and FA is not enabled - padding V cache to %d\n",
|
| 93 |
+
__func__, hparams.n_embd_v_gqa_max());
|
| 94 |
+
}
|
| 95 |
|
| 96 |
for (uint32_t il = 0; il < n_layer_cache; il++) {
|
| 97 |
if (filter && !filter(il)) {
|
|
|
|
| 99 |
continue;
|
| 100 |
}
|
| 101 |
|
| 102 |
+
// [TAG_V_CACHE_VARIABLE]
|
| 103 |
+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
| 104 |
+
const uint32_t n_embd_v_gqa = !v_trans ? hparams.n_embd_v_gqa(il) : hparams.n_embd_v_gqa_max();
|
| 105 |
|
| 106 |
const char * dev_name = "CPU";
|
| 107 |
|
|
|
|
| 124 |
ggml_tensor * k;
|
| 125 |
ggml_tensor * v;
|
| 126 |
|
| 127 |
+
k = ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, n_stream);
|
| 128 |
+
v = ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, n_stream);
|
| 129 |
|
| 130 |
ggml_format_name(k, "cache_k_l%d", il);
|
| 131 |
ggml_format_name(v, "cache_v_l%d", il);
|
| 132 |
|
| 133 |
+
std::vector<ggml_tensor *> k_stream;
|
| 134 |
+
std::vector<ggml_tensor *> v_stream;
|
| 135 |
+
|
| 136 |
+
for (uint32_t s = 0; s < n_stream; ++s) {
|
| 137 |
+
k_stream.push_back(ggml_view_2d(ctx, k, n_embd_k_gqa, kv_size, k->nb[1], s*k->nb[2]));
|
| 138 |
+
v_stream.push_back(ggml_view_2d(ctx, v, n_embd_v_gqa, kv_size, v->nb[1], s*v->nb[2]));
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
map_layer_ids[il] = layers.size();
|
| 142 |
+
|
| 143 |
+
layers.push_back({ il, k, v, k_stream, v_stream, });
|
| 144 |
}
|
| 145 |
|
| 146 |
// TODO: this is temporary until we support passing reuse layer filters [KV_REUSE]
|
|
|
|
| 183 |
const size_t memory_size_k = size_k_bytes();
|
| 184 |
const size_t memory_size_v = size_v_bytes();
|
| 185 |
|
| 186 |
+
LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u/%2u seqs), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
|
| 187 |
+
(float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), kv_size, (int) layers.size(), n_seq_max, n_stream,
|
| 188 |
ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
|
| 189 |
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
|
| 190 |
}
|
|
|
|
| 193 |
debug = LLAMA_KV_CACHE_DEBUG ? atoi(LLAMA_KV_CACHE_DEBUG) : 0;
|
| 194 |
|
| 195 |
const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS");
|
| 196 |
+
supports_set_rows = LLAMA_SET_ROWS ? atoi(LLAMA_SET_ROWS) != 0 : 0;
|
| 197 |
+
|
| 198 |
+
if (!supports_set_rows) {
|
| 199 |
+
// ref: https://github.com/ggml-org/llama.cpp/pull/14363
|
| 200 |
+
GGML_ASSERT(unified && "cannot use non-unified KV cache without ggml_set_rows() support");
|
| 201 |
+
}
|
| 202 |
|
| 203 |
if (!supports_set_rows) {
|
| 204 |
LLAMA_LOG_WARN("%s: LLAMA_SET_ROWS=0, using old ggml_cpy() method for backwards compatibility\n", __func__);
|
|
|
|
| 206 |
}
|
| 207 |
|
| 208 |
void llama_kv_cache_unified::clear(bool data) {
|
| 209 |
+
for (uint32_t s = 0; s < n_stream; ++s) {
|
| 210 |
+
v_cells[s].reset();
|
| 211 |
+
v_heads[s] = 0;
|
| 212 |
+
}
|
| 213 |
|
| 214 |
if (data) {
|
| 215 |
for (auto & buf : bufs) {
|
|
|
|
| 219 |
}
|
| 220 |
|
| 221 |
bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
| 222 |
+
GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
|
| 223 |
+
|
| 224 |
+
auto & cells = v_cells[seq_to_stream[seq_id]];
|
| 225 |
+
auto & head = v_heads[seq_to_stream[seq_id]];
|
| 226 |
+
|
| 227 |
uint32_t new_head = cells.size();
|
| 228 |
|
| 229 |
if (p0 < 0) {
|
|
|
|
| 270 |
}
|
| 271 |
|
| 272 |
void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
|
| 273 |
+
GGML_ASSERT(seq_id_src >= 0 && (size_t) seq_id_src < seq_to_stream.size());
|
| 274 |
+
GGML_ASSERT(seq_id_dst >= 0 && (size_t) seq_id_dst < seq_to_stream.size());
|
| 275 |
+
|
| 276 |
+
const auto s0 = seq_to_stream[seq_id_src];
|
| 277 |
+
const auto s1 = seq_to_stream[seq_id_dst];
|
| 278 |
+
|
| 279 |
+
if (s0 == s1) {
|
| 280 |
+
// since both sequences are in the same stream, no data copy is necessary
|
| 281 |
+
// we just have to update the cells meta data
|
| 282 |
+
|
| 283 |
+
auto & cells = v_cells[s0];
|
| 284 |
+
|
| 285 |
+
if (seq_id_src == seq_id_dst) {
|
| 286 |
+
return;
|
| 287 |
+
}
|
| 288 |
+
|
| 289 |
+
if (p0 < 0) {
|
| 290 |
+
p0 = 0;
|
| 291 |
+
}
|
| 292 |
+
|
| 293 |
+
if (p1 < 0) {
|
| 294 |
+
p1 = std::numeric_limits<llama_pos>::max();
|
| 295 |
+
}
|
| 296 |
+
|
| 297 |
+
for (uint32_t i = 0; i < cells.size(); ++i) {
|
| 298 |
+
if (!cells.pos_in(i, p0, p1)) {
|
| 299 |
+
continue;
|
| 300 |
+
}
|
| 301 |
+
|
| 302 |
+
if (cells.seq_has(i, seq_id_src)) {
|
| 303 |
+
cells.seq_add(i, seq_id_dst);
|
| 304 |
+
}
|
| 305 |
+
}
|
| 306 |
+
|
| 307 |
return;
|
| 308 |
}
|
| 309 |
|
| 310 |
+
// cross-stream sequence copies require to copy the actual buffer data
|
| 311 |
+
|
| 312 |
+
bool is_full = true;
|
| 313 |
+
|
| 314 |
+
if (p0 > 0 && p0 + 1 < (int) get_size()) {
|
| 315 |
+
is_full = false;
|
| 316 |
}
|
| 317 |
|
| 318 |
+
if (p1 > 0 && p1 + 1 < (int) get_size()) {
|
| 319 |
+
is_full = false;
|
| 320 |
}
|
| 321 |
|
| 322 |
+
GGML_ASSERT(is_full && "seq_cp() is only supported for full KV buffers");
|
| 323 |
+
|
| 324 |
+
// enqueue the copy operation - the buffer copy will be performed during the next update
|
| 325 |
+
sc_info.ssrc.push_back(s0);
|
| 326 |
+
sc_info.sdst.push_back(s1);
|
| 327 |
|
| 328 |
+
v_cells[s1].reset();
|
| 329 |
+
for (uint32_t i = 0; i < v_cells[s0].size(); ++i) {
|
| 330 |
+
if (v_cells[s0].seq_has(i, seq_id_src)) {
|
| 331 |
+
llama_pos pos = v_cells[s0].pos_get(i);
|
| 332 |
+
llama_pos shift = v_cells[s0].get_shift(i);
|
| 333 |
+
|
| 334 |
+
if (shift != 0) {
|
| 335 |
+
pos -= shift;
|
| 336 |
+
assert(pos >= 0);
|
| 337 |
+
}
|
| 338 |
+
|
| 339 |
+
v_cells[s1].pos_set(i, pos);
|
| 340 |
+
v_cells[s1].seq_add(i, seq_id_dst);
|
| 341 |
+
|
| 342 |
+
if (shift != 0) {
|
| 343 |
+
v_cells[s1].pos_add(i, shift);
|
| 344 |
+
}
|
| 345 |
}
|
| 346 |
}
|
| 347 |
+
|
| 348 |
+
v_heads[s1] = v_heads[s0];
|
| 349 |
+
|
| 350 |
+
//for (uint32_t s = 0; s < n_stream; ++s) {
|
| 351 |
+
// LLAMA_LOG_WARN("%s: seq %d: min = %d, max = %d\n", __func__, s, v_cells[s].seq_pos_min(s), v_cells[s].seq_pos_max(s));
|
| 352 |
+
//}
|
| 353 |
}
|
| 354 |
|
| 355 |
void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) {
|
| 356 |
+
GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
|
| 357 |
+
|
| 358 |
+
auto & cells = v_cells[seq_to_stream[seq_id]];
|
| 359 |
+
auto & head = v_heads[seq_to_stream[seq_id]];
|
| 360 |
+
|
| 361 |
uint32_t new_head = cells.size();
|
| 362 |
|
| 363 |
for (uint32_t i = 0; i < cells.size(); ++i) {
|
|
|
|
| 375 |
}
|
| 376 |
|
| 377 |
void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
|
| 378 |
+
GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
|
| 379 |
+
|
| 380 |
+
auto & cells = v_cells[seq_to_stream[seq_id]];
|
| 381 |
+
auto & head = v_heads[seq_to_stream[seq_id]];
|
| 382 |
+
|
| 383 |
if (shift == 0) {
|
| 384 |
return;
|
| 385 |
}
|
|
|
|
| 419 |
}
|
| 420 |
|
| 421 |
void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
|
| 422 |
+
GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
|
| 423 |
+
|
| 424 |
+
auto & cells = v_cells[seq_to_stream[seq_id]];
|
| 425 |
+
|
| 426 |
if (d == 1) {
|
| 427 |
return;
|
| 428 |
}
|
|
|
|
| 452 |
}
|
| 453 |
|
| 454 |
llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const {
|
| 455 |
+
GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
|
| 456 |
+
|
| 457 |
+
const auto & cells = v_cells[seq_to_stream[seq_id]];
|
| 458 |
+
|
| 459 |
return cells.seq_pos_min(seq_id);
|
| 460 |
}
|
| 461 |
|
| 462 |
llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
|
| 463 |
+
GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
|
| 464 |
+
|
| 465 |
+
const auto & cells = v_cells[seq_to_stream[seq_id]];
|
| 466 |
+
|
| 467 |
return cells.seq_pos_max(seq_id);
|
| 468 |
}
|
| 469 |
|
|
|
|
| 478 |
|
| 479 |
std::vector<llama_ubatch> ubatches;
|
| 480 |
while (true) {
|
| 481 |
+
auto ubatch = n_stream == 1 ? balloc.split_simple(n_ubatch) : balloc.split_equal(n_ubatch, true);
|
| 482 |
|
| 483 |
if (ubatch.n_tokens == 0) {
|
| 484 |
break;
|
|
|
|
| 514 |
defrag_info dinfo;
|
| 515 |
|
| 516 |
// see if we need to defrag
|
| 517 |
+
if (n_stream == 1) {
|
| 518 |
+
// note : for now do not consider defrag for n_stream > 1
|
| 519 |
+
const auto & cells = v_cells[seq_to_stream[0]];
|
| 520 |
+
|
| 521 |
bool do_defrag = optimize;
|
| 522 |
|
| 523 |
const auto thold = lctx->get_cparams().defrag_thold;
|
|
|
|
| 541 |
}
|
| 542 |
}
|
| 543 |
|
| 544 |
+
return std::make_unique<llama_kv_cache_unified_context>(this, lctx, do_shift, std::move(dinfo), std::move(sc_info));
|
| 545 |
}
|
| 546 |
|
| 547 |
llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
|
| 548 |
llama_kv_cache_unified::slot_info_vec_t res;
|
| 549 |
|
| 550 |
+
struct state_t {
|
|
|
|
|
|
|
| 551 |
slot_info sinfo; // slot info for the ubatch
|
| 552 |
|
| 553 |
+
std::vector<uint32_t> v_heads_old; // old positions of the heads, before placing the ubatch
|
| 554 |
+
|
| 555 |
+
std::vector<llama_kv_cells_unified> v_cells; // copy of the old cells, before placing the ubatch
|
| 556 |
};
|
| 557 |
|
| 558 |
// remember the old state of the cells so we can restore it in the end
|
| 559 |
+
std::vector<state_t> states;
|
| 560 |
|
| 561 |
bool success = true;
|
| 562 |
|
|
|
|
| 575 |
res.push_back(sinfo_new);
|
| 576 |
|
| 577 |
// store the old state of the cells in the recovery stack
|
| 578 |
+
{
|
| 579 |
+
state_t state = { sinfo_new, v_heads, {} };
|
| 580 |
+
|
| 581 |
+
for (uint32_t s = 0; s < sinfo_new.n_stream(); ++s) {
|
| 582 |
+
auto & cells = v_cells[sinfo_new.strm[s]];
|
| 583 |
+
|
| 584 |
+
state.v_cells.push_back(cells.cp(sinfo_new.idxs[s]));
|
| 585 |
+
}
|
| 586 |
+
|
| 587 |
+
states.push_back(std::move(state));
|
| 588 |
+
}
|
| 589 |
|
| 590 |
// now emplace the ubatch
|
| 591 |
apply_ubatch(sinfo_new, ubatch);
|
| 592 |
}
|
| 593 |
|
| 594 |
+
GGML_ASSERT(!states.empty() || !success);
|
| 595 |
+
|
| 596 |
// iterate backwards and restore the cells to their original state
|
| 597 |
for (auto it = states.rbegin(); it != states.rend(); ++it) {
|
| 598 |
+
const auto & sinfo = it->sinfo;
|
| 599 |
+
|
| 600 |
+
for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
|
| 601 |
+
auto & cells = v_cells[sinfo.strm[s]];
|
| 602 |
+
auto & head = v_heads[sinfo.strm[s]];
|
| 603 |
+
|
| 604 |
+
cells.set(sinfo.idxs[s], it->v_cells[s]);
|
| 605 |
+
head = it->v_heads_old[s];
|
| 606 |
+
}
|
| 607 |
}
|
| 608 |
|
| 609 |
if (!success) {
|
|
|
|
| 613 |
return res;
|
| 614 |
}
|
| 615 |
|
| 616 |
+
bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const defrag_info & dinfo, const stream_copy_info & sc_info) {
|
| 617 |
bool updated = false;
|
| 618 |
|
| 619 |
auto * sched = lctx->get_sched();
|
| 620 |
|
| 621 |
+
if (!sc_info.empty()) {
|
| 622 |
+
assert(n_stream > 1 && "stream copy should never happen with a single stream");
|
| 623 |
+
|
| 624 |
+
llama_synchronize(lctx);
|
| 625 |
+
|
| 626 |
+
const size_t n_copy = sc_info.ssrc.size();
|
| 627 |
+
|
| 628 |
+
for (size_t i = 0; i < n_copy; ++i) {
|
| 629 |
+
const auto ssrc = sc_info.ssrc[i];
|
| 630 |
+
const auto sdst = sc_info.sdst[i];
|
| 631 |
+
|
| 632 |
+
assert(ssrc < n_stream);
|
| 633 |
+
assert(sdst < n_stream);
|
| 634 |
+
|
| 635 |
+
LLAMA_LOG_DEBUG("%s: copying KV buffer: stream %d to stream %d\n", __func__, ssrc, sdst);
|
| 636 |
+
|
| 637 |
+
assert(ssrc != sdst);
|
| 638 |
+
|
| 639 |
+
for (uint32_t il = 0; il < layers.size(); ++il) {
|
| 640 |
+
const auto & layer = layers[il];
|
| 641 |
+
|
| 642 |
+
ggml_backend_tensor_copy(layer.k_stream[ssrc], layer.k_stream[sdst]);
|
| 643 |
+
ggml_backend_tensor_copy(layer.v_stream[ssrc], layer.v_stream[sdst]);
|
| 644 |
+
}
|
| 645 |
+
}
|
| 646 |
+
}
|
| 647 |
+
|
| 648 |
if (do_shift) {
|
| 649 |
if (!get_can_shift()) {
|
| 650 |
GGML_ABORT("The current KV cache / model configuration does not support K-shift");
|
|
|
|
| 656 |
if (hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
|
| 657 |
ggml_backend_sched_reset(sched);
|
| 658 |
|
| 659 |
+
auto * res = lctx->get_gf_res_reserve();
|
| 660 |
|
| 661 |
+
res->reset();
|
|
|
|
|
|
|
|
|
|
|
|
|
| 662 |
|
| 663 |
+
auto * gf = build_graph_shift(res, lctx);
|
| 664 |
if (!ggml_backend_sched_alloc_graph(sched, gf)) {
|
| 665 |
LLAMA_LOG_ERROR("%s: failed to allocate compute graph for K-shift\n", __func__);
|
| 666 |
return updated;
|
|
|
|
| 676 |
updated = true;
|
| 677 |
}
|
| 678 |
|
| 679 |
+
for (uint32_t s = 0; s < n_stream; ++s) {
|
| 680 |
+
auto & cells = v_cells[s];
|
| 681 |
+
|
| 682 |
+
cells.reset_shift();
|
| 683 |
+
}
|
| 684 |
}
|
| 685 |
|
| 686 |
if (!dinfo.empty()) {
|
| 687 |
LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__);
|
| 688 |
|
| 689 |
+
// note: for now do not consider defrag for n_stream > 1
|
| 690 |
+
auto & cells = v_cells[seq_to_stream[0]];
|
| 691 |
+
auto & head = v_heads[seq_to_stream[0]];
|
| 692 |
+
|
| 693 |
// apply moves:
|
| 694 |
{
|
| 695 |
const auto n_kv = dinfo.ids.size();
|
|
|
|
| 710 |
|
| 711 |
ggml_backend_sched_reset(sched);
|
| 712 |
|
| 713 |
+
auto * res = lctx->get_gf_res_reserve();
|
| 714 |
|
| 715 |
+
res->reset();
|
|
|
|
|
|
|
|
|
|
|
|
|
| 716 |
|
| 717 |
+
auto * gf = build_graph_defrag(res, lctx, dinfo);
|
| 718 |
if (!ggml_backend_sched_alloc_graph(sched, gf)) {
|
| 719 |
LLAMA_LOG_ERROR("%s: failed to allocate compute graph for defrag\n", __func__);
|
| 720 |
return updated;
|
|
|
|
| 734 |
}
|
| 735 |
|
| 736 |
llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch, bool cont) const {
|
| 737 |
+
if (debug > 0) {
|
| 738 |
+
const auto & cells = v_cells[seq_to_stream[1]];
|
| 739 |
|
| 740 |
+
const uint32_t head_cur = v_heads[1];
|
| 741 |
|
| 742 |
+
LLAMA_LOG_DEBUG("%s: n = %5d, used = %5d, head = %5d, size = %5d, n_swa = %5d\n",
|
| 743 |
+
__func__, cells.used_max_p1(), cells.get_used(), head_cur, get_size(), n_swa);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 744 |
|
| 745 |
if ((debug == 2 && n_swa > 0) || debug > 2) {
|
| 746 |
std::string ss;
|
|
|
|
| 797 |
}
|
| 798 |
}
|
| 799 |
|
| 800 |
+
uint32_t n_tokens = ubatch.n_tokens;
|
| 801 |
+
uint32_t n_seqs = 1;
|
| 802 |
+
|
| 803 |
+
if (n_stream > 1) {
|
| 804 |
+
GGML_ASSERT(n_tokens % ubatch.n_seqs_unq == 0);
|
| 805 |
|
| 806 |
+
n_seqs = ubatch.n_seqs_unq;
|
| 807 |
+
n_tokens = n_tokens / n_seqs;
|
| 808 |
+
}
|
| 809 |
|
| 810 |
+
slot_info res = {
|
| 811 |
+
/*.s0 =*/ LLAMA_MAX_SEQ,
|
| 812 |
+
/*.s1 =*/ 0,
|
| 813 |
+
/*.strm =*/ { },
|
| 814 |
+
/*.idxs =*/ { },
|
| 815 |
+
};
|
| 816 |
|
| 817 |
+
res.resize(n_seqs);
|
| 818 |
|
| 819 |
+
for (uint32_t s = 0; s < n_seqs; ++s) {
|
| 820 |
+
const auto seq_id = ubatch.seq_id_unq[s];
|
| 821 |
|
| 822 |
+
if (n_stream > 1) {
|
| 823 |
+
GGML_ASSERT(ubatch.n_seq_id[s*n_tokens] == 1);
|
| 824 |
+
GGML_ASSERT(ubatch.seq_id [s*n_tokens][0] == seq_id);
|
| 825 |
+
}
|
| 826 |
+
|
| 827 |
+
res.s0 = std::min<llama_seq_id>(res.s0, seq_to_stream[seq_id]);
|
| 828 |
+
res.s1 = std::max<llama_seq_id>(res.s1, seq_to_stream[seq_id]);
|
| 829 |
+
|
| 830 |
+
res.strm[s] = seq_to_stream[seq_id];
|
| 831 |
+
res.idxs[s].reserve(n_tokens);
|
| 832 |
+
|
| 833 |
+
const auto & cells = v_cells[seq_to_stream[seq_id]];
|
| 834 |
+
|
| 835 |
+
uint32_t head_cur = v_heads[seq_to_stream[seq_id]];
|
| 836 |
+
|
| 837 |
+
// if we have enough unused cells before the current head ->
|
| 838 |
+
// better to start searching from the beginning of the cache, hoping to fill it
|
| 839 |
+
if (head_cur > cells.get_used() + 2*n_tokens) {
|
| 840 |
head_cur = 0;
|
|
|
|
| 841 |
}
|
| 842 |
|
| 843 |
+
if (n_tokens > cells.size()) {
|
| 844 |
+
LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size());
|
| 845 |
+
return { };
|
| 846 |
+
}
|
| 847 |
+
|
| 848 |
+
uint32_t n_tested = 0;
|
| 849 |
+
|
| 850 |
+
// for continuous slots, we test that all tokens in the ubatch fit, starting from the current head
|
| 851 |
+
// for non-continuous slots, we test the tokens one by one
|
| 852 |
+
const uint32_t n_test = cont ? n_tokens : 1;
|
| 853 |
|
| 854 |
+
while (true) {
|
| 855 |
+
if (head_cur + n_test > cells.size()) {
|
| 856 |
+
n_tested += cells.size() - head_cur;
|
| 857 |
+
head_cur = 0;
|
| 858 |
+
continue;
|
| 859 |
+
}
|
| 860 |
|
| 861 |
+
for (uint32_t i = 0; i < n_test; i++) {
|
| 862 |
+
const auto idx = head_cur;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 863 |
|
| 864 |
+
head_cur++;
|
| 865 |
+
n_tested++;
|
| 866 |
|
| 867 |
+
//const llama_pos pos = ubatch.pos[i];
|
| 868 |
+
//const llama_seq_id seq_id = ubatch.seq_id[i][0];
|
|
|
|
|
|
|
|
|
|
| 869 |
|
| 870 |
+
// can we use this cell? either:
|
| 871 |
+
// - the cell is empty
|
| 872 |
+
// - the cell is occupied only by one sequence:
|
| 873 |
+
// - (disabled) mask causally, if the sequence is the same as the one we are inserting
|
| 874 |
+
// - mask SWA, using current max pos for that sequence in the cache
|
| 875 |
+
// always insert in the cell with minimum pos
|
| 876 |
+
bool can_use = cells.is_empty(idx);
|
| 877 |
|
| 878 |
+
if (!can_use && cells.seq_count(idx) == 1) {
|
| 879 |
+
const llama_pos pos_cell = cells.pos_get(idx);
|
| 880 |
+
|
| 881 |
+
// (disabled) causal mask
|
| 882 |
+
// note: it's better to purge any "future" tokens beforehand
|
| 883 |
+
//if (cells.seq_has(idx, seq_id)) {
|
| 884 |
+
// can_use = pos_cell >= pos;
|
| 885 |
+
//}
|
| 886 |
+
|
| 887 |
+
if (!can_use) {
|
| 888 |
+
const llama_seq_id seq_id_cell = cells.seq_get(idx);
|
| 889 |
+
|
| 890 |
+
// SWA mask
|
| 891 |
+
if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
|
| 892 |
+
can_use = true;
|
| 893 |
+
}
|
| 894 |
}
|
| 895 |
}
|
|
|
|
| 896 |
|
| 897 |
+
if (can_use) {
|
| 898 |
+
res.idxs[s].push_back(idx);
|
| 899 |
+
} else {
|
| 900 |
+
if (cont) {
|
| 901 |
+
break;
|
| 902 |
+
}
|
| 903 |
+
}
|
| 904 |
+
}
|
| 905 |
|
| 906 |
+
if (res.idxs[s].size() == n_tokens) {
|
|
|
|
|
|
|
| 907 |
break;
|
| 908 |
}
|
|
|
|
| 909 |
|
| 910 |
+
if (cont) {
|
| 911 |
+
res.idxs[s].clear();
|
| 912 |
+
}
|
| 913 |
|
| 914 |
+
if (n_tested >= cells.size()) {
|
| 915 |
+
//LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
|
| 916 |
+
return { };
|
| 917 |
+
}
|
| 918 |
}
|
| 919 |
|
| 920 |
+
// we didn't find a suitable slot - return empty result
|
| 921 |
+
if (res.idxs[s].size() < n_tokens) {
|
| 922 |
return { };
|
| 923 |
}
|
| 924 |
}
|
| 925 |
|
| 926 |
+
assert(res.s1 >= res.s0);
|
|
|
|
|
|
|
|
|
|
| 927 |
|
| 928 |
return res;
|
| 929 |
}
|
|
|
|
| 932 |
// keep track of the max sequence position that we would overwrite with this ubatch
|
| 933 |
// for non-SWA cache, this would be always empty
|
| 934 |
llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ];
|
| 935 |
+
for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
| 936 |
seq_pos_max_rm[s] = -1;
|
| 937 |
}
|
| 938 |
|
| 939 |
+
assert(ubatch.n_tokens == sinfo.n_stream()*sinfo.size());
|
| 940 |
|
| 941 |
+
for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
|
| 942 |
+
for (uint32_t ii = 0; ii < sinfo.size(); ++ii) {
|
| 943 |
+
const uint32_t i = s*sinfo.size() + ii;
|
| 944 |
|
| 945 |
+
auto & cells = v_cells[sinfo.strm[s]];
|
|
|
|
| 946 |
|
| 947 |
+
const auto idx = sinfo.idxs[s][ii];
|
|
|
|
| 948 |
|
| 949 |
+
if (!cells.is_empty(idx)) {
|
| 950 |
+
assert(cells.seq_count(idx) == 1);
|
| 951 |
|
| 952 |
+
const llama_seq_id seq_id = cells.seq_get(idx);
|
| 953 |
+
const llama_pos pos = cells.pos_get(idx);
|
| 954 |
|
| 955 |
+
seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
|
| 956 |
+
|
| 957 |
+
cells.rm(idx);
|
| 958 |
+
}
|
| 959 |
|
| 960 |
+
cells.pos_set(idx, ubatch.pos[i]);
|
| 961 |
+
|
| 962 |
+
for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) {
|
| 963 |
+
cells.seq_add(idx, ubatch.seq_id[i][s]);
|
| 964 |
+
}
|
| 965 |
}
|
| 966 |
}
|
| 967 |
|
| 968 |
// note: we want to preserve the invariant that all positions between [pos_min, pos_max] for each sequence
|
| 969 |
// will be present in the cache. so we have to purge any position which is less than those we would overwrite
|
| 970 |
// ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092
|
| 971 |
+
for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
| 972 |
if (seq_pos_max_rm[s] == -1) {
|
| 973 |
continue;
|
| 974 |
}
|
| 975 |
|
| 976 |
+
GGML_ASSERT(s < seq_to_stream.size());
|
| 977 |
+
|
| 978 |
+
auto & cells = v_cells[seq_to_stream[s]];
|
| 979 |
+
|
| 980 |
if (cells.seq_pos_min(s) <= seq_pos_max_rm[s]) {
|
| 981 |
LLAMA_LOG_DEBUG("%s: purging positions [%d, %d] of sequence %d from KV cache\n",
|
| 982 |
__func__, cells.seq_pos_min(s), seq_pos_max_rm[s], s);
|
|
|
|
| 986 |
}
|
| 987 |
|
| 988 |
// move the head at the end of the slot
|
| 989 |
+
for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
|
| 990 |
+
auto & head = v_heads[sinfo.strm[s]];
|
| 991 |
+
|
| 992 |
+
head = sinfo.idxs[s].back() + 1;
|
| 993 |
+
}
|
| 994 |
}
|
| 995 |
|
| 996 |
bool llama_kv_cache_unified::get_can_shift() const {
|
|
|
|
| 998 |
}
|
| 999 |
|
| 1000 |
uint32_t llama_kv_cache_unified::get_size() const {
|
| 1001 |
+
const auto & cells = v_cells[seq_to_stream[0]];
|
| 1002 |
+
|
| 1003 |
return cells.size();
|
| 1004 |
}
|
| 1005 |
|
| 1006 |
+
uint32_t llama_kv_cache_unified::get_n_stream() const {
|
| 1007 |
+
return n_stream;
|
| 1008 |
+
}
|
| 1009 |
+
|
| 1010 |
bool llama_kv_cache_unified::get_has_shift() const {
|
| 1011 |
+
bool result = false;
|
| 1012 |
+
|
| 1013 |
+
for (uint32_t s = 0; s < n_stream; ++s) {
|
| 1014 |
+
result |= v_cells[s].get_has_shift();
|
| 1015 |
+
}
|
| 1016 |
+
|
| 1017 |
+
return result;
|
| 1018 |
}
|
| 1019 |
|
| 1020 |
uint32_t llama_kv_cache_unified::get_n_kv() const {
|
| 1021 |
+
uint32_t result = 0;
|
| 1022 |
+
|
| 1023 |
+
for (uint32_t s = 0; s < n_stream; ++s) {
|
| 1024 |
+
const auto & cells = v_cells[s];
|
| 1025 |
+
|
| 1026 |
+
result = std::max(std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad))), result);
|
| 1027 |
+
}
|
| 1028 |
+
|
| 1029 |
+
return result;
|
| 1030 |
}
|
| 1031 |
|
| 1032 |
+
bool llama_kv_cache_unified::get_supports_set_rows() const {
|
| 1033 |
+
return supports_set_rows;
|
| 1034 |
+
}
|
| 1035 |
+
|
| 1036 |
+
ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const {
|
| 1037 |
const int32_t ikv = map_layer_ids.at(il);
|
| 1038 |
|
| 1039 |
auto * k = layers[ikv].k;
|
| 1040 |
|
| 1041 |
+
const uint64_t kv_size = get_size();
|
| 1042 |
+
const uint64_t n_embd_k_gqa = k->ne[0];
|
| 1043 |
+
|
| 1044 |
+
assert(n_embd_k_gqa == hparams.n_embd_k_gqa(il));
|
| 1045 |
+
|
| 1046 |
+
const uint32_t ns = sinfo.s1 - sinfo.s0 + 1;
|
| 1047 |
+
|
| 1048 |
+
return ggml_view_4d(ctx, k,
|
| 1049 |
+
hparams.n_embd_head_k, hparams.n_head_kv(il), n_kv, ns,
|
| 1050 |
ggml_row_size(k->type, hparams.n_embd_head_k),
|
| 1051 |
+
ggml_row_size(k->type, n_embd_k_gqa),
|
| 1052 |
+
ggml_row_size(k->type, n_embd_k_gqa*kv_size),
|
| 1053 |
+
ggml_row_size(k->type, n_embd_k_gqa*kv_size)*sinfo.s0);
|
| 1054 |
}
|
| 1055 |
|
| 1056 |
+
ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const {
|
| 1057 |
const int32_t ikv = map_layer_ids.at(il);
|
| 1058 |
|
| 1059 |
auto * v = layers[ikv].v;
|
| 1060 |
|
| 1061 |
+
const uint64_t kv_size = get_size();
|
| 1062 |
+
const uint64_t n_embd_v_gqa = v->ne[0];
|
| 1063 |
+
|
| 1064 |
+
// [TAG_V_CACHE_VARIABLE]
|
| 1065 |
+
assert(n_embd_v_gqa >= hparams.n_embd_v_gqa(il));
|
| 1066 |
+
|
| 1067 |
+
const uint32_t ns = sinfo.s1 - sinfo.s0 + 1;
|
| 1068 |
+
|
| 1069 |
if (!v_trans) {
|
| 1070 |
// note: v->nb[1] <= v->nb[2]
|
| 1071 |
+
return ggml_view_4d(ctx, v,
|
| 1072 |
+
hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv, ns,
|
| 1073 |
+
ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1]
|
| 1074 |
+
ggml_row_size(v->type, n_embd_v_gqa), // v->nb[2]
|
| 1075 |
+
ggml_row_size(v->type, n_embd_v_gqa*kv_size), // v->nb[3]
|
| 1076 |
+
ggml_row_size(v->type, n_embd_v_gqa*kv_size)*sinfo.s0);
|
| 1077 |
}
|
| 1078 |
|
| 1079 |
// note: v->nb[1] > v->nb[2]
|
| 1080 |
+
return ggml_view_4d(ctx, v,
|
| 1081 |
+
n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v, ns,
|
| 1082 |
+
ggml_row_size(v->type, kv_size*hparams.n_embd_head_v), // v->nb[1]
|
| 1083 |
+
ggml_row_size(v->type, kv_size), // v->nb[2]
|
| 1084 |
+
ggml_row_size(v->type, kv_size*n_embd_v_gqa), // v->nb[3]
|
| 1085 |
+
ggml_row_size(v->type, kv_size*n_embd_v_gqa)*sinfo.s0);
|
| 1086 |
}
|
| 1087 |
|
| 1088 |
ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const {
|
|
|
|
| 1096 |
k_cur = ggml_reshape_2d(ctx, k_cur, k->ne[0], n_tokens);
|
| 1097 |
|
| 1098 |
if (k_idxs && supports_set_rows) {
|
| 1099 |
+
if (k->ne[2] > 1) {
|
| 1100 |
+
k = ggml_reshape_2d(ctx, k, k->ne[0], k->ne[1]*k->ne[2]);
|
| 1101 |
+
}
|
| 1102 |
+
|
| 1103 |
return ggml_set_rows(ctx, k, k_cur, k_idxs);
|
| 1104 |
}
|
| 1105 |
|
| 1106 |
// TODO: fallback to old ggml_cpy() method for backwards compatibility
|
| 1107 |
// will be removed when ggml_set_rows() is adopted by all backends
|
| 1108 |
|
| 1109 |
+
GGML_ASSERT(n_stream == 1 && "n_stream > 1 not supported without LLAMA_SET_ROWS");
|
| 1110 |
+
|
| 1111 |
ggml_tensor * k_view = ggml_view_1d(ctx, k,
|
| 1112 |
n_tokens*n_embd_k_gqa,
|
| 1113 |
ggml_row_size(k->type, n_embd_k_gqa)*sinfo.head());
|
|
|
|
| 1120 |
|
| 1121 |
auto * v = layers[ikv].v;
|
| 1122 |
|
| 1123 |
+
const int64_t n_embd_v_gqa = v_cur->ne[0]*v_cur->ne[1];
|
| 1124 |
+
const int64_t n_tokens = v_cur->ne[2];
|
| 1125 |
|
| 1126 |
v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens);
|
| 1127 |
|
| 1128 |
if (v_idxs && supports_set_rows) {
|
| 1129 |
if (!v_trans) {
|
| 1130 |
+
if (v->ne[2] > 1) {
|
| 1131 |
+
v = ggml_reshape_2d(ctx, v, v->ne[0], v->ne[1]*v->ne[2]);
|
| 1132 |
+
}
|
| 1133 |
+
|
| 1134 |
return ggml_set_rows(ctx, v, v_cur, v_idxs);
|
| 1135 |
}
|
| 1136 |
|
| 1137 |
+
// [TAG_V_CACHE_VARIABLE]
|
| 1138 |
+
if (n_embd_v_gqa < v->ne[0]) {
|
| 1139 |
+
v_cur = ggml_pad(ctx, v_cur, v->ne[0] - n_embd_v_gqa, 0, 0, 0);
|
| 1140 |
+
}
|
| 1141 |
|
| 1142 |
+
// the row becomes a single element
|
| 1143 |
+
ggml_tensor * v_view = ggml_reshape_2d(ctx, v, 1, v->ne[0]*v->ne[1]*v->ne[2]);
|
| 1144 |
|
| 1145 |
+
v_cur = ggml_reshape_2d(ctx, v_cur, 1, v_cur->ne[0]*v_cur->ne[1]);
|
|
|
|
|
|
|
|
|
|
| 1146 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1147 |
return ggml_set_rows(ctx, v_view, v_cur, v_idxs);
|
| 1148 |
}
|
| 1149 |
|
| 1150 |
// TODO: fallback to old ggml_cpy() method for backwards compatibility
|
| 1151 |
// will be removed when ggml_set_rows() is adopted by all backends
|
| 1152 |
|
| 1153 |
+
GGML_ASSERT(n_stream == 1 && "n_stream > 1 not supported without LLAMA_SET_ROWS");
|
| 1154 |
+
|
| 1155 |
ggml_tensor * v_view = nullptr;
|
| 1156 |
|
| 1157 |
if (!v_trans) {
|
|
|
|
| 1182 |
ggml_tensor * llama_kv_cache_unified::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
|
| 1183 |
const uint32_t n_tokens = ubatch.n_tokens;
|
| 1184 |
|
| 1185 |
+
ggml_tensor * v_idxs;
|
| 1186 |
+
|
| 1187 |
+
if (!v_trans) {
|
| 1188 |
+
v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens);
|
| 1189 |
+
} else {
|
| 1190 |
+
v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens*hparams.n_embd_v_gqa_max());
|
| 1191 |
+
}
|
| 1192 |
|
| 1193 |
ggml_set_input(v_idxs);
|
| 1194 |
|
|
|
|
| 1201 |
}
|
| 1202 |
|
| 1203 |
const uint32_t n_tokens = ubatch->n_tokens;
|
| 1204 |
+
GGML_ASSERT(n_tokens == (int64_t) sinfo.size()*sinfo.n_stream());
|
| 1205 |
|
| 1206 |
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
|
| 1207 |
int64_t * data = (int64_t *) dst->data;
|
| 1208 |
|
| 1209 |
+
for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
|
| 1210 |
+
const int64_t offs = sinfo.strm[s]*get_size();
|
| 1211 |
+
|
| 1212 |
+
for (uint32_t i = 0; i < sinfo.size(); ++i) {
|
| 1213 |
+
data[s*sinfo.size() + i] = offs + sinfo.idxs[s][i];
|
| 1214 |
+
}
|
| 1215 |
}
|
| 1216 |
}
|
| 1217 |
|
|
|
|
| 1221 |
}
|
| 1222 |
|
| 1223 |
const uint32_t n_tokens = ubatch->n_tokens;
|
| 1224 |
+
GGML_ASSERT(n_tokens == (int64_t) sinfo.size()*sinfo.n_stream());
|
| 1225 |
|
| 1226 |
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
|
| 1227 |
int64_t * data = (int64_t *) dst->data;
|
| 1228 |
|
| 1229 |
+
if (!v_trans) {
|
| 1230 |
+
for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
|
| 1231 |
+
const int64_t offs = sinfo.strm[s]*get_size();
|
| 1232 |
+
|
| 1233 |
+
for (uint32_t i = 0; i < sinfo.size(); ++i) {
|
| 1234 |
+
data[s*sinfo.size() + i] = offs + sinfo.idxs[s][i];
|
| 1235 |
+
}
|
| 1236 |
+
}
|
| 1237 |
+
} else {
|
| 1238 |
+
// note: the V cache is transposed when not using flash attention
|
| 1239 |
+
const int64_t kv_size = get_size();
|
| 1240 |
+
|
| 1241 |
+
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa_max();
|
| 1242 |
+
|
| 1243 |
+
for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
|
| 1244 |
+
const int64_t offs = sinfo.strm[s]*kv_size*n_embd_v_gqa;
|
| 1245 |
+
|
| 1246 |
+
for (uint32_t i = 0; i < sinfo.size(); ++i) {
|
| 1247 |
+
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
|
| 1248 |
+
data[s*sinfo.size()*n_embd_v_gqa + i*n_embd_v_gqa + j] = offs + j*kv_size + sinfo.idxs[s][i];
|
| 1249 |
+
}
|
| 1250 |
+
}
|
| 1251 |
+
}
|
| 1252 |
+
}
|
| 1253 |
+
}
|
| 1254 |
+
|
| 1255 |
+
void llama_kv_cache_unified::set_input_k_shift(ggml_tensor * dst) const {
|
| 1256 |
+
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
|
| 1257 |
+
|
| 1258 |
+
int32_t * data = (int32_t *) dst->data;
|
| 1259 |
+
|
| 1260 |
+
for (uint32_t s = 0; s < n_stream; ++s) {
|
| 1261 |
+
const auto & cells = v_cells[s];
|
| 1262 |
+
|
| 1263 |
+
for (uint32_t i = 0; i < cells.size(); ++i) {
|
| 1264 |
+
data[s*cells.size() + i] = cells.is_empty(i) ? 0 : cells.get_shift(i);
|
| 1265 |
+
}
|
| 1266 |
}
|
| 1267 |
}
|
| 1268 |
|
|
|
|
| 1272 |
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
|
| 1273 |
float * data = (float *) dst->data;
|
| 1274 |
|
| 1275 |
+
const int64_t n_kv = dst->ne[0];
|
| 1276 |
+
const int64_t n_stream = dst->ne[3]; // num streams in the current ubatch
|
| 1277 |
+
|
| 1278 |
+
GGML_ASSERT(n_tokens%n_stream == 0);
|
| 1279 |
+
|
| 1280 |
+
// n_tps == n_tokens_per_stream
|
| 1281 |
+
const int64_t n_tps = n_tokens/n_stream;
|
| 1282 |
+
const int64_t n_tps_pad = GGML_PAD(n_tps, GGML_KQ_MASK_PAD);
|
| 1283 |
+
|
| 1284 |
+
std::fill(data, data + ggml_nelements(dst), -INFINITY);
|
| 1285 |
|
| 1286 |
// Use only the previous KV cells of the correct sequence for each token of the ubatch.
|
| 1287 |
// It's assumed that if a token in the batch has multiple sequences, they are equivalent.
|
|
|
|
| 1295 |
// xxxxx-----
|
| 1296 |
// xxxxx-----
|
| 1297 |
// To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
|
| 1298 |
+
// TODO: optimize this section
|
| 1299 |
for (uint32_t h = 0; h < 1; ++h) {
|
| 1300 |
+
for (uint32_t s = 0; s < n_stream; ++s) {
|
| 1301 |
+
for (uint32_t ii = 0; ii < n_tps; ++ii) {
|
| 1302 |
+
const uint32_t i = s*n_tps + ii;
|
| 1303 |
|
| 1304 |
+
const llama_seq_id seq_id = ubatch->seq_id[i][0];
|
| 1305 |
|
| 1306 |
+
const auto & cells = v_cells[seq_to_stream[seq_id]];
|
|
|
|
| 1307 |
|
| 1308 |
+
const llama_pos p1 = ubatch->pos[i];
|
| 1309 |
|
| 1310 |
+
const uint64_t idst = n_kv*(h*n_stream*n_tps_pad + s*n_tps_pad + ii);
|
| 1311 |
+
|
| 1312 |
+
for (uint32_t j = 0; j < n_kv; ++j) {
|
| 1313 |
+
if (cells.is_empty(j)) {
|
| 1314 |
+
continue;
|
| 1315 |
+
}
|
| 1316 |
|
| 1317 |
// mask the token if not the same sequence
|
| 1318 |
+
if (!cells.seq_has(j, seq_id)) {
|
| 1319 |
+
continue;
|
| 1320 |
+
}
|
| 1321 |
+
|
| 1322 |
+
const llama_pos p0 = cells.pos_get(j);
|
| 1323 |
|
| 1324 |
// mask future tokens
|
| 1325 |
+
if (causal_attn && p0 > p1) {
|
| 1326 |
+
continue;
|
| 1327 |
+
}
|
| 1328 |
|
| 1329 |
// apply SWA if any
|
| 1330 |
+
if (is_masked_swa(p0, p1)) {
|
| 1331 |
+
continue;
|
|
|
|
|
|
|
| 1332 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1333 |
|
| 1334 |
+
data[idst + j] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1335 |
}
|
| 1336 |
}
|
| 1337 |
}
|
| 1338 |
}
|
| 1339 |
}
|
| 1340 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1341 |
void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
|
| 1342 |
const int64_t n_tokens = ubatch->n_tokens;
|
| 1343 |
|
| 1344 |
+
GGML_ASSERT(n_stream == 1 && "TODO: support multiple streams");
|
| 1345 |
+
const auto & cells = v_cells[0];
|
| 1346 |
+
|
| 1347 |
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
|
| 1348 |
+
GGML_ASSERT(!ubatch->equal_seqs()); // TODO: use ubatch->n_seqs instead of failing
|
| 1349 |
|
| 1350 |
int32_t * data = (int32_t *) dst->data;
|
| 1351 |
|
|
|
|
| 1450 |
|
| 1451 |
void set_input(const llama_ubatch * ubatch) override;
|
| 1452 |
|
| 1453 |
+
ggml_tensor * k_shift; // I32 [kv_size*n_stream]
|
| 1454 |
|
| 1455 |
const llama_kv_cache_unified * kv_self;
|
| 1456 |
};
|
|
|
|
| 1463 |
}
|
| 1464 |
}
|
| 1465 |
|
| 1466 |
+
ggml_cgraph * llama_kv_cache_unified::build_graph_shift(llm_graph_result * res, llama_context * lctx) const {
|
| 1467 |
+
auto * ctx = res->get_ctx();
|
| 1468 |
+
auto * gf = res->get_gf();
|
|
|
|
|
|
|
| 1469 |
|
| 1470 |
const auto & n_embd_head_k = hparams.n_embd_head_k;
|
| 1471 |
//const auto & n_embd_head_v = hparams.n_embd_head_v;
|
| 1472 |
|
| 1473 |
auto inp = std::make_unique<llm_graph_input_k_shift>(this);
|
| 1474 |
|
| 1475 |
+
inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, (int64_t) get_size()*n_stream);
|
| 1476 |
ggml_set_input(inp->k_shift);
|
| 1477 |
|
| 1478 |
+
const auto & cparams = lctx->get_cparams();
|
| 1479 |
+
|
| 1480 |
for (const auto & layer : layers) {
|
| 1481 |
const uint32_t il = layer.il;
|
| 1482 |
|
|
|
|
| 1490 |
|
| 1491 |
ggml_tensor * k =
|
| 1492 |
ggml_view_3d(ctx, layer.k,
|
| 1493 |
+
n_embd_head_k, n_head_kv, get_size()*n_stream,
|
| 1494 |
ggml_row_size(layer.k->type, n_embd_head_k),
|
| 1495 |
ggml_row_size(layer.k->type, n_embd_k_gqa),
|
| 1496 |
0);
|
|
|
|
| 1502 |
|
| 1503 |
res->add_input(std::move(inp));
|
| 1504 |
|
| 1505 |
+
return gf;
|
| 1506 |
}
|
| 1507 |
|
| 1508 |
+
ggml_cgraph * llama_kv_cache_unified::build_graph_defrag(
|
| 1509 |
+
llm_graph_result * res,
|
| 1510 |
+
llama_context * lctx,
|
| 1511 |
+
const defrag_info & dinfo) const {
|
| 1512 |
+
auto * ctx = res->get_ctx();
|
| 1513 |
+
auto * gf = res->get_gf();
|
| 1514 |
+
|
| 1515 |
+
GGML_ASSERT(n_stream == 1 && "n_stream > 1 does not support defrag");
|
| 1516 |
+
|
| 1517 |
+
const auto & cells = v_cells[0];
|
| 1518 |
|
| 1519 |
const auto & ids = dinfo.ids;
|
| 1520 |
|
| 1521 |
+
const auto & cparams = lctx->get_cparams();
|
| 1522 |
+
|
| 1523 |
#if 0
|
| 1524 |
// CPU defrag
|
| 1525 |
//
|
|
|
|
| 1656 |
//LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes);
|
| 1657 |
#endif
|
| 1658 |
|
| 1659 |
+
return gf;
|
| 1660 |
}
|
| 1661 |
|
| 1662 |
llama_kv_cache_unified::defrag_info llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) const {
|
| 1663 |
+
GGML_ASSERT(n_stream == 1 && "n_stream > 1 does not support defrag");
|
| 1664 |
+
|
| 1665 |
+
const auto & cells = v_cells[0];
|
| 1666 |
+
|
| 1667 |
const uint32_t n_layer = layers.size();
|
| 1668 |
|
| 1669 |
const uint32_t n_kv = cells.used_max_p1();
|
|
|
|
| 1809 |
}
|
| 1810 |
|
| 1811 |
void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
|
| 1812 |
+
io.write(&n_stream, sizeof(n_stream));
|
|
|
|
| 1813 |
|
| 1814 |
+
for (uint32_t s = 0; s < n_stream; ++s) {
|
| 1815 |
+
cell_ranges_t cr { s, {} };
|
|
|
|
| 1816 |
|
| 1817 |
+
uint32_t cell_count = 0;
|
| 1818 |
+
|
| 1819 |
+
const auto & cells = v_cells[s];
|
| 1820 |
+
|
| 1821 |
+
// Count the number of cells with the specified seq_id
|
| 1822 |
+
// Find all the ranges of cells with this seq id (or all, when -1)
|
| 1823 |
+
uint32_t cell_range_begin = cells.size();
|
| 1824 |
+
|
| 1825 |
+
for (uint32_t i = 0; i < cells.size(); ++i) {
|
| 1826 |
+
if (!cells.is_empty(i) && (seq_id == -1 || cells.seq_has(i, seq_id))) {
|
| 1827 |
+
++cell_count;
|
| 1828 |
+
if (cell_range_begin == cells.size()) {
|
| 1829 |
+
cell_range_begin = i;
|
| 1830 |
+
}
|
| 1831 |
+
} else {
|
| 1832 |
+
if (cell_range_begin != cells.size()) {
|
| 1833 |
+
cr.data.emplace_back(cell_range_begin, i);
|
| 1834 |
+
cell_range_begin = cells.size();
|
| 1835 |
+
}
|
| 1836 |
}
|
| 1837 |
}
|
|
|
|
| 1838 |
|
| 1839 |
+
if (cell_range_begin != cells.size()) {
|
| 1840 |
+
cr.data.emplace_back(cell_range_begin, cells.size());
|
| 1841 |
+
}
|
| 1842 |
|
| 1843 |
+
// DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
|
| 1844 |
+
uint32_t cell_count_check = 0;
|
| 1845 |
+
for (const auto & range : cr.data) {
|
| 1846 |
+
cell_count_check += range.second - range.first;
|
| 1847 |
+
}
|
| 1848 |
+
GGML_ASSERT(cell_count == cell_count_check);
|
| 1849 |
|
| 1850 |
+
io.write(&cell_count, sizeof(cell_count));
|
| 1851 |
|
| 1852 |
+
// skip empty streams
|
| 1853 |
+
if (cell_count == 0) {
|
| 1854 |
+
continue;
|
| 1855 |
+
}
|
| 1856 |
+
|
| 1857 |
+
state_write_meta(io, cr, seq_id);
|
| 1858 |
+
state_write_data(io, cr);
|
| 1859 |
+
}
|
| 1860 |
}
|
| 1861 |
|
| 1862 |
void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
|
| 1863 |
+
GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()));
|
|
|
|
| 1864 |
|
| 1865 |
+
uint32_t n_stream_cur;
|
| 1866 |
+
io.read_to(&n_stream_cur, sizeof(n_stream_cur));
|
| 1867 |
+
if (n_stream_cur != n_stream) {
|
| 1868 |
+
throw std::runtime_error("n_stream mismatch");
|
| 1869 |
+
}
|
| 1870 |
+
|
| 1871 |
+
for (uint32_t s = 0; s < n_stream; ++s) {
|
| 1872 |
+
uint32_t cell_count;
|
| 1873 |
+
io.read_to(&cell_count, sizeof(cell_count));
|
| 1874 |
+
|
| 1875 |
+
if (cell_count == 0) {
|
| 1876 |
+
continue;
|
| 1877 |
+
}
|
| 1878 |
|
| 1879 |
+
const uint32_t strm = seq_id == -1 ? s : seq_to_stream[seq_id];
|
| 1880 |
+
|
| 1881 |
+
bool res = true;
|
| 1882 |
+
res = res && state_read_meta(io, strm, cell_count, seq_id);
|
| 1883 |
+
res = res && state_read_data(io, strm, cell_count);
|
| 1884 |
+
|
| 1885 |
+
if (!res) {
|
| 1886 |
+
if (seq_id == -1) {
|
| 1887 |
+
clear(true);
|
| 1888 |
+
} else {
|
| 1889 |
+
seq_rm(seq_id, -1, -1);
|
| 1890 |
+
}
|
| 1891 |
+
throw std::runtime_error("failed to restore kv cache");
|
| 1892 |
}
|
|
|
|
| 1893 |
}
|
| 1894 |
}
|
| 1895 |
|
| 1896 |
+
void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const cell_ranges_t & cr, llama_seq_id seq_id) const {
|
| 1897 |
+
const auto & cells = v_cells[cr.strm];
|
| 1898 |
+
|
| 1899 |
+
for (const auto & range : cr.data) {
|
| 1900 |
for (uint32_t i = range.first; i < range.second; ++i) {
|
| 1901 |
std::vector<llama_seq_id> seq_ids;
|
| 1902 |
|
|
|
|
| 1921 |
}
|
| 1922 |
}
|
| 1923 |
|
| 1924 |
+
void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const cell_ranges_t & cr) const {
|
| 1925 |
+
const auto & cells = v_cells[cr.strm];
|
| 1926 |
+
|
| 1927 |
const uint32_t v_trans = this->v_trans ? 1 : 0;
|
| 1928 |
const uint32_t n_layer = layers.size();
|
| 1929 |
|
|
|
|
| 1939 |
|
| 1940 |
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
| 1941 |
|
| 1942 |
+
auto * k = layer.k_stream[cr.strm];
|
| 1943 |
+
|
| 1944 |
// Write key type
|
| 1945 |
+
const int32_t k_type_i = (int32_t) k->type;
|
| 1946 |
io.write(&k_type_i, sizeof(k_type_i));
|
| 1947 |
|
| 1948 |
// Write row size of key
|
| 1949 |
+
const uint64_t k_size_row = ggml_row_size(k->type, n_embd_k_gqa);
|
| 1950 |
io.write(&k_size_row, sizeof(k_size_row));
|
| 1951 |
|
| 1952 |
// Read each range of cells of k_size length each into tmp_buf and write out
|
| 1953 |
+
for (const auto & range : cr.data) {
|
| 1954 |
const size_t range_size = range.second - range.first;
|
| 1955 |
const size_t buf_size = range_size * k_size_row;
|
| 1956 |
+
io.write_tensor(k, range.first * k_size_row, buf_size);
|
| 1957 |
}
|
| 1958 |
}
|
| 1959 |
|
|
|
|
| 1963 |
|
| 1964 |
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
| 1965 |
|
| 1966 |
+
auto * v = layer.v_stream[cr.strm];
|
| 1967 |
+
|
| 1968 |
// Write value type
|
| 1969 |
+
const int32_t v_type_i = (int32_t) v->type;
|
| 1970 |
io.write(&v_type_i, sizeof(v_type_i));
|
| 1971 |
|
| 1972 |
// Write row size of value
|
| 1973 |
+
const uint64_t v_size_row = ggml_row_size(v->type, n_embd_v_gqa);
|
| 1974 |
io.write(&v_size_row, sizeof(v_size_row));
|
| 1975 |
|
| 1976 |
// Read each range of cells of v_size length each into tmp_buf and write out
|
| 1977 |
+
for (const auto & range : cr.data) {
|
| 1978 |
const size_t range_size = range.second - range.first;
|
| 1979 |
const size_t buf_size = range_size * v_size_row;
|
| 1980 |
+
io.write_tensor(v, range.first * v_size_row, buf_size);
|
| 1981 |
}
|
| 1982 |
}
|
| 1983 |
} else {
|
|
|
|
| 1989 |
|
| 1990 |
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
| 1991 |
|
| 1992 |
+
auto * v = layer.v_stream[cr.strm];
|
| 1993 |
+
|
| 1994 |
// Write value type
|
| 1995 |
+
const int32_t v_type_i = (int32_t) v->type;
|
| 1996 |
io.write(&v_type_i, sizeof(v_type_i));
|
| 1997 |
|
| 1998 |
// Write element size
|
| 1999 |
+
const uint32_t v_size_el = ggml_type_size(v->type);
|
| 2000 |
io.write(&v_size_el, sizeof(v_size_el));
|
| 2001 |
|
| 2002 |
// Write GQA embedding size
|
|
|
|
| 2005 |
// For each row, we get the element values of each cell
|
| 2006 |
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
|
| 2007 |
// Read each range of cells of v_size_el length each into tmp_buf and write out
|
| 2008 |
+
for (const auto & range : cr.data) {
|
| 2009 |
const size_t range_size = range.second - range.first;
|
| 2010 |
const size_t src_offset = (range.first + j * kv_size) * v_size_el;
|
| 2011 |
const size_t buf_size = range_size * v_size_el;
|
| 2012 |
+
io.write_tensor(v, src_offset, buf_size);
|
| 2013 |
}
|
| 2014 |
}
|
| 2015 |
}
|
| 2016 |
}
|
| 2017 |
}
|
| 2018 |
|
| 2019 |
+
bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, llama_seq_id dest_seq_id) {
|
| 2020 |
+
auto & cells = v_cells[strm];
|
| 2021 |
+
auto & head = v_heads[strm];
|
| 2022 |
+
|
| 2023 |
if (dest_seq_id != -1) {
|
| 2024 |
// single sequence
|
|
|
|
| 2025 |
seq_rm(dest_seq_id, -1, -1);
|
| 2026 |
|
| 2027 |
llama_batch_allocr balloc(hparams.n_pos_per_embd());
|
| 2028 |
|
| 2029 |
llama_ubatch ubatch = balloc.ubatch_reserve(cell_count, 1);
|
| 2030 |
|
| 2031 |
+
ubatch.seq_id_unq[0] = dest_seq_id;
|
| 2032 |
+
|
| 2033 |
for (uint32_t i = 0; i < cell_count; ++i) {
|
| 2034 |
llama_pos pos;
|
| 2035 |
uint32_t n_seq_id;
|
|
|
|
| 2066 |
// keep the head at the old position because we will read the KV data into it in state_read_data()
|
| 2067 |
head = head_cur;
|
| 2068 |
|
| 2069 |
+
LLAMA_LOG_DEBUG("%s: head_cur = %d, head = %d, cell_count = %d, dest_seq_id = %d\n", __func__, head_cur, head, cell_count, dest_seq_id);
|
| 2070 |
+
|
| 2071 |
// DEBUG CHECK: head_cur should be our first cell, head_cur + cell_count - 1 should be our last cell (verify seq_id and pos values)
|
| 2072 |
// Assume that this is one contiguous block of cells
|
| 2073 |
GGML_ASSERT(head_cur + cell_count <= cells.size());
|
|
|
|
| 2113 |
return true;
|
| 2114 |
}
|
| 2115 |
|
| 2116 |
+
bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count) {
|
| 2117 |
+
auto & cells = v_cells[strm];
|
| 2118 |
+
auto & head = v_heads[strm];
|
| 2119 |
+
|
| 2120 |
uint32_t v_trans;
|
| 2121 |
uint32_t n_layer;
|
| 2122 |
|
|
|
|
| 2144 |
|
| 2145 |
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
| 2146 |
|
| 2147 |
+
auto * k = layer.k_stream[strm];
|
| 2148 |
+
|
| 2149 |
// Read type of key
|
| 2150 |
int32_t k_type_i_ref;
|
| 2151 |
io.read_to(&k_type_i_ref, sizeof(k_type_i_ref));
|
| 2152 |
+
const int32_t k_type_i = (int32_t) k->type;
|
| 2153 |
if (k_type_i != k_type_i_ref) {
|
| 2154 |
LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
|
| 2155 |
return false;
|
|
|
|
| 2158 |
// Read row size of key
|
| 2159 |
uint64_t k_size_row_ref;
|
| 2160 |
io.read_to(&k_size_row_ref, sizeof(k_size_row_ref));
|
| 2161 |
+
const size_t k_size_row = ggml_row_size(k->type, n_embd_k_gqa);
|
| 2162 |
if (k_size_row != k_size_row_ref) {
|
| 2163 |
LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il);
|
| 2164 |
return false;
|
|
|
|
| 2166 |
|
| 2167 |
if (cell_count) {
|
| 2168 |
// Read and set the keys for the whole cell range
|
| 2169 |
+
ggml_backend_tensor_set(k, io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row);
|
| 2170 |
}
|
| 2171 |
}
|
| 2172 |
|
|
|
|
| 2176 |
|
| 2177 |
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
| 2178 |
|
| 2179 |
+
auto * v = layer.v_stream[strm];
|
| 2180 |
+
|
| 2181 |
// Read type of value
|
| 2182 |
int32_t v_type_i_ref;
|
| 2183 |
io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
|
| 2184 |
+
const int32_t v_type_i = (int32_t) v->type;
|
| 2185 |
if (v_type_i != v_type_i_ref) {
|
| 2186 |
LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
|
| 2187 |
return false;
|
|
|
|
| 2190 |
// Read row size of value
|
| 2191 |
uint64_t v_size_row_ref;
|
| 2192 |
io.read_to(&v_size_row_ref, sizeof(v_size_row_ref));
|
| 2193 |
+
const size_t v_size_row = ggml_row_size(v->type, n_embd_v_gqa);
|
| 2194 |
if (v_size_row != v_size_row_ref) {
|
| 2195 |
LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il);
|
| 2196 |
return false;
|
|
|
|
| 2198 |
|
| 2199 |
if (cell_count) {
|
| 2200 |
// Read and set the values for the whole cell range
|
| 2201 |
+
ggml_backend_tensor_set(v, io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row);
|
| 2202 |
}
|
| 2203 |
}
|
| 2204 |
} else {
|
|
|
|
| 2208 |
|
| 2209 |
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
| 2210 |
|
| 2211 |
+
auto * v = layer.v_stream[strm];
|
| 2212 |
+
|
| 2213 |
// Read type of value
|
| 2214 |
int32_t v_type_i_ref;
|
| 2215 |
io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
|
| 2216 |
+
const int32_t v_type_i = (int32_t) v->type;
|
| 2217 |
if (v_type_i != v_type_i_ref) {
|
| 2218 |
LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
|
| 2219 |
return false;
|
|
|
|
| 2222 |
// Read element size of value
|
| 2223 |
uint32_t v_size_el_ref;
|
| 2224 |
io.read_to(&v_size_el_ref, sizeof(v_size_el_ref));
|
| 2225 |
+
const size_t v_size_el = ggml_type_size(v->type);
|
| 2226 |
if (v_size_el != v_size_el_ref) {
|
| 2227 |
LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il);
|
| 2228 |
return false;
|
|
|
|
| 2240 |
// For each row in the transposed matrix, read the values for the whole cell range
|
| 2241 |
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
|
| 2242 |
const size_t dst_offset = (head + j * cells.size()) * v_size_el;
|
| 2243 |
+
ggml_backend_tensor_set(v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
|
| 2244 |
}
|
| 2245 |
}
|
| 2246 |
}
|
|
|
|
| 2259 |
llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) {
|
| 2260 |
n_kv = kv->get_size();
|
| 2261 |
|
| 2262 |
+
const uint32_t n_stream = kv->get_n_stream();
|
| 2263 |
+
|
| 2264 |
// create a dummy slot info - the actual data is irrelevant. we just need to build the graph
|
| 2265 |
sinfos.resize(1);
|
| 2266 |
+
sinfos[0].s0 = 0;
|
| 2267 |
+
sinfos[0].s1 = n_stream - 1;
|
| 2268 |
+
sinfos[0].idxs.resize(n_stream);
|
| 2269 |
+
for (uint32_t s = 0; s < n_stream; ++s) {
|
| 2270 |
+
sinfos[0].strm.push_back(s);
|
| 2271 |
+
sinfos[0].idxs[s].resize(1, 0);
|
| 2272 |
+
}
|
| 2273 |
}
|
| 2274 |
|
| 2275 |
llama_kv_cache_unified_context::llama_kv_cache_unified_context(
|
| 2276 |
llama_kv_cache_unified * kv,
|
| 2277 |
llama_context * lctx,
|
| 2278 |
bool do_shift,
|
| 2279 |
+
defrag_info dinfo,
|
| 2280 |
+
stream_copy_info sc_info) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), lctx(lctx), do_shift(do_shift), dinfo(std::move(dinfo)), sc_info(std::move(sc_info)) {
|
| 2281 |
+
if (!do_shift && this->dinfo.empty() && this->sc_info.empty()) {
|
| 2282 |
status = LLAMA_MEMORY_STATUS_NO_UPDATE;
|
| 2283 |
}
|
| 2284 |
}
|
|
|
|
| 2306 |
|
| 2307 |
// no ubatches -> this is a KV cache update
|
| 2308 |
if (ubatches.empty()) {
|
| 2309 |
+
kv->update(lctx, do_shift, dinfo, sc_info);
|
| 2310 |
|
| 2311 |
return true;
|
| 2312 |
}
|
|
|
|
| 2332 |
return n_kv;
|
| 2333 |
}
|
| 2334 |
|
| 2335 |
+
bool llama_kv_cache_unified_context::get_supports_set_rows() const {
|
| 2336 |
+
return kv->get_supports_set_rows();
|
| 2337 |
+
}
|
| 2338 |
+
|
| 2339 |
ggml_tensor * llama_kv_cache_unified_context::get_k(ggml_context * ctx, int32_t il) const {
|
| 2340 |
+
return kv->get_k(ctx, il, n_kv, sinfos[i_cur]);
|
| 2341 |
}
|
| 2342 |
|
| 2343 |
ggml_tensor * llama_kv_cache_unified_context::get_v(ggml_context * ctx, int32_t il) const {
|
| 2344 |
+
return kv->get_v(ctx, il, n_kv, sinfos[i_cur]);
|
| 2345 |
}
|
| 2346 |
|
| 2347 |
ggml_tensor * llama_kv_cache_unified_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const {
|
examples/talk-llama/llama-kv-cache-unified.h
CHANGED
|
@@ -35,16 +35,50 @@ public:
|
|
| 35 |
std::vector<uint32_t> ids;
|
| 36 |
};
|
| 37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
// for each ubatch, create a slot_info that contains information about where the ubatch should be inserted in the
|
| 39 |
// KV cells. for example, cell indices for each token, such that: token[i] -> goes to cells[idxs[i]]
|
| 40 |
struct slot_info {
|
| 41 |
// data for ggml_set_rows
|
| 42 |
using idx_vec_t = std::vector<uint32_t>;
|
| 43 |
|
| 44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
|
| 46 |
uint32_t head() const {
|
| 47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
}
|
| 49 |
|
| 50 |
bool empty() const {
|
|
@@ -54,9 +88,6 @@ public:
|
|
| 54 |
void clear() {
|
| 55 |
idxs.clear();
|
| 56 |
}
|
| 57 |
-
|
| 58 |
-
// TODO: implement
|
| 59 |
-
//std::vector<idx_vec_t> seq_idxs;
|
| 60 |
};
|
| 61 |
|
| 62 |
using slot_info_vec_t = std::vector<slot_info>;
|
|
@@ -68,6 +99,7 @@ public:
|
|
| 68 |
ggml_type type_v,
|
| 69 |
bool v_trans,
|
| 70 |
bool offload,
|
|
|
|
| 71 |
uint32_t kv_size,
|
| 72 |
uint32_t n_seq_max,
|
| 73 |
uint32_t n_pad,
|
|
@@ -111,7 +143,8 @@ public:
|
|
| 111 |
// llama_kv_cache_unified specific API
|
| 112 |
//
|
| 113 |
|
| 114 |
-
uint32_t get_size()
|
|
|
|
| 115 |
|
| 116 |
bool get_has_shift() const;
|
| 117 |
|
|
@@ -121,9 +154,12 @@ public:
|
|
| 121 |
|
| 122 |
uint32_t get_n_kv() const;
|
| 123 |
|
|
|
|
|
|
|
|
|
|
| 124 |
// get views of the current state of the cache
|
| 125 |
-
ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv) const;
|
| 126 |
-
ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const;
|
| 127 |
|
| 128 |
// store k_cur and v_cur in the cache based on the provided head location
|
| 129 |
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const;
|
|
@@ -137,7 +173,7 @@ public:
|
|
| 137 |
// return empty vector on failure
|
| 138 |
slot_info_vec_t prepare(const std::vector<llama_ubatch> & ubatches);
|
| 139 |
|
| 140 |
-
bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo);
|
| 141 |
|
| 142 |
// find a slot of kv cells that can hold the ubatch
|
| 143 |
// if cont == true, then the slot must be continuous
|
|
@@ -157,8 +193,9 @@ public:
|
|
| 157 |
void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
|
| 158 |
void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
|
| 159 |
|
|
|
|
|
|
|
| 160 |
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
|
| 161 |
-
void set_input_k_shift (ggml_tensor * dst) const;
|
| 162 |
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
|
| 163 |
|
| 164 |
private:
|
|
@@ -172,15 +209,15 @@ private:
|
|
| 172 |
|
| 173 |
ggml_tensor * k;
|
| 174 |
ggml_tensor * v;
|
|
|
|
|
|
|
|
|
|
| 175 |
};
|
| 176 |
|
| 177 |
bool v_trans = true; // the value tensor is transposed
|
| 178 |
|
| 179 |
-
// the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot())
|
| 180 |
-
// note: this is not part of the KV state and it's only used to speed-up the find_slot() method
|
| 181 |
-
uint32_t head = 0;
|
| 182 |
-
|
| 183 |
const uint32_t n_seq_max = 1;
|
|
|
|
| 184 |
|
| 185 |
// required padding
|
| 186 |
const uint32_t n_pad = 1;
|
|
@@ -193,14 +230,24 @@ private:
|
|
| 193 |
|
| 194 |
// env: LLAMA_SET_ROWS (temporary)
|
| 195 |
// ref: https://github.com/ggml-org/llama.cpp/pull/14285
|
| 196 |
-
|
| 197 |
|
| 198 |
const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
|
| 199 |
|
| 200 |
std::vector<ggml_context_ptr> ctxs;
|
| 201 |
std::vector<ggml_backend_buffer_ptr> bufs;
|
| 202 |
|
| 203 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 204 |
|
| 205 |
std::vector<kv_layer> layers;
|
| 206 |
|
|
@@ -226,29 +273,34 @@ private:
|
|
| 226 |
float freq_base,
|
| 227 |
float freq_scale) const;
|
| 228 |
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
ggml_cgraph * gf) const;
|
| 233 |
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
ggml_cgraph * gf,
|
| 238 |
const defrag_info & dinfo) const;
|
| 239 |
|
| 240 |
-
|
| 241 |
-
|
| 242 |
|
| 243 |
-
|
| 244 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 245 |
};
|
| 246 |
|
| 247 |
class llama_kv_cache_unified_context : public llama_memory_context_i {
|
| 248 |
public:
|
| 249 |
// some shorthands
|
| 250 |
-
using slot_info_vec_t
|
| 251 |
-
using defrag_info
|
|
|
|
| 252 |
|
| 253 |
// used for errors
|
| 254 |
llama_kv_cache_unified_context(llama_memory_status status);
|
|
@@ -262,7 +314,8 @@ public:
|
|
| 262 |
llama_kv_cache_unified * kv,
|
| 263 |
llama_context * lctx,
|
| 264 |
bool do_shift,
|
| 265 |
-
defrag_info dinfo
|
|
|
|
| 266 |
|
| 267 |
// used to create a batch procesing context from a batch
|
| 268 |
llama_kv_cache_unified_context(
|
|
@@ -288,6 +341,9 @@ public:
|
|
| 288 |
|
| 289 |
uint32_t get_n_kv() const;
|
| 290 |
|
|
|
|
|
|
|
|
|
|
| 291 |
// get views of the current state of the cache
|
| 292 |
ggml_tensor * get_k(ggml_context * ctx, int32_t il) const;
|
| 293 |
ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;
|
|
@@ -320,6 +376,8 @@ private:
|
|
| 320 |
|
| 321 |
defrag_info dinfo;
|
| 322 |
|
|
|
|
|
|
|
| 323 |
//
|
| 324 |
// batch processing context
|
| 325 |
//
|
|
|
|
| 35 |
std::vector<uint32_t> ids;
|
| 36 |
};
|
| 37 |
|
| 38 |
+
struct stream_copy_info {
|
| 39 |
+
bool empty() const {
|
| 40 |
+
assert(ssrc.size() == sdst.size());
|
| 41 |
+
return ssrc.empty();
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
std::vector<uint32_t> ssrc;
|
| 45 |
+
std::vector<uint32_t> sdst;
|
| 46 |
+
};
|
| 47 |
+
|
| 48 |
// for each ubatch, create a slot_info that contains information about where the ubatch should be inserted in the
|
| 49 |
// KV cells. for example, cell indices for each token, such that: token[i] -> goes to cells[idxs[i]]
|
| 50 |
struct slot_info {
|
| 51 |
// data for ggml_set_rows
|
| 52 |
using idx_vec_t = std::vector<uint32_t>;
|
| 53 |
|
| 54 |
+
// number of streams: ns = s1 - s0 + 1
|
| 55 |
+
llama_seq_id s0;
|
| 56 |
+
llama_seq_id s1;
|
| 57 |
+
|
| 58 |
+
std::vector<llama_seq_id> strm; // [ns]
|
| 59 |
+
std::vector<idx_vec_t> idxs; // [ns]
|
| 60 |
|
| 61 |
uint32_t head() const {
|
| 62 |
+
GGML_ASSERT(idxs.size() == 1);
|
| 63 |
+
GGML_ASSERT(!idxs[0].empty());
|
| 64 |
+
|
| 65 |
+
return idxs[0][0];
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
void resize(size_t n) {
|
| 69 |
+
strm.resize(n);
|
| 70 |
+
idxs.resize(n);
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
size_t size() const {
|
| 74 |
+
GGML_ASSERT(idxs.size() == strm.size());
|
| 75 |
+
GGML_ASSERT(!idxs.empty());
|
| 76 |
+
|
| 77 |
+
return idxs[0].size();
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
size_t n_stream() const {
|
| 81 |
+
return strm.size();
|
| 82 |
}
|
| 83 |
|
| 84 |
bool empty() const {
|
|
|
|
| 88 |
void clear() {
|
| 89 |
idxs.clear();
|
| 90 |
}
|
|
|
|
|
|
|
|
|
|
| 91 |
};
|
| 92 |
|
| 93 |
using slot_info_vec_t = std::vector<slot_info>;
|
|
|
|
| 99 |
ggml_type type_v,
|
| 100 |
bool v_trans,
|
| 101 |
bool offload,
|
| 102 |
+
bool unified,
|
| 103 |
uint32_t kv_size,
|
| 104 |
uint32_t n_seq_max,
|
| 105 |
uint32_t n_pad,
|
|
|
|
| 143 |
// llama_kv_cache_unified specific API
|
| 144 |
//
|
| 145 |
|
| 146 |
+
uint32_t get_size() const;
|
| 147 |
+
uint32_t get_n_stream() const;
|
| 148 |
|
| 149 |
bool get_has_shift() const;
|
| 150 |
|
|
|
|
| 154 |
|
| 155 |
uint32_t get_n_kv() const;
|
| 156 |
|
| 157 |
+
// TODO: temporary
|
| 158 |
+
bool get_supports_set_rows() const;
|
| 159 |
+
|
| 160 |
// get views of the current state of the cache
|
| 161 |
+
ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const;
|
| 162 |
+
ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const;
|
| 163 |
|
| 164 |
// store k_cur and v_cur in the cache based on the provided head location
|
| 165 |
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const;
|
|
|
|
| 173 |
// return empty vector on failure
|
| 174 |
slot_info_vec_t prepare(const std::vector<llama_ubatch> & ubatches);
|
| 175 |
|
| 176 |
+
bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo, const stream_copy_info & sc_info);
|
| 177 |
|
| 178 |
// find a slot of kv cells that can hold the ubatch
|
| 179 |
// if cont == true, then the slot must be continuous
|
|
|
|
| 193 |
void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
|
| 194 |
void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
|
| 195 |
|
| 196 |
+
void set_input_k_shift(ggml_tensor * dst) const;
|
| 197 |
+
|
| 198 |
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
|
|
|
|
| 199 |
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
|
| 200 |
|
| 201 |
private:
|
|
|
|
| 209 |
|
| 210 |
ggml_tensor * k;
|
| 211 |
ggml_tensor * v;
|
| 212 |
+
|
| 213 |
+
std::vector<ggml_tensor *> k_stream;
|
| 214 |
+
std::vector<ggml_tensor *> v_stream;
|
| 215 |
};
|
| 216 |
|
| 217 |
bool v_trans = true; // the value tensor is transposed
|
| 218 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 219 |
const uint32_t n_seq_max = 1;
|
| 220 |
+
const uint32_t n_stream = 1;
|
| 221 |
|
| 222 |
// required padding
|
| 223 |
const uint32_t n_pad = 1;
|
|
|
|
| 230 |
|
| 231 |
// env: LLAMA_SET_ROWS (temporary)
|
| 232 |
// ref: https://github.com/ggml-org/llama.cpp/pull/14285
|
| 233 |
+
bool supports_set_rows = false;
|
| 234 |
|
| 235 |
const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
|
| 236 |
|
| 237 |
std::vector<ggml_context_ptr> ctxs;
|
| 238 |
std::vector<ggml_backend_buffer_ptr> bufs;
|
| 239 |
|
| 240 |
+
// the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot())
|
| 241 |
+
// note: this is not part of the KV state and it's only used to speed-up the find_slot() method
|
| 242 |
+
std::vector<uint32_t> v_heads;
|
| 243 |
+
|
| 244 |
+
std::vector<llama_kv_cells_unified> v_cells;
|
| 245 |
+
|
| 246 |
+
// maps from a sequence id to a stream id
|
| 247 |
+
std::vector<uint32_t> seq_to_stream;
|
| 248 |
+
|
| 249 |
+
// pending stream copies that will be applied during the next update
|
| 250 |
+
stream_copy_info sc_info;
|
| 251 |
|
| 252 |
std::vector<kv_layer> layers;
|
| 253 |
|
|
|
|
| 273 |
float freq_base,
|
| 274 |
float freq_scale) const;
|
| 275 |
|
| 276 |
+
ggml_cgraph * build_graph_shift(
|
| 277 |
+
llm_graph_result * res,
|
| 278 |
+
llama_context * lctx) const;
|
|
|
|
| 279 |
|
| 280 |
+
ggml_cgraph * build_graph_defrag(
|
| 281 |
+
llm_graph_result * res,
|
| 282 |
+
llama_context * lctx,
|
|
|
|
| 283 |
const defrag_info & dinfo) const;
|
| 284 |
|
| 285 |
+
struct cell_ranges_t {
|
| 286 |
+
uint32_t strm;
|
| 287 |
|
| 288 |
+
std::vector<std::pair<uint32_t, uint32_t>> data; // ranges, from inclusive, to exclusive
|
| 289 |
+
};
|
| 290 |
+
|
| 291 |
+
void state_write_meta(llama_io_write_i & io, const cell_ranges_t & cr, llama_seq_id seq_id = -1) const;
|
| 292 |
+
void state_write_data(llama_io_write_i & io, const cell_ranges_t & cr) const;
|
| 293 |
+
|
| 294 |
+
bool state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
|
| 295 |
+
bool state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count);
|
| 296 |
};
|
| 297 |
|
| 298 |
class llama_kv_cache_unified_context : public llama_memory_context_i {
|
| 299 |
public:
|
| 300 |
// some shorthands
|
| 301 |
+
using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
|
| 302 |
+
using defrag_info = llama_kv_cache_unified::defrag_info;
|
| 303 |
+
using stream_copy_info = llama_kv_cache_unified::stream_copy_info;
|
| 304 |
|
| 305 |
// used for errors
|
| 306 |
llama_kv_cache_unified_context(llama_memory_status status);
|
|
|
|
| 314 |
llama_kv_cache_unified * kv,
|
| 315 |
llama_context * lctx,
|
| 316 |
bool do_shift,
|
| 317 |
+
defrag_info dinfo,
|
| 318 |
+
stream_copy_info sc_info);
|
| 319 |
|
| 320 |
// used to create a batch procesing context from a batch
|
| 321 |
llama_kv_cache_unified_context(
|
|
|
|
| 341 |
|
| 342 |
uint32_t get_n_kv() const;
|
| 343 |
|
| 344 |
+
// TODO: temporary
|
| 345 |
+
bool get_supports_set_rows() const;
|
| 346 |
+
|
| 347 |
// get views of the current state of the cache
|
| 348 |
ggml_tensor * get_k(ggml_context * ctx, int32_t il) const;
|
| 349 |
ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;
|
|
|
|
| 376 |
|
| 377 |
defrag_info dinfo;
|
| 378 |
|
| 379 |
+
stream_copy_info sc_info;
|
| 380 |
+
|
| 381 |
//
|
| 382 |
// batch processing context
|
| 383 |
//
|
examples/talk-llama/llama-memory-hybrid.cpp
CHANGED
|
@@ -38,6 +38,7 @@ llama_memory_hybrid::llama_memory_hybrid(
|
|
| 38 |
type_v,
|
| 39 |
v_trans,
|
| 40 |
offload,
|
|
|
|
| 41 |
kv_size,
|
| 42 |
n_seq_max,
|
| 43 |
n_pad,
|
|
|
|
| 38 |
type_v,
|
| 39 |
v_trans,
|
| 40 |
offload,
|
| 41 |
+
1,
|
| 42 |
kv_size,
|
| 43 |
n_seq_max,
|
| 44 |
n_pad,
|
examples/talk-llama/llama-memory-recurrent.cpp
CHANGED
|
@@ -446,7 +446,7 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
|
|
| 446 |
// A slot should be always be contiguous.
|
| 447 |
|
| 448 |
// can only process batches with an equal number of new tokens in each sequence
|
| 449 |
-
GGML_ASSERT(ubatch.equal_seqs);
|
| 450 |
|
| 451 |
int32_t min = size - 1;
|
| 452 |
int32_t max = 0;
|
|
@@ -768,6 +768,8 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std::
|
|
| 768 |
// Iterate and write all the keys first, each row is a cell
|
| 769 |
// Get whole range at a time
|
| 770 |
for (uint32_t il = 0; il < n_layer; ++il) {
|
|
|
|
|
|
|
| 771 |
|
| 772 |
// Write key type
|
| 773 |
const int32_t r_type_i = (int32_t)r_l[il]->type;
|
|
@@ -787,6 +789,8 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std::
|
|
| 787 |
|
| 788 |
if (!s_trans) {
|
| 789 |
for (uint32_t il = 0; il < n_layer; ++il) {
|
|
|
|
|
|
|
| 790 |
|
| 791 |
// Write value type
|
| 792 |
const int32_t s_type_i = (int32_t)s_l[il]->type;
|
|
@@ -807,6 +811,9 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std::
|
|
| 807 |
// When v is transposed, we also need the element size and get the element ranges from each row
|
| 808 |
const uint32_t mem_size = size;
|
| 809 |
for (uint32_t il = 0; il < n_layer; ++il) {
|
|
|
|
|
|
|
|
|
|
| 810 |
const uint32_t n_embd_s = hparams.n_embd_s();
|
| 811 |
|
| 812 |
// Write value type
|
|
@@ -951,6 +958,8 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
|
|
| 951 |
|
| 952 |
// For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
|
| 953 |
for (uint32_t il = 0; il < n_layer; ++il) {
|
|
|
|
|
|
|
| 954 |
|
| 955 |
// Read type of key
|
| 956 |
int32_t r_type_i_ref;
|
|
@@ -978,11 +987,14 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
|
|
| 978 |
|
| 979 |
if (!s_trans) {
|
| 980 |
for (uint32_t il = 0; il < n_layer; ++il) {
|
|
|
|
|
|
|
| 981 |
|
| 982 |
// Read type of value
|
| 983 |
int32_t s_type_i_ref;
|
| 984 |
io.read_to(&s_type_i_ref, sizeof(s_type_i_ref));
|
| 985 |
const int32_t s_type_i = (int32_t)s_l[il]->type;
|
|
|
|
| 986 |
if (s_type_i != s_type_i_ref) {
|
| 987 |
LLAMA_LOG_ERROR("%s: mismatched s type (%d != %d, layer %d)\n", __func__, s_type_i, s_type_i_ref, il);
|
| 988 |
return false;
|
|
@@ -1005,6 +1017,9 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
|
|
| 1005 |
} else {
|
| 1006 |
// For each layer, read the values for each cell (transposed)
|
| 1007 |
for (uint32_t il = 0; il < n_layer; ++il) {
|
|
|
|
|
|
|
|
|
|
| 1008 |
const uint32_t n_embd_s = hparams.n_embd_s();
|
| 1009 |
|
| 1010 |
// Read type of value
|
|
|
|
| 446 |
// A slot should be always be contiguous.
|
| 447 |
|
| 448 |
// can only process batches with an equal number of new tokens in each sequence
|
| 449 |
+
GGML_ASSERT(ubatch.equal_seqs());
|
| 450 |
|
| 451 |
int32_t min = size - 1;
|
| 452 |
int32_t max = 0;
|
|
|
|
| 768 |
// Iterate and write all the keys first, each row is a cell
|
| 769 |
// Get whole range at a time
|
| 770 |
for (uint32_t il = 0; il < n_layer; ++il) {
|
| 771 |
+
// skip null layers (read_data will handle this by checking "r_l" and "s_l" for null)
|
| 772 |
+
if (r_l[il] == nullptr) continue;
|
| 773 |
|
| 774 |
// Write key type
|
| 775 |
const int32_t r_type_i = (int32_t)r_l[il]->type;
|
|
|
|
| 789 |
|
| 790 |
if (!s_trans) {
|
| 791 |
for (uint32_t il = 0; il < n_layer; ++il) {
|
| 792 |
+
// skip null layers (read_data will handle this by checking "r_l" and "s_l" for null)
|
| 793 |
+
if (s_l[il] == nullptr) continue;
|
| 794 |
|
| 795 |
// Write value type
|
| 796 |
const int32_t s_type_i = (int32_t)s_l[il]->type;
|
|
|
|
| 811 |
// When v is transposed, we also need the element size and get the element ranges from each row
|
| 812 |
const uint32_t mem_size = size;
|
| 813 |
for (uint32_t il = 0; il < n_layer; ++il) {
|
| 814 |
+
// skip null layers (read_data will handle this by checking "r_l" and "s_l" for null)
|
| 815 |
+
if (s_l[il] == nullptr) continue;
|
| 816 |
+
|
| 817 |
const uint32_t n_embd_s = hparams.n_embd_s();
|
| 818 |
|
| 819 |
// Write value type
|
|
|
|
| 958 |
|
| 959 |
// For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
|
| 960 |
for (uint32_t il = 0; il < n_layer; ++il) {
|
| 961 |
+
// skip null layers
|
| 962 |
+
if (r_l[il] == nullptr) continue;
|
| 963 |
|
| 964 |
// Read type of key
|
| 965 |
int32_t r_type_i_ref;
|
|
|
|
| 987 |
|
| 988 |
if (!s_trans) {
|
| 989 |
for (uint32_t il = 0; il < n_layer; ++il) {
|
| 990 |
+
// skip null layers
|
| 991 |
+
if (s_l[il] == nullptr) continue;
|
| 992 |
|
| 993 |
// Read type of value
|
| 994 |
int32_t s_type_i_ref;
|
| 995 |
io.read_to(&s_type_i_ref, sizeof(s_type_i_ref));
|
| 996 |
const int32_t s_type_i = (int32_t)s_l[il]->type;
|
| 997 |
+
|
| 998 |
if (s_type_i != s_type_i_ref) {
|
| 999 |
LLAMA_LOG_ERROR("%s: mismatched s type (%d != %d, layer %d)\n", __func__, s_type_i, s_type_i_ref, il);
|
| 1000 |
return false;
|
|
|
|
| 1017 |
} else {
|
| 1018 |
// For each layer, read the values for each cell (transposed)
|
| 1019 |
for (uint32_t il = 0; il < n_layer; ++il) {
|
| 1020 |
+
// skip null layers
|
| 1021 |
+
if (s_l[il] == nullptr) continue;
|
| 1022 |
+
|
| 1023 |
const uint32_t n_embd_s = hparams.n_embd_s();
|
| 1024 |
|
| 1025 |
// Read type of value
|
examples/talk-llama/llama-model.cpp
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
examples/talk-llama/llama-model.h
CHANGED
|
@@ -99,8 +99,10 @@ enum llm_type {
|
|
| 99 |
LLM_TYPE_17B_16E, // llama4 Scout
|
| 100 |
LLM_TYPE_17B_128E, // llama4 Maverick
|
| 101 |
LLM_TYPE_A13B,
|
|
|
|
| 102 |
LLM_TYPE_30B_A3B,
|
| 103 |
LLM_TYPE_235B_A22B,
|
|
|
|
| 104 |
LLM_TYPE_E2B,
|
| 105 |
LLM_TYPE_E4B,
|
| 106 |
};
|
|
@@ -452,10 +454,7 @@ struct llama_model {
|
|
| 452 |
llama_memory_i * create_memory(const llama_memory_params & params, llama_cparams & cparams) const;
|
| 453 |
|
| 454 |
// TODO: move this to new llm_arch_model_i interface
|
| 455 |
-
|
| 456 |
-
const llm_graph_params & params,
|
| 457 |
-
ggml_cgraph * gf,
|
| 458 |
-
llm_graph_type type) const;
|
| 459 |
|
| 460 |
private:
|
| 461 |
struct impl;
|
|
|
|
| 99 |
LLM_TYPE_17B_16E, // llama4 Scout
|
| 100 |
LLM_TYPE_17B_128E, // llama4 Maverick
|
| 101 |
LLM_TYPE_A13B,
|
| 102 |
+
LLM_TYPE_21B_A3B, // Ernie MoE small
|
| 103 |
LLM_TYPE_30B_A3B,
|
| 104 |
LLM_TYPE_235B_A22B,
|
| 105 |
+
LLM_TYPE_300B_A47B, // Ernie MoE big
|
| 106 |
LLM_TYPE_E2B,
|
| 107 |
LLM_TYPE_E4B,
|
| 108 |
};
|
|
|
|
| 454 |
llama_memory_i * create_memory(const llama_memory_params & params, llama_cparams & cparams) const;
|
| 455 |
|
| 456 |
// TODO: move this to new llm_arch_model_i interface
|
| 457 |
+
ggml_cgraph * build_graph(const llm_graph_params & params) const;
|
|
|
|
|
|
|
|
|
|
| 458 |
|
| 459 |
private:
|
| 460 |
struct impl;
|
examples/talk-llama/llama-quant.cpp
CHANGED
|
@@ -884,8 +884,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
|
| 884 |
if (std::regex pattern(tname); std::regex_search(tensor_name, pattern)) {
|
| 885 |
if (qtype != new_type) {
|
| 886 |
LLAMA_LOG_DEBUG("(overriding %s) ", ggml_type_name(new_type));
|
| 887 |
-
new_type = qtype;
|
| 888 |
-
break; // if two or more types are specified for the tensor, first match wins
|
| 889 |
}
|
| 890 |
}
|
| 891 |
}
|
|
|
|
| 884 |
if (std::regex pattern(tname); std::regex_search(tensor_name, pattern)) {
|
| 885 |
if (qtype != new_type) {
|
| 886 |
LLAMA_LOG_DEBUG("(overriding %s) ", ggml_type_name(new_type));
|
| 887 |
+
new_type = qtype; // if two or more types are specified for the same tensor, the last match wins
|
|
|
|
| 888 |
}
|
| 889 |
}
|
| 890 |
}
|
examples/talk-llama/llama-vocab.cpp
CHANGED
|
@@ -11,6 +11,7 @@
|
|
| 11 |
#include <cassert>
|
| 12 |
#include <cctype>
|
| 13 |
#include <cfloat>
|
|
|
|
| 14 |
#include <cstdarg>
|
| 15 |
#include <cstring>
|
| 16 |
#include <forward_list>
|
|
@@ -404,6 +405,13 @@ struct llm_tokenizer_bpe : llm_tokenizer {
|
|
| 404 |
"[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))*((?=[\\p{L}])([^A-Z]))+(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))+((?=[\\p{L}])([^A-Z]))*(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
| 405 |
};
|
| 406 |
break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 407 |
case LLAMA_VOCAB_PRE_TYPE_SUPERBPE:
|
| 408 |
regex_exprs = {
|
| 409 |
"\\p{N}+",
|
|
@@ -1196,6 +1204,284 @@ private:
|
|
| 1196 |
const llm_tokenizer_rwkv & tokenizer;
|
| 1197 |
};
|
| 1198 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1199 |
//
|
| 1200 |
// impl
|
| 1201 |
//
|
|
@@ -1499,6 +1785,16 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
|
| 1499 |
special_unk_id = LLAMA_TOKEN_NULL;
|
| 1500 |
special_sep_id = LLAMA_TOKEN_NULL;
|
| 1501 |
special_pad_id = LLAMA_TOKEN_NULL;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1502 |
} else {
|
| 1503 |
throw std::runtime_error(format("unknown tokenizer: '%s'", tokenizer_model.c_str()));
|
| 1504 |
}
|
|
@@ -1629,6 +1925,9 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
|
| 1629 |
} else if (
|
| 1630 |
tokenizer_pre == "exaone") {
|
| 1631 |
pre_type = LLAMA_VOCAB_PRE_TYPE_EXAONE;
|
|
|
|
|
|
|
|
|
|
| 1632 |
} else if (
|
| 1633 |
tokenizer_pre == "chameleon") {
|
| 1634 |
pre_type = LLAMA_VOCAB_PRE_TYPE_CHAMELEON;
|
|
@@ -1665,6 +1964,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
|
| 1665 |
tokenizer_pre == "hunyuan") {
|
| 1666 |
pre_type = LLAMA_VOCAB_PRE_TYPE_HUNYUAN;
|
| 1667 |
clean_spaces = false;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1668 |
} else {
|
| 1669 |
throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
|
| 1670 |
}
|
|
@@ -2145,13 +2448,14 @@ enum llama_vocab_type llama_vocab::impl::get_type() const {
|
|
| 2145 |
|
| 2146 |
std::string llama_vocab::impl::type_name() const{
|
| 2147 |
switch (type) {
|
| 2148 |
-
case LLAMA_VOCAB_TYPE_NONE:
|
| 2149 |
-
case LLAMA_VOCAB_TYPE_SPM:
|
| 2150 |
-
case LLAMA_VOCAB_TYPE_BPE:
|
| 2151 |
-
case LLAMA_VOCAB_TYPE_WPM:
|
| 2152 |
-
case LLAMA_VOCAB_TYPE_UGM:
|
| 2153 |
-
case LLAMA_VOCAB_TYPE_RWKV:
|
| 2154 |
-
|
|
|
|
| 2155 |
}
|
| 2156 |
}
|
| 2157 |
|
|
@@ -2234,6 +2538,9 @@ void llama_vocab::impl::init_tokenizer(enum llama_vocab_type type) {
|
|
| 2234 |
case LLAMA_VOCAB_TYPE_RWKV:
|
| 2235 |
tokenizer = std::make_unique<llm_tokenizer_rwkv>(vocab);
|
| 2236 |
break;
|
|
|
|
|
|
|
|
|
|
| 2237 |
default:
|
| 2238 |
GGML_ABORT("unsupported vocab type");
|
| 2239 |
}
|
|
@@ -2566,6 +2873,23 @@ std::vector<llama_token> llama_vocab::impl::tokenize(
|
|
| 2566 |
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
|
| 2567 |
std::string text = fragment.raw_text.substr(fragment.offset, fragment.length);
|
| 2568 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2569 |
#ifdef PRETOKENIZERDEBUG
|
| 2570 |
LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", text.length(), fragment.offset, fragment.length, text.c_str());
|
| 2571 |
#endif
|
|
@@ -2664,6 +2988,24 @@ int32_t llama_vocab::impl::token_to_piece(llama_token token, char * buf, int32_t
|
|
| 2664 |
memcpy(buf, result.data(), result.size());
|
| 2665 |
return (int)result.size();
|
| 2666 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2667 |
default:
|
| 2668 |
GGML_ABORT("fatal error");
|
| 2669 |
}
|
|
@@ -2908,6 +3250,12 @@ llama_token llama_vocab::byte_to_token(uint8_t ch) const {
|
|
| 2908 |
case LLAMA_VOCAB_TYPE_BPE: {
|
| 2909 |
return pimpl->token_to_id.at(unicode_byte_to_utf8(ch));
|
| 2910 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2911 |
default:
|
| 2912 |
GGML_ABORT("fatal error");
|
| 2913 |
}
|
|
@@ -3009,6 +3357,10 @@ llama_token llama_vocab::token_fim_sep() const {
|
|
| 3009 |
return pimpl->special_fim_sep_id;
|
| 3010 |
}
|
| 3011 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3012 |
bool llama_vocab::get_add_space_prefix() const {
|
| 3013 |
return pimpl->add_space_prefix;
|
| 3014 |
}
|
|
@@ -3249,6 +3601,10 @@ llama_token llama_vocab_fim_sep(const struct llama_vocab * vocab) {
|
|
| 3249 |
return vocab->token_fim_sep();
|
| 3250 |
}
|
| 3251 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3252 |
// deprecated
|
| 3253 |
const char * llama_token_get_text(const struct llama_vocab * vocab, llama_token token) {
|
| 3254 |
return llama_vocab_get_text(vocab, token);
|
|
@@ -3385,4 +3741,3 @@ int32_t llama_detokenize(
|
|
| 3385 |
bool unparse_special) {
|
| 3386 |
return vocab->detokenize(tokens, n_tokens, text, text_len_max, remove_special, unparse_special);
|
| 3387 |
}
|
| 3388 |
-
|
|
|
|
| 11 |
#include <cassert>
|
| 12 |
#include <cctype>
|
| 13 |
#include <cfloat>
|
| 14 |
+
#include <cmath>
|
| 15 |
#include <cstdarg>
|
| 16 |
#include <cstring>
|
| 17 |
#include <forward_list>
|
|
|
|
| 405 |
"[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))*((?=[\\p{L}])([^A-Z]))+(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))+((?=[\\p{L}])([^A-Z]))*(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
| 406 |
};
|
| 407 |
break;
|
| 408 |
+
case LLAMA_VOCAB_PRE_TYPE_KIMI_K2:
|
| 409 |
+
regex_exprs = {
|
| 410 |
+
// K2 trigger pattern - this will activate the custom K2 handler in unicode.cpp
|
| 411 |
+
// The custom handler implements all K2 patterns with proper Han character exclusion
|
| 412 |
+
"\\p{Han}+",
|
| 413 |
+
};
|
| 414 |
+
break;
|
| 415 |
case LLAMA_VOCAB_PRE_TYPE_SUPERBPE:
|
| 416 |
regex_exprs = {
|
| 417 |
"\\p{N}+",
|
|
|
|
| 1204 |
const llm_tokenizer_rwkv & tokenizer;
|
| 1205 |
};
|
| 1206 |
|
| 1207 |
+
struct llm_tokenizer_plamo2 : llm_tokenizer {
|
| 1208 |
+
llm_tokenizer_plamo2(const llama_vocab & vocab) {
|
| 1209 |
+
build(vocab);
|
| 1210 |
+
}
|
| 1211 |
+
|
| 1212 |
+
void build(const llama_vocab & vocab) {
|
| 1213 |
+
// Reset internal structures
|
| 1214 |
+
tokens_.clear();
|
| 1215 |
+
bytes_.assign(256, 0);
|
| 1216 |
+
to_suffix_id_.clear();
|
| 1217 |
+
table_.clear();
|
| 1218 |
+
|
| 1219 |
+
// Build token list and byte mapping
|
| 1220 |
+
std::unordered_map<std::string, float> suffix_to_score;
|
| 1221 |
+
std::unordered_map<std::string, llama_token> token_to_id;
|
| 1222 |
+
|
| 1223 |
+
for (size_t token_id = 0; token_id < vocab.n_tokens(); ++token_id) {
|
| 1224 |
+
const auto & entry = vocab.get_token_data(token_id);
|
| 1225 |
+
tokens_.push_back(entry.text);
|
| 1226 |
+
token_to_id[entry.text] = static_cast<llama_token>(token_id);
|
| 1227 |
+
|
| 1228 |
+
// Handle byte tokens
|
| 1229 |
+
if (vocab.is_byte(token_id)) {
|
| 1230 |
+
if (entry.text.length() == 6 && entry.text.substr(0, 3) == "<0x" && entry.text.back() == '>') {
|
| 1231 |
+
std::string hex_str = entry.text.substr(3, 2);
|
| 1232 |
+
int byte_val = std::stoi(hex_str, nullptr, 16);
|
| 1233 |
+
bytes_[byte_val] = static_cast<llama_token>(token_id);
|
| 1234 |
+
}
|
| 1235 |
+
continue;
|
| 1236 |
+
}
|
| 1237 |
+
|
| 1238 |
+
// Add token and all its suffixes to suffix_to_score
|
| 1239 |
+
suffix_to_score[entry.text] = entry.score;
|
| 1240 |
+
|
| 1241 |
+
// Extract suffixes character by character (UTF-8 aware)
|
| 1242 |
+
std::vector<uint32_t> cpts = unicode_cpts_from_utf8(entry.text);
|
| 1243 |
+
for (size_t i = 1; i < cpts.size(); ++i) {
|
| 1244 |
+
std::string suffix;
|
| 1245 |
+
for (size_t j = i; j < cpts.size(); ++j) {
|
| 1246 |
+
suffix += unicode_cpt_to_utf8(cpts[j]);
|
| 1247 |
+
}
|
| 1248 |
+
if (suffix_to_score.find(suffix) == suffix_to_score.end()) {
|
| 1249 |
+
suffix_to_score[suffix] = std::numeric_limits<float>::quiet_NaN();
|
| 1250 |
+
}
|
| 1251 |
+
}
|
| 1252 |
+
}
|
| 1253 |
+
|
| 1254 |
+
// Check that all byte tokens are set
|
| 1255 |
+
for (int i = 0; i < 256; ++i) {
|
| 1256 |
+
if (bytes_[i] == 0) {
|
| 1257 |
+
throw std::runtime_error("Byte token for <0x" + std::to_string(i) + "> is not set");
|
| 1258 |
+
}
|
| 1259 |
+
}
|
| 1260 |
+
|
| 1261 |
+
// Build suffix list in lexicographical order of reversed strings
|
| 1262 |
+
std::vector<std::string> suffixes;
|
| 1263 |
+
for (const auto & pair : suffix_to_score) {
|
| 1264 |
+
suffixes.push_back(pair.first);
|
| 1265 |
+
}
|
| 1266 |
+
suffixes.push_back(""); // Empty suffix
|
| 1267 |
+
|
| 1268 |
+
std::sort(suffixes.begin(), suffixes.end(), [](const std::string & a, const std::string & b) {
|
| 1269 |
+
std::string rev_a(a.rbegin(), a.rend());
|
| 1270 |
+
std::string rev_b(b.rbegin(), b.rend());
|
| 1271 |
+
return rev_a < rev_b;
|
| 1272 |
+
});
|
| 1273 |
+
|
| 1274 |
+
// Build suffix_to_id and to_suffix_id_
|
| 1275 |
+
std::unordered_map<std::string, int32_t> suffix_to_id;
|
| 1276 |
+
int32_t num_pieces = 0;
|
| 1277 |
+
|
| 1278 |
+
for (const auto & suffix : suffixes) {
|
| 1279 |
+
suffix_to_id[suffix] = num_pieces;
|
| 1280 |
+
if (!suffix.empty()) {
|
| 1281 |
+
std::vector<uint32_t> cpts = unicode_cpts_from_utf8(suffix);
|
| 1282 |
+
|
| 1283 |
+
std::string remaining;
|
| 1284 |
+
for (size_t i = 1; i < cpts.size(); ++i) {
|
| 1285 |
+
remaining += unicode_cpt_to_utf8(cpts[i]);
|
| 1286 |
+
}
|
| 1287 |
+
|
| 1288 |
+
int64_t piece_code = (static_cast<int64_t>(cpts[0]) << 32) | suffix_to_id[remaining];
|
| 1289 |
+
to_suffix_id_[piece_code] = num_pieces;
|
| 1290 |
+
|
| 1291 |
+
// Count number of pieces for this suffix
|
| 1292 |
+
int32_t pieces_for_suffix = 1; // sentinel row
|
| 1293 |
+
for (int32_t piece_length = static_cast<int32_t>(cpts.size()); piece_length > 0; --piece_length) {
|
| 1294 |
+
std::string piece;
|
| 1295 |
+
for (int32_t i = 0; i < piece_length; ++i) {
|
| 1296 |
+
piece += unicode_cpt_to_utf8(cpts[i]);
|
| 1297 |
+
}
|
| 1298 |
+
if (suffix_to_score.find(piece) != suffix_to_score.end()) {
|
| 1299 |
+
pieces_for_suffix++;
|
| 1300 |
+
}
|
| 1301 |
+
}
|
| 1302 |
+
num_pieces += pieces_for_suffix;
|
| 1303 |
+
} else {
|
| 1304 |
+
num_pieces++; // Empty suffix contributes one piece (sentinel row)
|
| 1305 |
+
}
|
| 1306 |
+
}
|
| 1307 |
+
|
| 1308 |
+
// Build flattened table
|
| 1309 |
+
table_.resize(num_pieces, std::vector<int32_t>(4, 0));
|
| 1310 |
+
int32_t table_idx = 0;
|
| 1311 |
+
|
| 1312 |
+
for (const auto & suffix : suffixes) {
|
| 1313 |
+
// Add all prefixes of the suffix to the table (in decreasing order of length)
|
| 1314 |
+
std::vector<uint32_t> cpts = unicode_cpts_from_utf8(suffix);
|
| 1315 |
+
for (int32_t piece_length = static_cast<int32_t>(cpts.size()); piece_length > 0; --piece_length) {
|
| 1316 |
+
std::string piece;
|
| 1317 |
+
for (int32_t i = 0; i < piece_length; ++i) {
|
| 1318 |
+
piece += unicode_cpt_to_utf8(cpts[i]);
|
| 1319 |
+
}
|
| 1320 |
+
|
| 1321 |
+
auto score_it = suffix_to_score.find(piece);
|
| 1322 |
+
if (score_it == suffix_to_score.end()) {
|
| 1323 |
+
continue;
|
| 1324 |
+
}
|
| 1325 |
+
|
| 1326 |
+
table_[table_idx][TABLE_PIECE_LENGTH] = piece_length;
|
| 1327 |
+
auto token_it = token_to_id.find(piece);
|
| 1328 |
+
table_[table_idx][TABLE_TOKEN_ID] = (token_it != token_to_id.end()) ? token_it->second : -1;
|
| 1329 |
+
|
| 1330 |
+
float score = score_it->second;
|
| 1331 |
+
table_[table_idx][TABLE_SCORE] = std::isfinite(score) ?
|
| 1332 |
+
static_cast<int32_t>(std::round(score * 1e4)) : INVALID_SCORE;
|
| 1333 |
+
table_[table_idx][TABLE_PIECE_ID] = suffix_to_id[piece];
|
| 1334 |
+
|
| 1335 |
+
table_idx++;
|
| 1336 |
+
}
|
| 1337 |
+
|
| 1338 |
+
// Add sentinel row
|
| 1339 |
+
table_[table_idx][TABLE_PIECE_LENGTH] = 1;
|
| 1340 |
+
table_[table_idx][TABLE_TOKEN_ID] = -1;
|
| 1341 |
+
table_[table_idx][TABLE_SCORE] = UNKNOWN_SCORE;
|
| 1342 |
+
table_idx++;
|
| 1343 |
+
}
|
| 1344 |
+
}
|
| 1345 |
+
|
| 1346 |
+
std::vector<llama_token> encode(const std::string & text) const {
|
| 1347 |
+
std::vector<uint32_t> unicode_data = unicode_cpts_from_utf8(text);
|
| 1348 |
+
// Skip the first code point if it is a BOM (Byte Order Mark)
|
| 1349 |
+
if (!unicode_data.empty() && unicode_data[0] == 0xFEFF) {
|
| 1350 |
+
unicode_data.erase(unicode_data.begin());
|
| 1351 |
+
}
|
| 1352 |
+
|
| 1353 |
+
if (unicode_data.empty()) {
|
| 1354 |
+
return {};
|
| 1355 |
+
}
|
| 1356 |
+
|
| 1357 |
+
const size_t data_len = unicode_data.size();
|
| 1358 |
+
|
| 1359 |
+
// Initialize scores array (dynamic programming)
|
| 1360 |
+
std::vector<int64_t> scores(data_len + 1, static_cast<int64_t>(1) << 60);
|
| 1361 |
+
scores[data_len] = 0;
|
| 1362 |
+
|
| 1363 |
+
// Path array to track best tokenization
|
| 1364 |
+
std::vector<std::vector<int32_t>> path(data_len + 1, std::vector<int32_t>(3, 0));
|
| 1365 |
+
|
| 1366 |
+
int32_t suffix_id = 0;
|
| 1367 |
+
|
| 1368 |
+
// Process from end to beginning
|
| 1369 |
+
for (int i = static_cast<int>(data_len) - 1; i >= 0; --i) {
|
| 1370 |
+
uint32_t c = unicode_data[i];
|
| 1371 |
+
|
| 1372 |
+
// Find next suffix ID
|
| 1373 |
+
for (size_t p = suffix_id; p < table_.size(); ++p) {
|
| 1374 |
+
int64_t piece_code = (static_cast<int64_t>(c) << 32) | table_[p][TABLE_PIECE_ID];
|
| 1375 |
+
auto it = to_suffix_id_.find(piece_code);
|
| 1376 |
+
suffix_id = (it != to_suffix_id_.end()) ? it->second : 0;
|
| 1377 |
+
|
| 1378 |
+
if (suffix_id > 0 || table_[p][TABLE_SCORE] == UNKNOWN_SCORE) {
|
| 1379 |
+
break;
|
| 1380 |
+
}
|
| 1381 |
+
}
|
| 1382 |
+
|
| 1383 |
+
// Update best path
|
| 1384 |
+
for (size_t p = suffix_id; p < table_.size(); ++p) {
|
| 1385 |
+
int32_t score = table_[p][TABLE_SCORE];
|
| 1386 |
+
if (score > INVALID_SCORE) {
|
| 1387 |
+
int32_t piece_length = table_[p][TABLE_PIECE_LENGTH];
|
| 1388 |
+
int64_t s = scores[i + piece_length] - score;
|
| 1389 |
+
|
| 1390 |
+
if (s < scores[i]) {
|
| 1391 |
+
scores[i] = s;
|
| 1392 |
+
path[i][PATH_TOKEN_LENGTH] = piece_length;
|
| 1393 |
+
path[i][PATH_TOKEN_ID] = table_[p][TABLE_TOKEN_ID];
|
| 1394 |
+
path[i][PATH_NUM_TOKENS] = path[i + piece_length][PATH_NUM_TOKENS] + 1;
|
| 1395 |
+
|
| 1396 |
+
if (score == UNKNOWN_SCORE) {
|
| 1397 |
+
// Add UTF-8 byte count
|
| 1398 |
+
path[i][PATH_NUM_TOKENS] += (c >= 0x80) + (c >= 0x800) + (c >= 0x10000);
|
| 1399 |
+
}
|
| 1400 |
+
}
|
| 1401 |
+
}
|
| 1402 |
+
|
| 1403 |
+
if (score == UNKNOWN_SCORE) {
|
| 1404 |
+
break;
|
| 1405 |
+
}
|
| 1406 |
+
}
|
| 1407 |
+
}
|
| 1408 |
+
|
| 1409 |
+
// Decode the best path
|
| 1410 |
+
std::vector<llama_token> token_ids;
|
| 1411 |
+
token_ids.reserve(path[0][PATH_NUM_TOKENS]);
|
| 1412 |
+
|
| 1413 |
+
int pos = 0;
|
| 1414 |
+
while (pos < static_cast<int>(data_len)) {
|
| 1415 |
+
if (path[pos][PATH_TOKEN_ID] >= 0) {
|
| 1416 |
+
token_ids.push_back(path[pos][PATH_TOKEN_ID]);
|
| 1417 |
+
} else {
|
| 1418 |
+
// Fall back to byte tokens
|
| 1419 |
+
uint32_t c = unicode_data[pos];
|
| 1420 |
+
int s = 1 + (c >= 0x80) + (c >= 0x800) + (c >= 0x10000);
|
| 1421 |
+
|
| 1422 |
+
for (int i = 0; i < s; ++i) {
|
| 1423 |
+
uint8_t b;
|
| 1424 |
+
if (s == 1) {
|
| 1425 |
+
b = c;
|
| 1426 |
+
} else {
|
| 1427 |
+
if (i == 0) {
|
| 1428 |
+
b = (0xF00 >> s) & 0xFF;
|
| 1429 |
+
} else {
|
| 1430 |
+
b = 0x80;
|
| 1431 |
+
}
|
| 1432 |
+
}
|
| 1433 |
+
token_ids.push_back(bytes_[b | ((c >> ((s - i - 1) * 6)) & 0x3F)]);
|
| 1434 |
+
}
|
| 1435 |
+
}
|
| 1436 |
+
|
| 1437 |
+
assert(path[pos][PATH_TOKEN_LENGTH] > 0);
|
| 1438 |
+
pos += path[pos][PATH_TOKEN_LENGTH];
|
| 1439 |
+
}
|
| 1440 |
+
|
| 1441 |
+
return token_ids;
|
| 1442 |
+
}
|
| 1443 |
+
private:
|
| 1444 |
+
// Constants for table structure
|
| 1445 |
+
static constexpr int32_t TABLE_PIECE_LENGTH = 0;
|
| 1446 |
+
static constexpr int32_t TABLE_TOKEN_ID = 1;
|
| 1447 |
+
static constexpr int32_t TABLE_SCORE = 2;
|
| 1448 |
+
static constexpr int32_t TABLE_PIECE_ID = 3;
|
| 1449 |
+
|
| 1450 |
+
// Constants for path array
|
| 1451 |
+
static constexpr int32_t PATH_TOKEN_LENGTH = 0;
|
| 1452 |
+
static constexpr int32_t PATH_TOKEN_ID = 1;
|
| 1453 |
+
static constexpr int32_t PATH_NUM_TOKENS = 2;
|
| 1454 |
+
|
| 1455 |
+
// Score constants
|
| 1456 |
+
static constexpr int32_t INVALID_SCORE = -20000000;
|
| 1457 |
+
static constexpr int32_t UNKNOWN_SCORE = -10000000;
|
| 1458 |
+
|
| 1459 |
+
// List of tokens in the vocabulary
|
| 1460 |
+
std::vector<std::string> tokens_;
|
| 1461 |
+
|
| 1462 |
+
// Mapping from byte code point to token ID (for byte fallback)
|
| 1463 |
+
std::vector<llama_token> bytes_;
|
| 1464 |
+
|
| 1465 |
+
// Mapping from piece code to suffix ID
|
| 1466 |
+
std::unordered_map<int64_t, int32_t> to_suffix_id_;
|
| 1467 |
+
|
| 1468 |
+
// Flattened table representing the Trie structure
|
| 1469 |
+
// Each row contains: [piece_length, token_id, score, piece_id]
|
| 1470 |
+
std::vector<std::vector<int32_t>> table_;
|
| 1471 |
+
};
|
| 1472 |
+
|
| 1473 |
+
struct llm_tokenizer_plamo2_session {
|
| 1474 |
+
llm_tokenizer_plamo2_session(const llm_tokenizer_plamo2 & tokenizer) : tokenizer(tokenizer) {}
|
| 1475 |
+
|
| 1476 |
+
void tokenize(const std::string & text, std::vector<llama_token> & output) {
|
| 1477 |
+
std::vector<llama_token> tokens = tokenizer.encode(text);
|
| 1478 |
+
output.insert(output.end(), tokens.begin(), tokens.end());
|
| 1479 |
+
}
|
| 1480 |
+
|
| 1481 |
+
private:
|
| 1482 |
+
const llm_tokenizer_plamo2 & tokenizer;
|
| 1483 |
+
};
|
| 1484 |
+
|
| 1485 |
//
|
| 1486 |
// impl
|
| 1487 |
//
|
|
|
|
| 1785 |
special_unk_id = LLAMA_TOKEN_NULL;
|
| 1786 |
special_sep_id = LLAMA_TOKEN_NULL;
|
| 1787 |
special_pad_id = LLAMA_TOKEN_NULL;
|
| 1788 |
+
} else if (tokenizer_model == "plamo2") {
|
| 1789 |
+
type = LLAMA_VOCAB_TYPE_PLAMO2;
|
| 1790 |
+
|
| 1791 |
+
// PLaMo-2 default special tokens (these will be overridden by model config)
|
| 1792 |
+
special_bos_id = 1; // <|plamo:bos|>
|
| 1793 |
+
special_eos_id = 2; // <|plamo:eos|>
|
| 1794 |
+
special_unk_id = 0; // <|plamo:unk|>
|
| 1795 |
+
special_sep_id = LLAMA_TOKEN_NULL;
|
| 1796 |
+
special_pad_id = 3; // <|plamo:pad|>
|
| 1797 |
+
special_mask_id = LLAMA_TOKEN_NULL;
|
| 1798 |
} else {
|
| 1799 |
throw std::runtime_error(format("unknown tokenizer: '%s'", tokenizer_model.c_str()));
|
| 1800 |
}
|
|
|
|
| 1925 |
} else if (
|
| 1926 |
tokenizer_pre == "exaone") {
|
| 1927 |
pre_type = LLAMA_VOCAB_PRE_TYPE_EXAONE;
|
| 1928 |
+
} else if (
|
| 1929 |
+
tokenizer_pre == "exaone4") {
|
| 1930 |
+
pre_type = LLAMA_VOCAB_PRE_TYPE_GPT2;
|
| 1931 |
} else if (
|
| 1932 |
tokenizer_pre == "chameleon") {
|
| 1933 |
pre_type = LLAMA_VOCAB_PRE_TYPE_CHAMELEON;
|
|
|
|
| 1964 |
tokenizer_pre == "hunyuan") {
|
| 1965 |
pre_type = LLAMA_VOCAB_PRE_TYPE_HUNYUAN;
|
| 1966 |
clean_spaces = false;
|
| 1967 |
+
} else if (
|
| 1968 |
+
tokenizer_pre == "kimi-k2") {
|
| 1969 |
+
pre_type = LLAMA_VOCAB_PRE_TYPE_KIMI_K2;
|
| 1970 |
+
clean_spaces = false;
|
| 1971 |
} else {
|
| 1972 |
throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
|
| 1973 |
}
|
|
|
|
| 2448 |
|
| 2449 |
std::string llama_vocab::impl::type_name() const{
|
| 2450 |
switch (type) {
|
| 2451 |
+
case LLAMA_VOCAB_TYPE_NONE: return "no vocab";
|
| 2452 |
+
case LLAMA_VOCAB_TYPE_SPM: return "SPM";
|
| 2453 |
+
case LLAMA_VOCAB_TYPE_BPE: return "BPE";
|
| 2454 |
+
case LLAMA_VOCAB_TYPE_WPM: return "WPM";
|
| 2455 |
+
case LLAMA_VOCAB_TYPE_UGM: return "UGM";
|
| 2456 |
+
case LLAMA_VOCAB_TYPE_RWKV: return "RWKV";
|
| 2457 |
+
case LLAMA_VOCAB_TYPE_PLAMO2: return "PLaMo2";
|
| 2458 |
+
default: return "unknown";
|
| 2459 |
}
|
| 2460 |
}
|
| 2461 |
|
|
|
|
| 2538 |
case LLAMA_VOCAB_TYPE_RWKV:
|
| 2539 |
tokenizer = std::make_unique<llm_tokenizer_rwkv>(vocab);
|
| 2540 |
break;
|
| 2541 |
+
case LLAMA_VOCAB_TYPE_PLAMO2:
|
| 2542 |
+
tokenizer = std::make_unique<llm_tokenizer_plamo2>(vocab);
|
| 2543 |
+
break;
|
| 2544 |
default:
|
| 2545 |
GGML_ABORT("unsupported vocab type");
|
| 2546 |
}
|
|
|
|
| 2873 |
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
|
| 2874 |
std::string text = fragment.raw_text.substr(fragment.offset, fragment.length);
|
| 2875 |
|
| 2876 |
+
#ifdef PRETOKENIZERDEBUG
|
| 2877 |
+
LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", text.length(), fragment.offset, fragment.length, text.c_str());
|
| 2878 |
+
#endif
|
| 2879 |
+
|
| 2880 |
+
session.tokenize(text, output);
|
| 2881 |
+
} else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
|
| 2882 |
+
output.push_back(fragment.token);
|
| 2883 |
+
}
|
| 2884 |
+
}
|
| 2885 |
+
} break;
|
| 2886 |
+
case LLAMA_VOCAB_TYPE_PLAMO2:
|
| 2887 |
+
{
|
| 2888 |
+
llm_tokenizer_plamo2_session session(*static_cast<const llm_tokenizer_plamo2 *>(tokenizer.get()));
|
| 2889 |
+
for (const auto & fragment : fragment_buffer) {
|
| 2890 |
+
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
|
| 2891 |
+
std::string text = fragment.raw_text.substr(fragment.offset, fragment.length);
|
| 2892 |
+
|
| 2893 |
#ifdef PRETOKENIZERDEBUG
|
| 2894 |
LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", text.length(), fragment.offset, fragment.length, text.c_str());
|
| 2895 |
#endif
|
|
|
|
| 2988 |
memcpy(buf, result.data(), result.size());
|
| 2989 |
return (int)result.size();
|
| 2990 |
}
|
| 2991 |
+
case LLAMA_VOCAB_TYPE_PLAMO2: {
|
| 2992 |
+
// PLaMo-2 uses similar token handling as BPE/SPM
|
| 2993 |
+
if (vocab.is_byte(token)) {
|
| 2994 |
+
// Handle byte tokens like <0xXX>
|
| 2995 |
+
if (token_text.length() == 6 && token_text.substr(0, 3) == "<0x" && token_text.back() == '>') {
|
| 2996 |
+
int hex_val = std::stoi(token_text.substr(3, 2), nullptr, 16);
|
| 2997 |
+
if (length < 1) {
|
| 2998 |
+
return -1;
|
| 2999 |
+
}
|
| 3000 |
+
buf[0] = static_cast<char>(hex_val);
|
| 3001 |
+
return 1;
|
| 3002 |
+
}
|
| 3003 |
+
}
|
| 3004 |
+
|
| 3005 |
+
// Normal token - just copy the text
|
| 3006 |
+
std::string result = token_text;
|
| 3007 |
+
return _try_copy(result.data(), result.size());
|
| 3008 |
+
}
|
| 3009 |
default:
|
| 3010 |
GGML_ABORT("fatal error");
|
| 3011 |
}
|
|
|
|
| 3250 |
case LLAMA_VOCAB_TYPE_BPE: {
|
| 3251 |
return pimpl->token_to_id.at(unicode_byte_to_utf8(ch));
|
| 3252 |
}
|
| 3253 |
+
case LLAMA_VOCAB_TYPE_PLAMO2: {
|
| 3254 |
+
// PLaMo-2 uses byte tokens in format <0xXX>
|
| 3255 |
+
char hex_str[8];
|
| 3256 |
+
snprintf(hex_str, sizeof(hex_str), "<0x%02X>", ch);
|
| 3257 |
+
return pimpl->token_to_id.at(hex_str);
|
| 3258 |
+
}
|
| 3259 |
default:
|
| 3260 |
GGML_ABORT("fatal error");
|
| 3261 |
}
|
|
|
|
| 3357 |
return pimpl->special_fim_sep_id;
|
| 3358 |
}
|
| 3359 |
|
| 3360 |
+
llama_token llama_vocab::token_mask() const {
|
| 3361 |
+
return pimpl->special_mask_id;
|
| 3362 |
+
}
|
| 3363 |
+
|
| 3364 |
bool llama_vocab::get_add_space_prefix() const {
|
| 3365 |
return pimpl->add_space_prefix;
|
| 3366 |
}
|
|
|
|
| 3601 |
return vocab->token_fim_sep();
|
| 3602 |
}
|
| 3603 |
|
| 3604 |
+
llama_token llama_vocab_mask(const struct llama_vocab* vocab) {
|
| 3605 |
+
return vocab->token_mask();
|
| 3606 |
+
}
|
| 3607 |
+
|
| 3608 |
// deprecated
|
| 3609 |
const char * llama_token_get_text(const struct llama_vocab * vocab, llama_token token) {
|
| 3610 |
return llama_vocab_get_text(vocab, token);
|
|
|
|
| 3741 |
bool unparse_special) {
|
| 3742 |
return vocab->detokenize(tokens, n_tokens, text, text_len_max, remove_special, unparse_special);
|
| 3743 |
}
|
|
|
examples/talk-llama/llama-vocab.h
CHANGED
|
@@ -45,6 +45,7 @@ enum llama_vocab_pre_type {
|
|
| 45 |
LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34,
|
| 46 |
LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 35,
|
| 47 |
LLAMA_VOCAB_PRE_TYPE_HUNYUAN = 36,
|
|
|
|
| 48 |
};
|
| 49 |
|
| 50 |
struct LLM_KV;
|
|
@@ -100,6 +101,7 @@ struct llama_vocab {
|
|
| 100 |
llama_token token_sep() const;
|
| 101 |
llama_token token_nl () const;
|
| 102 |
llama_token token_pad() const;
|
|
|
|
| 103 |
|
| 104 |
llama_token token_prefix() const;
|
| 105 |
llama_token token_middle() const;
|
|
|
|
| 45 |
LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34,
|
| 46 |
LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 35,
|
| 47 |
LLAMA_VOCAB_PRE_TYPE_HUNYUAN = 36,
|
| 48 |
+
LLAMA_VOCAB_PRE_TYPE_KIMI_K2 = 37,
|
| 49 |
};
|
| 50 |
|
| 51 |
struct LLM_KV;
|
|
|
|
| 101 |
llama_token token_sep() const;
|
| 102 |
llama_token token_nl () const;
|
| 103 |
llama_token token_pad() const;
|
| 104 |
+
llama_token token_mask() const;
|
| 105 |
|
| 106 |
llama_token token_prefix() const;
|
| 107 |
llama_token token_middle() const;
|
examples/talk-llama/llama.h
CHANGED
|
@@ -71,12 +71,13 @@ extern "C" {
|
|
| 71 |
typedef int32_t llama_seq_id;
|
| 72 |
|
| 73 |
enum llama_vocab_type {
|
| 74 |
-
LLAMA_VOCAB_TYPE_NONE
|
| 75 |
-
LLAMA_VOCAB_TYPE_SPM
|
| 76 |
-
LLAMA_VOCAB_TYPE_BPE
|
| 77 |
-
LLAMA_VOCAB_TYPE_WPM
|
| 78 |
-
LLAMA_VOCAB_TYPE_UGM
|
| 79 |
-
LLAMA_VOCAB_TYPE_RWKV
|
|
|
|
| 80 |
};
|
| 81 |
|
| 82 |
enum llama_rope_type {
|
|
@@ -334,6 +335,9 @@ extern "C" {
|
|
| 334 |
bool swa_full; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
|
| 335 |
// NOTE: setting to false when n_seq_max > 1 can cause bad performance in some cases
|
| 336 |
// ref: https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573
|
|
|
|
|
|
|
|
|
|
| 337 |
};
|
| 338 |
|
| 339 |
// model quantization parameters
|
|
@@ -724,7 +728,7 @@ extern "C" {
|
|
| 724 |
// - lazily on next llama_decode()
|
| 725 |
// p0 < 0 : [0, p1]
|
| 726 |
// p1 < 0 : [p0, inf)
|
| 727 |
-
DEPRECATED(void llama_kv_self_seq_div(
|
| 728 |
struct llama_context * ctx,
|
| 729 |
llama_seq_id seq_id,
|
| 730 |
llama_pos p0,
|
|
@@ -952,6 +956,7 @@ extern "C" {
|
|
| 952 |
// in the order they have appeared in the batch.
|
| 953 |
// Rows: number of tokens for which llama_batch.logits[i] != 0
|
| 954 |
// Cols: n_vocab
|
|
|
|
| 955 |
LLAMA_API float * llama_get_logits(struct llama_context * ctx);
|
| 956 |
|
| 957 |
// Logits for the ith token. For positive indices, Equivalent to:
|
|
@@ -966,6 +971,7 @@ extern "C" {
|
|
| 966 |
// in the order they have appeared in the batch.
|
| 967 |
// shape: [n_outputs*n_embd]
|
| 968 |
// Otherwise, returns NULL.
|
|
|
|
| 969 |
LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
|
| 970 |
|
| 971 |
// Get the embeddings for the ith token. For positive indices, Equivalent to:
|
|
@@ -1004,6 +1010,7 @@ extern "C" {
|
|
| 1004 |
LLAMA_API llama_token llama_vocab_sep(const struct llama_vocab * vocab); // sentence separator
|
| 1005 |
LLAMA_API llama_token llama_vocab_nl (const struct llama_vocab * vocab); // next-line
|
| 1006 |
LLAMA_API llama_token llama_vocab_pad(const struct llama_vocab * vocab); // padding
|
|
|
|
| 1007 |
|
| 1008 |
LLAMA_API bool llama_vocab_get_add_bos(const struct llama_vocab * vocab);
|
| 1009 |
LLAMA_API bool llama_vocab_get_add_eos(const struct llama_vocab * vocab);
|
|
@@ -1389,6 +1396,7 @@ extern "C" {
|
|
| 1389 |
|
| 1390 |
int32_t n_p_eval;
|
| 1391 |
int32_t n_eval;
|
|
|
|
| 1392 |
};
|
| 1393 |
|
| 1394 |
struct llama_perf_sampler_data {
|
|
|
|
| 71 |
typedef int32_t llama_seq_id;
|
| 72 |
|
| 73 |
enum llama_vocab_type {
|
| 74 |
+
LLAMA_VOCAB_TYPE_NONE = 0, // For models without vocab
|
| 75 |
+
LLAMA_VOCAB_TYPE_SPM = 1, // LLaMA tokenizer based on byte-level BPE with byte fallback
|
| 76 |
+
LLAMA_VOCAB_TYPE_BPE = 2, // GPT-2 tokenizer based on byte-level BPE
|
| 77 |
+
LLAMA_VOCAB_TYPE_WPM = 3, // BERT tokenizer based on WordPiece
|
| 78 |
+
LLAMA_VOCAB_TYPE_UGM = 4, // T5 tokenizer based on Unigram
|
| 79 |
+
LLAMA_VOCAB_TYPE_RWKV = 5, // RWKV tokenizer based on greedy tokenization
|
| 80 |
+
LLAMA_VOCAB_TYPE_PLAMO2 = 6, // PLaMo-2 tokenizer based on Aho-Corasick with dynamic programming
|
| 81 |
};
|
| 82 |
|
| 83 |
enum llama_rope_type {
|
|
|
|
| 335 |
bool swa_full; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
|
| 336 |
// NOTE: setting to false when n_seq_max > 1 can cause bad performance in some cases
|
| 337 |
// ref: https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573
|
| 338 |
+
bool kv_unified; // use a unified buffer across the input sequences when computing the attention
|
| 339 |
+
// try to disable when n_seq_max > 1 for improved performance when the sequences do not share a large prefix
|
| 340 |
+
// ref: https://github.com/ggml-org/llama.cpp/pull/14363
|
| 341 |
};
|
| 342 |
|
| 343 |
// model quantization parameters
|
|
|
|
| 728 |
// - lazily on next llama_decode()
|
| 729 |
// p0 < 0 : [0, p1]
|
| 730 |
// p1 < 0 : [p0, inf)
|
| 731 |
+
DEPRECATED(LLAMA_API void llama_kv_self_seq_div(
|
| 732 |
struct llama_context * ctx,
|
| 733 |
llama_seq_id seq_id,
|
| 734 |
llama_pos p0,
|
|
|
|
| 956 |
// in the order they have appeared in the batch.
|
| 957 |
// Rows: number of tokens for which llama_batch.logits[i] != 0
|
| 958 |
// Cols: n_vocab
|
| 959 |
+
// TODO: deprecate in favor of llama_get_logits_ith() (ref: https://github.com/ggml-org/llama.cpp/pull/14853#issuecomment-3113143522)
|
| 960 |
LLAMA_API float * llama_get_logits(struct llama_context * ctx);
|
| 961 |
|
| 962 |
// Logits for the ith token. For positive indices, Equivalent to:
|
|
|
|
| 971 |
// in the order they have appeared in the batch.
|
| 972 |
// shape: [n_outputs*n_embd]
|
| 973 |
// Otherwise, returns NULL.
|
| 974 |
+
// TODO: deprecate in favor of llama_get_embeddings_ith() (ref: https://github.com/ggml-org/llama.cpp/pull/14853#issuecomment-3113143522)
|
| 975 |
LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
|
| 976 |
|
| 977 |
// Get the embeddings for the ith token. For positive indices, Equivalent to:
|
|
|
|
| 1010 |
LLAMA_API llama_token llama_vocab_sep(const struct llama_vocab * vocab); // sentence separator
|
| 1011 |
LLAMA_API llama_token llama_vocab_nl (const struct llama_vocab * vocab); // next-line
|
| 1012 |
LLAMA_API llama_token llama_vocab_pad(const struct llama_vocab * vocab); // padding
|
| 1013 |
+
LLAMA_API llama_token llama_vocab_mask(const struct llama_vocab * vocab); // mask
|
| 1014 |
|
| 1015 |
LLAMA_API bool llama_vocab_get_add_bos(const struct llama_vocab * vocab);
|
| 1016 |
LLAMA_API bool llama_vocab_get_add_eos(const struct llama_vocab * vocab);
|
|
|
|
| 1396 |
|
| 1397 |
int32_t n_p_eval;
|
| 1398 |
int32_t n_eval;
|
| 1399 |
+
int32_t n_reused; // number of times a ggml compute graph had been reused
|
| 1400 |
};
|
| 1401 |
|
| 1402 |
struct llama_perf_sampler_data {
|
examples/talk-llama/unicode.cpp
CHANGED
|
@@ -557,6 +557,178 @@ static std::vector<size_t> unicode_regex_split_stl(const std::string & text, con
|
|
| 557 |
return bpe_offsets;
|
| 558 |
}
|
| 559 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 560 |
static std::vector<size_t> unicode_regex_split_custom(const std::string & text, const std::string & regex_expr, const std::vector<size_t> & offsets) {
|
| 561 |
std::vector<size_t> bpe_offsets;
|
| 562 |
|
|
@@ -567,6 +739,9 @@ static std::vector<size_t> unicode_regex_split_custom(const std::string & text,
|
|
| 567 |
regex_expr == "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+") {
|
| 568 |
|
| 569 |
bpe_offsets = unicode_regex_split_custom_llama3(text, offsets);
|
|
|
|
|
|
|
|
|
|
| 570 |
}
|
| 571 |
|
| 572 |
return bpe_offsets;
|
|
@@ -672,6 +847,38 @@ uint32_t unicode_tolower(uint32_t cpt) {
|
|
| 672 |
return cpt; // Return the original code point if no lowercase mapping is found
|
| 673 |
}
|
| 674 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 675 |
std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs) {
|
| 676 |
// unicode categories
|
| 677 |
static const std::map<std::string, int> k_ucat_enum = {
|
|
|
|
| 557 |
return bpe_offsets;
|
| 558 |
}
|
| 559 |
|
| 560 |
+
// K2 system regex patterns (from tokenization_kimi.py):
|
| 561 |
+
// [\p{Han}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+
|
| 562 |
+
static std::vector<size_t> unicode_regex_split_custom_kimi_k2(const std::string & text, const std::vector<size_t> & offsets) {
|
| 563 |
+
std::vector<size_t> bpe_offsets;
|
| 564 |
+
bpe_offsets.reserve(offsets.size());
|
| 565 |
+
|
| 566 |
+
const auto cpts = unicode_cpts_from_utf8(text);
|
| 567 |
+
|
| 568 |
+
size_t start = 0;
|
| 569 |
+
for (auto offset : offsets) {
|
| 570 |
+
const size_t offset_ini = start;
|
| 571 |
+
const size_t offset_end = start + offset;
|
| 572 |
+
assert(offset_end <= cpts.size());
|
| 573 |
+
start = offset_end;
|
| 574 |
+
|
| 575 |
+
static const uint32_t OUT_OF_RANGE = 0xFFFFFFFF;
|
| 576 |
+
auto _get_cpt = [&] (const size_t pos) -> uint32_t {
|
| 577 |
+
return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE;
|
| 578 |
+
};
|
| 579 |
+
|
| 580 |
+
auto _get_flags = [&] (const size_t pos) -> unicode_cpt_flags {
|
| 581 |
+
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags_from_cpt(cpts[pos]) : unicode_cpt_flags{};
|
| 582 |
+
};
|
| 583 |
+
|
| 584 |
+
size_t _prev_end = offset_ini;
|
| 585 |
+
auto _add_token = [&] (const size_t end) -> size_t {
|
| 586 |
+
assert(_prev_end <= end && end <= offset_end);
|
| 587 |
+
size_t len = end - _prev_end;
|
| 588 |
+
if (len > 0) {
|
| 589 |
+
bpe_offsets.push_back(len);
|
| 590 |
+
}
|
| 591 |
+
_prev_end = end;
|
| 592 |
+
return len;
|
| 593 |
+
};
|
| 594 |
+
|
| 595 |
+
for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) {
|
| 596 |
+
const uint32_t cpt = _get_cpt(pos);
|
| 597 |
+
const auto flags = _get_flags(pos);
|
| 598 |
+
|
| 599 |
+
// Pattern 1: [\p{Han}]+ (Chinese characters)
|
| 600 |
+
if (unicode_cpt_is_han(cpt)) {
|
| 601 |
+
while (unicode_cpt_is_han(_get_cpt(pos))) {
|
| 602 |
+
pos++;
|
| 603 |
+
}
|
| 604 |
+
_add_token(pos);
|
| 605 |
+
continue;
|
| 606 |
+
}
|
| 607 |
+
|
| 608 |
+
// Pattern 2 & 3: Letter words excluding Han characters with optional contractions
|
| 609 |
+
// [^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+(?:'s|'t|'re|'ve|'m|'ll|'d)?
|
| 610 |
+
// [^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*(?:'s|'t|'re|'ve|'m|'ll|'d)?
|
| 611 |
+
// Check if current char is a letter OR if current char could be a leading char and next char is a letter
|
| 612 |
+
bool is_letter_pattern = (flags.is_letter && !unicode_cpt_is_han(cpt)) ||
|
| 613 |
+
(!(cpt == '\r' || cpt == '\n' || flags.is_letter || flags.is_number) &&
|
| 614 |
+
_get_flags(pos + 1).is_letter && !unicode_cpt_is_han(_get_cpt(pos + 1)));
|
| 615 |
+
|
| 616 |
+
if (is_letter_pattern) {
|
| 617 |
+
// Handle optional leading non-letter/non-number character
|
| 618 |
+
bool has_leading_char = false;
|
| 619 |
+
if (!(cpt == '\r' || cpt == '\n' || flags.is_letter || flags.is_number)) {
|
| 620 |
+
has_leading_char = true;
|
| 621 |
+
pos++;
|
| 622 |
+
}
|
| 623 |
+
|
| 624 |
+
// Match letter sequence (excluding Han characters)
|
| 625 |
+
bool has_letters = false;
|
| 626 |
+
while (_get_flags(pos).is_letter && !unicode_cpt_is_han(_get_cpt(pos))) {
|
| 627 |
+
has_letters = true;
|
| 628 |
+
pos++;
|
| 629 |
+
}
|
| 630 |
+
|
| 631 |
+
// Only proceed if we found letters (after potentially skipping leading char)
|
| 632 |
+
if (has_letters || (!has_leading_char && _get_flags(pos).is_letter && !unicode_cpt_is_han(_get_cpt(pos)))) {
|
| 633 |
+
if (!has_letters) pos++; // consume the first letter if we didn't already
|
| 634 |
+
|
| 635 |
+
// Continue consuming letters
|
| 636 |
+
while (_get_flags(pos).is_letter && !unicode_cpt_is_han(_get_cpt(pos))) {
|
| 637 |
+
pos++;
|
| 638 |
+
}
|
| 639 |
+
|
| 640 |
+
// Check for optional contractions (?:'s|'t|'re|'ve|'m|'ll|'d)
|
| 641 |
+
if (_get_cpt(pos) == '\'' && pos + 1 < offset_end) {
|
| 642 |
+
uint32_t cpt_next = unicode_tolower(_get_cpt(pos + 1));
|
| 643 |
+
if (cpt_next == 's' || cpt_next == 't' || cpt_next == 'm' || cpt_next == 'd') {
|
| 644 |
+
pos += 2;
|
| 645 |
+
} else if (pos + 2 < offset_end) {
|
| 646 |
+
uint32_t cpt_next_next = unicode_tolower(_get_cpt(pos + 2));
|
| 647 |
+
if ((cpt_next == 'r' && cpt_next_next == 'e') ||
|
| 648 |
+
(cpt_next == 'v' && cpt_next_next == 'e') ||
|
| 649 |
+
(cpt_next == 'l' && cpt_next_next == 'l')) {
|
| 650 |
+
pos += 3;
|
| 651 |
+
}
|
| 652 |
+
}
|
| 653 |
+
}
|
| 654 |
+
|
| 655 |
+
_add_token(pos);
|
| 656 |
+
continue;
|
| 657 |
+
} else if (has_leading_char) {
|
| 658 |
+
// We consumed a leading char but found no letters, backtrack
|
| 659 |
+
pos--;
|
| 660 |
+
}
|
| 661 |
+
}
|
| 662 |
+
|
| 663 |
+
// Pattern 4: \p{N}{1,3} (numbers 1-3 digits)
|
| 664 |
+
if (flags.is_number) {
|
| 665 |
+
size_t ini = pos;
|
| 666 |
+
while (_get_flags(pos).is_number) {
|
| 667 |
+
if (++pos - ini >= 3) {
|
| 668 |
+
_add_token(pos);
|
| 669 |
+
ini = pos;
|
| 670 |
+
}
|
| 671 |
+
}
|
| 672 |
+
_add_token(pos);
|
| 673 |
+
continue;
|
| 674 |
+
}
|
| 675 |
+
|
| 676 |
+
// Pattern 5: ?[^\s\p{L}\p{N}]+[\r\n]* (optional space + non-word chars + optional newlines)
|
| 677 |
+
auto flags2 = (cpt == ' ' ? _get_flags(pos + 1) : flags);
|
| 678 |
+
if (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number) && flags2.as_uint()) {
|
| 679 |
+
pos += (cpt == ' ');
|
| 680 |
+
while (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number) && flags2.as_uint()) {
|
| 681 |
+
flags2 = _get_flags(++pos);
|
| 682 |
+
}
|
| 683 |
+
// Match optional [\r\n]*
|
| 684 |
+
uint32_t cpt2 = _get_cpt(pos);
|
| 685 |
+
while (cpt2 == '\r' || cpt2 == '\n') {
|
| 686 |
+
cpt2 = _get_cpt(++pos);
|
| 687 |
+
}
|
| 688 |
+
_add_token(pos);
|
| 689 |
+
continue;
|
| 690 |
+
}
|
| 691 |
+
|
| 692 |
+
// Count whitespace characters
|
| 693 |
+
size_t num_whitespaces = 0;
|
| 694 |
+
size_t last_end_r_or_n = 0;
|
| 695 |
+
while (_get_flags(pos + num_whitespaces).is_whitespace) {
|
| 696 |
+
uint32_t cpt2 = _get_cpt(pos + num_whitespaces);
|
| 697 |
+
if (cpt2 == '\r' || cpt2 == '\n') {
|
| 698 |
+
last_end_r_or_n = pos + num_whitespaces + 1;
|
| 699 |
+
}
|
| 700 |
+
num_whitespaces++;
|
| 701 |
+
}
|
| 702 |
+
|
| 703 |
+
// Pattern 6: \s*[\r\n]+ (whitespace with newlines)
|
| 704 |
+
if (last_end_r_or_n > 0) {
|
| 705 |
+
pos = last_end_r_or_n;
|
| 706 |
+
_add_token(pos);
|
| 707 |
+
continue;
|
| 708 |
+
}
|
| 709 |
+
|
| 710 |
+
// Pattern 7: \s+(?!\S) (trailing whitespace)
|
| 711 |
+
if (num_whitespaces > 1 && _get_cpt(pos + num_whitespaces) != OUT_OF_RANGE) {
|
| 712 |
+
pos += num_whitespaces - 1;
|
| 713 |
+
_add_token(pos);
|
| 714 |
+
continue;
|
| 715 |
+
}
|
| 716 |
+
|
| 717 |
+
// Pattern 8: \s+ (general whitespace)
|
| 718 |
+
if (num_whitespaces > 0) {
|
| 719 |
+
pos += num_whitespaces;
|
| 720 |
+
_add_token(pos);
|
| 721 |
+
continue;
|
| 722 |
+
}
|
| 723 |
+
|
| 724 |
+
// No matches - consume single character
|
| 725 |
+
_add_token(++pos);
|
| 726 |
+
}
|
| 727 |
+
}
|
| 728 |
+
|
| 729 |
+
return bpe_offsets;
|
| 730 |
+
}
|
| 731 |
+
|
| 732 |
static std::vector<size_t> unicode_regex_split_custom(const std::string & text, const std::string & regex_expr, const std::vector<size_t> & offsets) {
|
| 733 |
std::vector<size_t> bpe_offsets;
|
| 734 |
|
|
|
|
| 739 |
regex_expr == "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+") {
|
| 740 |
|
| 741 |
bpe_offsets = unicode_regex_split_custom_llama3(text, offsets);
|
| 742 |
+
} else if (regex_expr == "\\p{Han}+") {
|
| 743 |
+
// K2's first pattern - handle all K2 patterns together
|
| 744 |
+
bpe_offsets = unicode_regex_split_custom_kimi_k2(text, offsets);
|
| 745 |
}
|
| 746 |
|
| 747 |
return bpe_offsets;
|
|
|
|
| 847 |
return cpt; // Return the original code point if no lowercase mapping is found
|
| 848 |
}
|
| 849 |
|
| 850 |
+
bool unicode_cpt_is_han(uint32_t cpt) {
|
| 851 |
+
// Han character ranges (Chinese/CJK characters)
|
| 852 |
+
// CJK Unified Ideographs (most common)
|
| 853 |
+
if (cpt >= 0x4E00 && cpt <= 0x9FFF) return true;
|
| 854 |
+
|
| 855 |
+
// CJK Extension A
|
| 856 |
+
if (cpt >= 0x3400 && cpt <= 0x4DBF) return true;
|
| 857 |
+
|
| 858 |
+
// CJK Extension B
|
| 859 |
+
if (cpt >= 0x20000 && cpt <= 0x2A6DF) return true;
|
| 860 |
+
|
| 861 |
+
// CJK Extension C
|
| 862 |
+
if (cpt >= 0x2A700 && cpt <= 0x2B73F) return true;
|
| 863 |
+
|
| 864 |
+
// CJK Extension D
|
| 865 |
+
if (cpt >= 0x2B740 && cpt <= 0x2B81F) return true;
|
| 866 |
+
|
| 867 |
+
// CJK Extension E
|
| 868 |
+
if (cpt >= 0x2B820 && cpt <= 0x2CEAF) return true;
|
| 869 |
+
|
| 870 |
+
// CJK Extension F
|
| 871 |
+
if (cpt >= 0x2CEB0 && cpt <= 0x2EBEF) return true;
|
| 872 |
+
|
| 873 |
+
// CJK Compatibility Ideographs
|
| 874 |
+
if (cpt >= 0xF900 && cpt <= 0xFAFF) return true;
|
| 875 |
+
|
| 876 |
+
// CJK Compatibility Ideographs Supplement
|
| 877 |
+
if (cpt >= 0x2F800 && cpt <= 0x2FA1F) return true;
|
| 878 |
+
|
| 879 |
+
return false;
|
| 880 |
+
}
|
| 881 |
+
|
| 882 |
std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs) {
|
| 883 |
// unicode categories
|
| 884 |
static const std::map<std::string, int> k_ucat_enum = {
|
examples/talk-llama/unicode.h
CHANGED
|
@@ -63,4 +63,6 @@ uint8_t unicode_utf8_to_byte(const std::string & utf8);
|
|
| 63 |
|
| 64 |
uint32_t unicode_tolower(uint32_t cpt);
|
| 65 |
|
|
|
|
|
|
|
| 66 |
std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs);
|
|
|
|
| 63 |
|
| 64 |
uint32_t unicode_tolower(uint32_t cpt);
|
| 65 |
|
| 66 |
+
bool unicode_cpt_is_han(uint32_t cpt);
|
| 67 |
+
|
| 68 |
std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs);
|