Spaces:
Running
Running
talk-llama : sync llama.cpp
Browse files- examples/talk-llama/llama-arch.cpp +64 -0
- examples/talk-llama/llama-arch.h +3 -0
- examples/talk-llama/llama-batch.cpp +270 -19
- examples/talk-llama/llama-batch.h +36 -11
- examples/talk-llama/llama-chat.cpp +17 -0
- examples/talk-llama/llama-chat.h +1 -0
- examples/talk-llama/llama-context.cpp +79 -110
- examples/talk-llama/llama-context.h +8 -6
- examples/talk-llama/llama-cparams.cpp +1 -1
- examples/talk-llama/llama-cparams.h +1 -1
- examples/talk-llama/llama-graph.cpp +52 -60
- examples/talk-llama/llama-graph.h +7 -21
- examples/talk-llama/llama-kv-cache-recurrent.cpp +64 -81
- examples/talk-llama/llama-kv-cache-recurrent.h +7 -8
- examples/talk-llama/llama-kv-cache-unified-iswa.cpp +54 -21
- examples/talk-llama/llama-kv-cache-unified-iswa.h +1 -2
- examples/talk-llama/llama-kv-cache-unified.cpp +144 -79
- examples/talk-llama/llama-kv-cache-unified.h +3 -2
- examples/talk-llama/llama-kv-cells.h +8 -8
- examples/talk-llama/llama-memory.h +1 -2
- examples/talk-llama/llama-model.cpp +576 -29
- examples/talk-llama/llama-model.h +1 -0
- examples/talk-llama/llama-quant.cpp +2 -1
- examples/talk-llama/llama-vocab.cpp +25 -20
- examples/talk-llama/llama.cpp +11 -7
- examples/talk-llama/llama.h +10 -7
examples/talk-llama/llama-arch.cpp
CHANGED
|
@@ -20,6 +20,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
|
| 20 |
{ LLM_ARCH_BERT, "bert" },
|
| 21 |
{ LLM_ARCH_NOMIC_BERT, "nomic-bert" },
|
| 22 |
{ LLM_ARCH_NOMIC_BERT_MOE, "nomic-bert-moe" },
|
|
|
|
| 23 |
{ LLM_ARCH_JINA_BERT_V2, "jina-bert-v2" },
|
| 24 |
{ LLM_ARCH_BLOOM, "bloom" },
|
| 25 |
{ LLM_ARCH_STABLELM, "stablelm" },
|
|
@@ -72,6 +73,8 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
|
| 72 |
{ LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" },
|
| 73 |
{ LLM_ARCH_PLM, "plm" },
|
| 74 |
{ LLM_ARCH_BAILINGMOE, "bailingmoe" },
|
|
|
|
|
|
|
| 75 |
{ LLM_ARCH_UNKNOWN, "(unknown)" },
|
| 76 |
};
|
| 77 |
|
|
@@ -243,6 +246,24 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
|
| 243 |
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
| 244 |
},
|
| 245 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 246 |
{
|
| 247 |
LLM_ARCH_LLAMA4,
|
| 248 |
{
|
|
@@ -494,6 +515,21 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
|
| 494 |
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
| 495 |
},
|
| 496 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 497 |
{
|
| 498 |
LLM_ARCH_JINA_BERT_V2,
|
| 499 |
{
|
|
@@ -1555,6 +1591,34 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
|
| 1555 |
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
|
| 1556 |
},
|
| 1557 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1558 |
{
|
| 1559 |
LLM_ARCH_UNKNOWN,
|
| 1560 |
{
|
|
|
|
| 20 |
{ LLM_ARCH_BERT, "bert" },
|
| 21 |
{ LLM_ARCH_NOMIC_BERT, "nomic-bert" },
|
| 22 |
{ LLM_ARCH_NOMIC_BERT_MOE, "nomic-bert-moe" },
|
| 23 |
+
{ LLM_ARCH_NEO_BERT, "neo-bert" },
|
| 24 |
{ LLM_ARCH_JINA_BERT_V2, "jina-bert-v2" },
|
| 25 |
{ LLM_ARCH_BLOOM, "bloom" },
|
| 26 |
{ LLM_ARCH_STABLELM, "stablelm" },
|
|
|
|
| 73 |
{ LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" },
|
| 74 |
{ LLM_ARCH_PLM, "plm" },
|
| 75 |
{ LLM_ARCH_BAILINGMOE, "bailingmoe" },
|
| 76 |
+
{ LLM_ARCH_DOTS1, "dots1" },
|
| 77 |
+
{ LLM_ARCH_ARCEE, "arcee" },
|
| 78 |
{ LLM_ARCH_UNKNOWN, "(unknown)" },
|
| 79 |
};
|
| 80 |
|
|
|
|
| 246 |
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
| 247 |
},
|
| 248 |
},
|
| 249 |
+
{
|
| 250 |
+
LLM_ARCH_ARCEE,
|
| 251 |
+
{
|
| 252 |
+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
| 253 |
+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
| 254 |
+
{ LLM_TENSOR_OUTPUT, "output" },
|
| 255 |
+
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
|
| 256 |
+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
| 257 |
+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
| 258 |
+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
| 259 |
+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
| 260 |
+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
| 261 |
+
{ LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
|
| 262 |
+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
| 263 |
+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
| 264 |
+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
| 265 |
+
},
|
| 266 |
+
},
|
| 267 |
{
|
| 268 |
LLM_ARCH_LLAMA4,
|
| 269 |
{
|
|
|
|
| 515 |
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
| 516 |
},
|
| 517 |
},
|
| 518 |
+
{
|
| 519 |
+
LLM_ARCH_NEO_BERT,
|
| 520 |
+
{
|
| 521 |
+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
| 522 |
+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
| 523 |
+
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
|
| 524 |
+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
| 525 |
+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
| 526 |
+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
| 527 |
+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
| 528 |
+
{ LLM_TENSOR_ENC_OUTPUT_NORM, "enc.output_norm" },
|
| 529 |
+
{ LLM_TENSOR_CLS, "cls" },
|
| 530 |
+
{ LLM_TENSOR_CLS_OUT, "cls.output" },
|
| 531 |
+
},
|
| 532 |
+
},
|
| 533 |
{
|
| 534 |
LLM_ARCH_JINA_BERT_V2,
|
| 535 |
{
|
|
|
|
| 1591 |
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
|
| 1592 |
},
|
| 1593 |
},
|
| 1594 |
+
{
|
| 1595 |
+
LLM_ARCH_DOTS1,
|
| 1596 |
+
{
|
| 1597 |
+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
| 1598 |
+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
| 1599 |
+
{ LLM_TENSOR_OUTPUT, "output" },
|
| 1600 |
+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
| 1601 |
+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
| 1602 |
+
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
|
| 1603 |
+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
| 1604 |
+
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
|
| 1605 |
+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
| 1606 |
+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
| 1607 |
+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
| 1608 |
+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
| 1609 |
+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
| 1610 |
+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
| 1611 |
+
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
| 1612 |
+
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
| 1613 |
+
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
| 1614 |
+
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
| 1615 |
+
{ LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" },
|
| 1616 |
+
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
|
| 1617 |
+
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
|
| 1618 |
+
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
|
| 1619 |
+
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
|
| 1620 |
+
}
|
| 1621 |
+
},
|
| 1622 |
{
|
| 1623 |
LLM_ARCH_UNKNOWN,
|
| 1624 |
{
|
examples/talk-llama/llama-arch.h
CHANGED
|
@@ -24,6 +24,7 @@ enum llm_arch {
|
|
| 24 |
LLM_ARCH_BERT,
|
| 25 |
LLM_ARCH_NOMIC_BERT,
|
| 26 |
LLM_ARCH_NOMIC_BERT_MOE,
|
|
|
|
| 27 |
LLM_ARCH_JINA_BERT_V2,
|
| 28 |
LLM_ARCH_BLOOM,
|
| 29 |
LLM_ARCH_STABLELM,
|
|
@@ -76,6 +77,8 @@ enum llm_arch {
|
|
| 76 |
LLM_ARCH_WAVTOKENIZER_DEC,
|
| 77 |
LLM_ARCH_PLM,
|
| 78 |
LLM_ARCH_BAILINGMOE,
|
|
|
|
|
|
|
| 79 |
LLM_ARCH_UNKNOWN,
|
| 80 |
};
|
| 81 |
|
|
|
|
| 24 |
LLM_ARCH_BERT,
|
| 25 |
LLM_ARCH_NOMIC_BERT,
|
| 26 |
LLM_ARCH_NOMIC_BERT_MOE,
|
| 27 |
+
LLM_ARCH_NEO_BERT,
|
| 28 |
LLM_ARCH_JINA_BERT_V2,
|
| 29 |
LLM_ARCH_BLOOM,
|
| 30 |
LLM_ARCH_STABLELM,
|
|
|
|
| 77 |
LLM_ARCH_WAVTOKENIZER_DEC,
|
| 78 |
LLM_ARCH_PLM,
|
| 79 |
LLM_ARCH_BAILINGMOE,
|
| 80 |
+
LLM_ARCH_DOTS1,
|
| 81 |
+
LLM_ARCH_ARCEE,
|
| 82 |
LLM_ARCH_UNKNOWN,
|
| 83 |
};
|
| 84 |
|
examples/talk-llama/llama-batch.cpp
CHANGED
|
@@ -1,8 +1,14 @@
|
|
| 1 |
#include "llama-batch.h"
|
| 2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
#include <cassert>
|
| 4 |
#include <cstring>
|
| 5 |
#include <algorithm>
|
|
|
|
| 6 |
|
| 7 |
llama_ubatch llama_sbatch::reserve_ubatch(size_t n_ubatch, bool has_embd) {
|
| 8 |
// clear empty sequences
|
|
@@ -105,12 +111,7 @@ void llama_sbatch::add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & s
|
|
| 105 |
ubatch.seq_id = batch->seq_id + seq.offset;
|
| 106 |
}
|
| 107 |
}
|
| 108 |
-
if (
|
| 109 |
-
for (size_t i = 0; i < length; ++i) {
|
| 110 |
-
ubatch.output[ubatch.n_tokens + i] = 1;
|
| 111 |
-
out_ids.push_back(ids[seq.offset + i]);
|
| 112 |
-
}
|
| 113 |
-
} else if (batch->logits) {
|
| 114 |
if (ubatch.equal_seqs) {
|
| 115 |
for (size_t i = 0; i < length; ++i) {
|
| 116 |
size_t id = ids[seq.offset + i];
|
|
@@ -197,11 +198,10 @@ llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) {
|
|
| 197 |
return ubatch;
|
| 198 |
}
|
| 199 |
|
| 200 |
-
llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split
|
| 201 |
GGML_ASSERT(batch.n_tokens >= 0);
|
| 202 |
this->batch = &batch;
|
| 203 |
this->n_embd = n_embd;
|
| 204 |
-
this->logits_all = logits_all;
|
| 205 |
|
| 206 |
n_tokens = batch.n_tokens;
|
| 207 |
ids.resize(n_tokens);
|
|
@@ -285,17 +285,56 @@ llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple
|
|
| 285 |
);
|
| 286 |
}
|
| 287 |
|
| 288 |
-
llama_batch_allocr::llama_batch_allocr(
|
| 289 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 290 |
GGML_ASSERT(batch.n_tokens > 0);
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 296 |
}
|
| 297 |
-
batch.pos = pos.data();
|
| 298 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 299 |
if (!batch.n_seq_id) {
|
| 300 |
n_seq_id.resize(batch.n_tokens);
|
| 301 |
for (int32_t i = 0; i < batch.n_tokens; i++) {
|
|
@@ -303,6 +342,7 @@ llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0
|
|
| 303 |
}
|
| 304 |
batch.n_seq_id = n_seq_id.data();
|
| 305 |
}
|
|
|
|
| 306 |
if (!batch.seq_id) {
|
| 307 |
seq_id.resize(batch.n_tokens + 1);
|
| 308 |
seq_id[batch.n_tokens] = NULL;
|
|
@@ -311,10 +351,221 @@ llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0
|
|
| 311 |
}
|
| 312 |
batch.seq_id = seq_id.data();
|
| 313 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 314 |
if (!batch.logits) {
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 318 |
}
|
| 319 |
}
|
| 320 |
|
|
|
|
| 1 |
#include "llama-batch.h"
|
| 2 |
|
| 3 |
+
#include "llama-impl.h"
|
| 4 |
+
#include "llama-cparams.h"
|
| 5 |
+
#include "llama-vocab.h"
|
| 6 |
+
#include "llama-memory.h"
|
| 7 |
+
|
| 8 |
#include <cassert>
|
| 9 |
#include <cstring>
|
| 10 |
#include <algorithm>
|
| 11 |
+
#include <sstream>
|
| 12 |
|
| 13 |
llama_ubatch llama_sbatch::reserve_ubatch(size_t n_ubatch, bool has_embd) {
|
| 14 |
// clear empty sequences
|
|
|
|
| 111 |
ubatch.seq_id = batch->seq_id + seq.offset;
|
| 112 |
}
|
| 113 |
}
|
| 114 |
+
if (batch->logits) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
if (ubatch.equal_seqs) {
|
| 116 |
for (size_t i = 0; i < length; ++i) {
|
| 117 |
size_t id = ids[seq.offset + i];
|
|
|
|
| 198 |
return ubatch;
|
| 199 |
}
|
| 200 |
|
| 201 |
+
llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split) {
|
| 202 |
GGML_ASSERT(batch.n_tokens >= 0);
|
| 203 |
this->batch = &batch;
|
| 204 |
this->n_embd = n_embd;
|
|
|
|
| 205 |
|
| 206 |
n_tokens = batch.n_tokens;
|
| 207 |
ids.resize(n_tokens);
|
|
|
|
| 285 |
);
|
| 286 |
}
|
| 287 |
|
| 288 |
+
llama_batch_allocr::llama_batch_allocr() {
|
| 289 |
+
const char * LLAMA_BATCH_DEBUG = getenv("LLAMA_BATCH_DEBUG");
|
| 290 |
+
debug = LLAMA_BATCH_DEBUG ? atoi(LLAMA_BATCH_DEBUG) : 0;
|
| 291 |
+
|
| 292 |
+
seq_pos.resize(LLAMA_MAX_SEQ);
|
| 293 |
+
seq_cpl.resize(LLAMA_MAX_SEQ);
|
| 294 |
+
for (auto & cur : seq_cpl) {
|
| 295 |
+
cur.resize(LLAMA_MAX_SEQ);
|
| 296 |
+
}
|
| 297 |
+
}
|
| 298 |
+
|
| 299 |
+
bool llama_batch_allocr::init(
|
| 300 |
+
const llama_batch & batch_inp,
|
| 301 |
+
const llama_vocab & vocab,
|
| 302 |
+
const llama_memory_i * memory,
|
| 303 |
+
bool embd_all) {
|
| 304 |
+
clear();
|
| 305 |
+
|
| 306 |
+
batch = batch_inp;
|
| 307 |
+
|
| 308 |
GGML_ASSERT(batch.n_tokens > 0);
|
| 309 |
+
|
| 310 |
+
//
|
| 311 |
+
// validate input batch
|
| 312 |
+
//
|
| 313 |
+
|
| 314 |
+
if (batch.token) {
|
| 315 |
+
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
| 316 |
+
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= vocab.n_tokens()) {
|
| 317 |
+
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
|
| 318 |
+
return false;
|
| 319 |
+
}
|
| 320 |
+
}
|
| 321 |
+
}
|
| 322 |
+
|
| 323 |
+
if (batch.seq_id) {
|
| 324 |
+
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
| 325 |
+
for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
|
| 326 |
+
if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= LLAMA_MAX_SEQ)) {
|
| 327 |
+
LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s], LLAMA_MAX_SEQ);
|
| 328 |
+
return false;
|
| 329 |
+
}
|
| 330 |
+
}
|
| 331 |
}
|
|
|
|
| 332 |
}
|
| 333 |
+
|
| 334 |
+
//
|
| 335 |
+
// auto-generate missing fields
|
| 336 |
+
//
|
| 337 |
+
|
| 338 |
if (!batch.n_seq_id) {
|
| 339 |
n_seq_id.resize(batch.n_tokens);
|
| 340 |
for (int32_t i = 0; i < batch.n_tokens; i++) {
|
|
|
|
| 342 |
}
|
| 343 |
batch.n_seq_id = n_seq_id.data();
|
| 344 |
}
|
| 345 |
+
|
| 346 |
if (!batch.seq_id) {
|
| 347 |
seq_id.resize(batch.n_tokens + 1);
|
| 348 |
seq_id[batch.n_tokens] = NULL;
|
|
|
|
| 351 |
}
|
| 352 |
batch.seq_id = seq_id.data();
|
| 353 |
}
|
| 354 |
+
|
| 355 |
+
if (!batch.pos) {
|
| 356 |
+
pos.resize(batch.n_tokens);
|
| 357 |
+
|
| 358 |
+
// initialize the starting position for each sequence based on the positions in the memory
|
| 359 |
+
llama_pos p0[LLAMA_MAX_SEQ];
|
| 360 |
+
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
| 361 |
+
if (!memory) {
|
| 362 |
+
p0[s] = 0;
|
| 363 |
+
} else {
|
| 364 |
+
p0[s] = memory->seq_pos_max(s) + 1;
|
| 365 |
+
}
|
| 366 |
+
}
|
| 367 |
+
|
| 368 |
+
for (int32_t i = 0; i < batch.n_tokens; i++) {
|
| 369 |
+
const llama_seq_id seq_id = batch.seq_id[i][0];
|
| 370 |
+
|
| 371 |
+
pos[i] = p0[seq_id];
|
| 372 |
+
|
| 373 |
+
for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
|
| 374 |
+
p0[batch.seq_id[i][s]] = pos[i] + 1;
|
| 375 |
+
}
|
| 376 |
+
}
|
| 377 |
+
|
| 378 |
+
batch.pos = pos.data();
|
| 379 |
+
}
|
| 380 |
+
|
| 381 |
if (!batch.logits) {
|
| 382 |
+
if (embd_all) {
|
| 383 |
+
// return the output for all tokens
|
| 384 |
+
output.resize(batch.n_tokens, true);
|
| 385 |
+
} else {
|
| 386 |
+
// return the output only for the last token
|
| 387 |
+
output.resize(batch.n_tokens, false);
|
| 388 |
+
output[output.size() - 1] = true;
|
| 389 |
+
}
|
| 390 |
+
|
| 391 |
+
batch.logits = output.data();
|
| 392 |
+
} else if (embd_all) {
|
| 393 |
+
bool warn = false;
|
| 394 |
+
|
| 395 |
+
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
| 396 |
+
if (batch.logits[i] == 0) {
|
| 397 |
+
warn = true;
|
| 398 |
+
}
|
| 399 |
+
}
|
| 400 |
+
|
| 401 |
+
if (warn) {
|
| 402 |
+
LLAMA_LOG_WARN("%s: embeddings required but some input tokens were not marked as outputs -> overriding\n", __func__);
|
| 403 |
+
|
| 404 |
+
output.resize(batch.n_tokens, true);
|
| 405 |
+
batch.logits = output.data();
|
| 406 |
+
}
|
| 407 |
+
}
|
| 408 |
+
|
| 409 |
+
//
|
| 410 |
+
// compute stats
|
| 411 |
+
//
|
| 412 |
+
|
| 413 |
+
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
| 414 |
+
n_outputs += batch.logits[i] != 0;
|
| 415 |
+
}
|
| 416 |
+
|
| 417 |
+
// determine coupled sequences
|
| 418 |
+
// these are pairs of sequences that have at least one token in the input batch that is assigned to both of them
|
| 419 |
+
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
| 420 |
+
for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
|
| 421 |
+
seq_pos[batch.seq_id[i][s]].insert(batch.pos[i]);
|
| 422 |
+
|
| 423 |
+
if (s > 0) {
|
| 424 |
+
const llama_seq_id s0 = batch.seq_id[i][0];
|
| 425 |
+
const llama_seq_id s1 = batch.seq_id[i][s];
|
| 426 |
+
|
| 427 |
+
// mark that sequence s1 is coupled to s0
|
| 428 |
+
seq_cpl[s1][s0] = true;
|
| 429 |
+
|
| 430 |
+
// note: the other way around is not necessary for now
|
| 431 |
+
//seq_cpl[s0][s1] = true;
|
| 432 |
+
}
|
| 433 |
+
}
|
| 434 |
+
}
|
| 435 |
+
|
| 436 |
+
if (debug > 0) {
|
| 437 |
+
LLAMA_LOG_DEBUG("%s: input batch info:\n", __func__);
|
| 438 |
+
LLAMA_LOG_DEBUG("%s: n_tokens = %d\n", __func__, batch.n_tokens);
|
| 439 |
+
LLAMA_LOG_DEBUG("%s: token = %p\n", __func__, (void *) batch.token);
|
| 440 |
+
LLAMA_LOG_DEBUG("%s: embd = %p\n", __func__, (void *) batch.embd);
|
| 441 |
+
LLAMA_LOG_DEBUG("%s: pos = %p\n", __func__, (void *) batch.pos);
|
| 442 |
+
LLAMA_LOG_DEBUG("%s: n_seq_id = %p\n", __func__, (void *) batch.n_seq_id);
|
| 443 |
+
LLAMA_LOG_DEBUG("%s: seq_id = %p\n", __func__, (void *) batch.seq_id);
|
| 444 |
+
LLAMA_LOG_DEBUG("%s: logits = %p\n", __func__, (void *) batch.logits);
|
| 445 |
+
LLAMA_LOG_DEBUG("%s: n_outputs = %d\n", __func__, n_outputs);
|
| 446 |
+
|
| 447 |
+
if (debug > 1) {
|
| 448 |
+
int seq_id_max = 0;
|
| 449 |
+
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
| 450 |
+
for (int s = 0; s < batch.n_seq_id[i]; ++s) {
|
| 451 |
+
for (int s = 0; s < batch.n_seq_id[i]; ++s) {
|
| 452 |
+
seq_id_max = std::max(seq_id_max, batch.seq_id[i][s]);
|
| 453 |
+
}
|
| 454 |
+
}
|
| 455 |
+
}
|
| 456 |
+
++seq_id_max;
|
| 457 |
+
|
| 458 |
+
LLAMA_LOG_DEBUG("%s: token = [\n", __func__);
|
| 459 |
+
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
| 460 |
+
std::vector<int8_t> seq_id(seq_id_max);
|
| 461 |
+
|
| 462 |
+
for (int s = 0; s < batch.n_seq_id[i]; ++s) {
|
| 463 |
+
seq_id[batch.seq_id[i][s]] = 1;
|
| 464 |
+
}
|
| 465 |
+
|
| 466 |
+
std::stringstream ss;
|
| 467 |
+
for (int s = 0; s < seq_id_max; ++s) {
|
| 468 |
+
if (seq_id[s]) {
|
| 469 |
+
ss << s%10;
|
| 470 |
+
} else {
|
| 471 |
+
ss << ".";
|
| 472 |
+
}
|
| 473 |
+
}
|
| 474 |
+
|
| 475 |
+
LLAMA_LOG_DEBUG("%s: %4d: id = %6d (%16s), pos = %4d, n_seq_id = %2d, seq_id = [%s], output = %d\n",
|
| 476 |
+
__func__, i, batch.token[i], vocab.token_to_piece(batch.token[i]).c_str(),
|
| 477 |
+
batch.pos[i], batch.n_seq_id[i], ss.str().c_str(), batch.logits[i]);
|
| 478 |
+
}
|
| 479 |
+
LLAMA_LOG_DEBUG("%s: ]\n", __func__);
|
| 480 |
+
|
| 481 |
+
LLAMA_LOG_DEBUG("%s: seq = [\n", __func__);
|
| 482 |
+
for (int s0 = 0; s0 < (int) seq_pos.size(); ++s0) {
|
| 483 |
+
if (seq_pos[s0].empty()) {
|
| 484 |
+
continue;
|
| 485 |
+
}
|
| 486 |
+
|
| 487 |
+
std::stringstream ss;
|
| 488 |
+
for (int s1 = 0; s1 < (int) seq_cpl[s0].size(); ++s1) {
|
| 489 |
+
if (seq_cpl[s0][s1]) {
|
| 490 |
+
ss << s1 << " ";
|
| 491 |
+
}
|
| 492 |
+
}
|
| 493 |
+
|
| 494 |
+
LLAMA_LOG_DEBUG("%s: %4d: pos = [%4d, %4d], cpl = %s\n",
|
| 495 |
+
__func__, s0, seq_pos_min(s0), seq_pos_max(s0), ss.str().empty() ? "-" : ss.str().c_str());
|
| 496 |
+
}
|
| 497 |
+
LLAMA_LOG_DEBUG("%s: ]\n", __func__);
|
| 498 |
+
}
|
| 499 |
+
}
|
| 500 |
+
|
| 501 |
+
//
|
| 502 |
+
// consistency checks
|
| 503 |
+
//
|
| 504 |
+
|
| 505 |
+
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
| 506 |
+
if (seq_pos[s].empty()) {
|
| 507 |
+
continue;
|
| 508 |
+
}
|
| 509 |
+
|
| 510 |
+
if (memory && seq_pos_min(s) != memory->seq_pos_max(s) + 1) {
|
| 511 |
+
LLAMA_LOG_ERROR("%s: sequence %d does not start from the last position stored in the memory\n", __func__, s);
|
| 512 |
+
return false;
|
| 513 |
+
}
|
| 514 |
+
|
| 515 |
+
if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) {
|
| 516 |
+
LLAMA_LOG_ERROR("%s: sequence %d positions are not continuous\n", __func__, s);
|
| 517 |
+
return false;
|
| 518 |
+
}
|
| 519 |
+
}
|
| 520 |
+
|
| 521 |
+
if (memory) {
|
| 522 |
+
for (int32_t s0 = 0; s0 < LLAMA_MAX_SEQ; ++s0) {
|
| 523 |
+
for (int32_t s1 = 0; s1 < LLAMA_MAX_SEQ; ++s1) {
|
| 524 |
+
if (seq_cpl[s0][s1]) {
|
| 525 |
+
if (memory->seq_pos_min(s0) != memory->seq_pos_min(s1) ||
|
| 526 |
+
memory->seq_pos_max(s0) != memory->seq_pos_max(s1)) {
|
| 527 |
+
LLAMA_LOG_ERROR("%s: sequence %d is coupled to %d in the input batch, but have divereged\n", __func__, s0, s1);
|
| 528 |
+
return false;
|
| 529 |
+
}
|
| 530 |
+
}
|
| 531 |
+
}
|
| 532 |
+
}
|
| 533 |
+
}
|
| 534 |
+
|
| 535 |
+
return true;
|
| 536 |
+
}
|
| 537 |
+
|
| 538 |
+
const llama_batch & llama_batch_allocr::get_batch() const {
|
| 539 |
+
return batch;
|
| 540 |
+
}
|
| 541 |
+
|
| 542 |
+
uint32_t llama_batch_allocr::get_n_outputs() const {
|
| 543 |
+
return n_outputs;
|
| 544 |
+
}
|
| 545 |
+
|
| 546 |
+
llama_pos llama_batch_allocr::seq_pos_min(llama_seq_id seq_id) const {
|
| 547 |
+
return seq_pos[seq_id].empty() ? -1 : *seq_pos[seq_id].begin();
|
| 548 |
+
}
|
| 549 |
+
|
| 550 |
+
llama_pos llama_batch_allocr::seq_pos_max(llama_seq_id seq_id) const {
|
| 551 |
+
return seq_pos[seq_id].empty() ? -1 : *seq_pos[seq_id].rbegin();
|
| 552 |
+
}
|
| 553 |
+
|
| 554 |
+
void llama_batch_allocr::clear() {
|
| 555 |
+
n_outputs = 0;
|
| 556 |
+
|
| 557 |
+
batch = {};
|
| 558 |
+
pos.clear();
|
| 559 |
+
n_seq_id.clear();
|
| 560 |
+
seq_id.clear();
|
| 561 |
+
output.clear();
|
| 562 |
+
|
| 563 |
+
for (auto & cur : seq_pos) {
|
| 564 |
+
cur.clear();
|
| 565 |
+
}
|
| 566 |
+
|
| 567 |
+
for (auto & cur : seq_cpl) {
|
| 568 |
+
std::fill(cur.begin(), cur.end(), false);
|
| 569 |
}
|
| 570 |
}
|
| 571 |
|
examples/talk-llama/llama-batch.h
CHANGED
|
@@ -4,6 +4,7 @@
|
|
| 4 |
|
| 5 |
#include <array>
|
| 6 |
#include <vector>
|
|
|
|
| 7 |
|
| 8 |
// very similar to llama_batch,
|
| 9 |
// but has more metadata about sequences
|
|
@@ -18,8 +19,8 @@ struct llama_ubatch {
|
|
| 18 |
llama_token * token; // [n_tokens]
|
| 19 |
float * embd; // [n_embd, n_tokens]
|
| 20 |
llama_pos * pos; // [n_tokens]
|
| 21 |
-
int32_t * n_seq_id; // [n_seqs]
|
| 22 |
-
llama_seq_id ** seq_id; // [n_seqs]
|
| 23 |
int8_t * output; // [n_tokens]
|
| 24 |
};
|
| 25 |
|
|
@@ -39,8 +40,6 @@ struct llama_sbatch {
|
|
| 39 |
|
| 40 |
size_t n_embd;
|
| 41 |
|
| 42 |
-
bool logits_all; // TODO: remove once lctx.logits_all is removed too
|
| 43 |
-
|
| 44 |
// sorted indices into the batch
|
| 45 |
std::vector<int64_t> ids;
|
| 46 |
// batch indices of the output
|
|
@@ -76,19 +75,45 @@ struct llama_sbatch {
|
|
| 76 |
llama_ubatch split_seq(size_t n_ubatch);
|
| 77 |
|
| 78 |
llama_sbatch() = default;
|
| 79 |
-
llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split = false
|
| 80 |
};
|
| 81 |
|
| 82 |
-
//
|
| 83 |
-
|
| 84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
|
| 86 |
std::array<llama_seq_id, 1> seq_id_0 = { 0 }; // default sequence id
|
|
|
|
| 87 |
std::vector<llama_pos> pos;
|
| 88 |
std::vector<int32_t> n_seq_id;
|
| 89 |
std::vector<llama_seq_id *> seq_id;
|
| 90 |
-
std::vector<int8_t>
|
|
|
|
|
|
|
|
|
|
| 91 |
|
| 92 |
-
|
| 93 |
-
llama_batch_allocr(struct llama_batch in_batch, llama_pos p0);
|
| 94 |
};
|
|
|
|
| 4 |
|
| 5 |
#include <array>
|
| 6 |
#include <vector>
|
| 7 |
+
#include <set>
|
| 8 |
|
| 9 |
// very similar to llama_batch,
|
| 10 |
// but has more metadata about sequences
|
|
|
|
| 19 |
llama_token * token; // [n_tokens]
|
| 20 |
float * embd; // [n_embd, n_tokens]
|
| 21 |
llama_pos * pos; // [n_tokens]
|
| 22 |
+
int32_t * n_seq_id; // [n_seqs]
|
| 23 |
+
llama_seq_id ** seq_id; // [n_seqs]
|
| 24 |
int8_t * output; // [n_tokens]
|
| 25 |
};
|
| 26 |
|
|
|
|
| 40 |
|
| 41 |
size_t n_embd;
|
| 42 |
|
|
|
|
|
|
|
| 43 |
// sorted indices into the batch
|
| 44 |
std::vector<int64_t> ids;
|
| 45 |
// batch indices of the output
|
|
|
|
| 75 |
llama_ubatch split_seq(size_t n_ubatch);
|
| 76 |
|
| 77 |
llama_sbatch() = default;
|
| 78 |
+
llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split = false);
|
| 79 |
};
|
| 80 |
|
| 81 |
+
// a helper for sanitizing and fulfilling a batch
|
| 82 |
+
class llama_batch_allocr {
|
| 83 |
+
public:
|
| 84 |
+
llama_batch_allocr();
|
| 85 |
+
|
| 86 |
+
// sanitize and auto-gen missing data in the input batch
|
| 87 |
+
// memory is optional. if provided will be used to check for sequence continuity and to determine the positions
|
| 88 |
+
bool init(
|
| 89 |
+
const llama_batch & batch_inp,
|
| 90 |
+
const llama_vocab & vocab,
|
| 91 |
+
const llama_memory_i * memory,
|
| 92 |
+
bool embd_all);
|
| 93 |
+
|
| 94 |
+
const llama_batch & get_batch() const;
|
| 95 |
+
|
| 96 |
+
uint32_t get_n_outputs() const;
|
| 97 |
+
|
| 98 |
+
llama_pos seq_pos_min(llama_seq_id seq_id) const;
|
| 99 |
+
llama_pos seq_pos_max(llama_seq_id seq_id) const;
|
| 100 |
+
|
| 101 |
+
private:
|
| 102 |
+
void clear();
|
| 103 |
+
|
| 104 |
+
llama_batch batch;
|
| 105 |
+
|
| 106 |
+
uint32_t n_outputs;
|
| 107 |
|
| 108 |
std::array<llama_seq_id, 1> seq_id_0 = { 0 }; // default sequence id
|
| 109 |
+
|
| 110 |
std::vector<llama_pos> pos;
|
| 111 |
std::vector<int32_t> n_seq_id;
|
| 112 |
std::vector<llama_seq_id *> seq_id;
|
| 113 |
+
std::vector<int8_t> output;
|
| 114 |
+
|
| 115 |
+
std::vector<std::set<llama_pos>> seq_pos; // seq_pos[s]: the set of positions in sequence s
|
| 116 |
+
std::vector<std::vector<bool>> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1
|
| 117 |
|
| 118 |
+
int debug;
|
|
|
|
| 119 |
};
|
examples/talk-llama/llama-chat.cpp
CHANGED
|
@@ -183,6 +183,8 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
|
|
| 183 |
return LLM_CHAT_TEMPLATE_BAILING;
|
| 184 |
} else if (tmpl_contains("<|header_start|>") && tmpl_contains("<|header_end|>")) {
|
| 185 |
return LLM_CHAT_TEMPLATE_LLAMA4;
|
|
|
|
|
|
|
| 186 |
}
|
| 187 |
return LLM_CHAT_TEMPLATE_UNKNOWN;
|
| 188 |
}
|
|
@@ -643,6 +645,21 @@ int32_t llm_chat_apply_template(
|
|
| 643 |
if (add_ass) {
|
| 644 |
ss << "Assistant:";
|
| 645 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 646 |
} else {
|
| 647 |
// template not supported
|
| 648 |
return -1;
|
|
|
|
| 183 |
return LLM_CHAT_TEMPLATE_BAILING;
|
| 184 |
} else if (tmpl_contains("<|header_start|>") && tmpl_contains("<|header_end|>")) {
|
| 185 |
return LLM_CHAT_TEMPLATE_LLAMA4;
|
| 186 |
+
} else if (tmpl_contains("<|endofuserprompt|>")) {
|
| 187 |
+
return LLM_CHAT_TEMPLATE_DOTS1;
|
| 188 |
}
|
| 189 |
return LLM_CHAT_TEMPLATE_UNKNOWN;
|
| 190 |
}
|
|
|
|
| 645 |
if (add_ass) {
|
| 646 |
ss << "Assistant:";
|
| 647 |
}
|
| 648 |
+
} else if (tmpl == LLM_CHAT_TEMPLATE_DOTS1) {
|
| 649 |
+
// dots.llm1.inst (DOTS1)
|
| 650 |
+
for (auto message : chat) {
|
| 651 |
+
std::string role(message->role);
|
| 652 |
+
if (role == "system") {
|
| 653 |
+
ss << "<|system|>" << message->content << "<|endofsystem|>";
|
| 654 |
+
} else if (role == "user") {
|
| 655 |
+
ss << "<|userprompt|>" << message->content << "<|endofuserprompt|>";
|
| 656 |
+
} else {
|
| 657 |
+
ss << "<|response|>" << message->content << "<|endofresponse|>";
|
| 658 |
+
}
|
| 659 |
+
}
|
| 660 |
+
if (add_ass) {
|
| 661 |
+
ss << "<|response|>";
|
| 662 |
+
}
|
| 663 |
} else {
|
| 664 |
// template not supported
|
| 665 |
return -1;
|
examples/talk-llama/llama-chat.h
CHANGED
|
@@ -43,6 +43,7 @@ enum llm_chat_template {
|
|
| 43 |
LLM_CHAT_TEMPLATE_BAILING,
|
| 44 |
LLM_CHAT_TEMPLATE_LLAMA4,
|
| 45 |
LLM_CHAT_TEMPLATE_SMOLVLM,
|
|
|
|
| 46 |
LLM_CHAT_TEMPLATE_UNKNOWN,
|
| 47 |
};
|
| 48 |
|
|
|
|
| 43 |
LLM_CHAT_TEMPLATE_BAILING,
|
| 44 |
LLM_CHAT_TEMPLATE_LLAMA4,
|
| 45 |
LLM_CHAT_TEMPLATE_SMOLVLM,
|
| 46 |
+
LLM_CHAT_TEMPLATE_DOTS1,
|
| 47 |
LLM_CHAT_TEMPLATE_UNKNOWN,
|
| 48 |
};
|
| 49 |
|
examples/talk-llama/llama-context.cpp
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
#include "llama-context.h"
|
| 2 |
|
| 3 |
#include "llama-impl.h"
|
|
|
|
| 4 |
#include "llama-io.h"
|
| 5 |
#include "llama-memory.h"
|
| 6 |
#include "llama-mmap.h"
|
|
@@ -18,7 +19,8 @@
|
|
| 18 |
llama_context::llama_context(
|
| 19 |
const llama_model & model,
|
| 20 |
llama_context_params params) :
|
| 21 |
-
model(model)
|
|
|
|
| 22 |
LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__);
|
| 23 |
|
| 24 |
t_start_us = model.t_start_us;
|
|
@@ -27,8 +29,8 @@ llama_context::llama_context(
|
|
| 27 |
const auto & hparams = model.hparams;
|
| 28 |
|
| 29 |
cparams.n_seq_max = std::max(1u, params.n_seq_max);
|
| 30 |
-
if (cparams.n_seq_max >
|
| 31 |
-
throw std::runtime_error("n_seq_max must be <= " + std::to_string(
|
| 32 |
}
|
| 33 |
|
| 34 |
cparams.n_threads = params.n_threads;
|
|
@@ -494,7 +496,7 @@ float * llama_context::get_logits() {
|
|
| 494 |
}
|
| 495 |
|
| 496 |
float * llama_context::get_logits_ith(int32_t i) {
|
| 497 |
-
|
| 498 |
|
| 499 |
try {
|
| 500 |
if (logits == nullptr) {
|
|
@@ -517,7 +519,7 @@ float * llama_context::get_logits_ith(int32_t i) {
|
|
| 517 |
}
|
| 518 |
if (j >= n_outputs) {
|
| 519 |
// This should not happen
|
| 520 |
-
throw std::runtime_error(format("corrupt output buffer (j=%
|
| 521 |
}
|
| 522 |
|
| 523 |
return logits + j*model.vocab.n_tokens();
|
|
@@ -536,7 +538,7 @@ float * llama_context::get_embeddings() {
|
|
| 536 |
}
|
| 537 |
|
| 538 |
float * llama_context::get_embeddings_ith(int32_t i) {
|
| 539 |
-
|
| 540 |
|
| 541 |
try {
|
| 542 |
if (embd == nullptr) {
|
|
@@ -559,7 +561,7 @@ float * llama_context::get_embeddings_ith(int32_t i) {
|
|
| 559 |
}
|
| 560 |
if (j >= n_outputs) {
|
| 561 |
// This should not happen
|
| 562 |
-
throw std::runtime_error(format("corrupt output buffer (j=%
|
| 563 |
}
|
| 564 |
|
| 565 |
return embd + j*model.hparams.n_embd;
|
|
@@ -719,52 +721,41 @@ llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch,
|
|
| 719 |
return res;
|
| 720 |
}
|
| 721 |
|
| 722 |
-
int llama_context::encode(llama_batch &
|
| 723 |
-
if (
|
| 724 |
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
|
| 725 |
return -1;
|
| 726 |
}
|
| 727 |
|
| 728 |
-
// temporary allocate memory for the input batch if needed
|
| 729 |
// note: during encode, we always pass the full sequence starting from pos = 0
|
| 730 |
-
|
|
|
|
|
|
|
|
|
|
| 731 |
|
| 732 |
-
const llama_batch & batch = batch_allocr
|
| 733 |
-
const int32_t n_tokens = batch.n_tokens;
|
| 734 |
|
| 735 |
-
const
|
| 736 |
|
| 737 |
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
|
| 738 |
|
| 739 |
-
// TODO: move the validation to the llama_batch_allocr
|
| 740 |
-
if (batch.token) {
|
| 741 |
-
for (int32_t i = 0; i < n_tokens; ++i) {
|
| 742 |
-
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
|
| 743 |
-
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
|
| 744 |
-
return -1;
|
| 745 |
-
}
|
| 746 |
-
|
| 747 |
-
if (batch.seq_id && (batch.seq_id[i][0] < 0 || batch.seq_id[i][0] >= LLAMA_MAX_PARALLEL_SEQUENCES)) {
|
| 748 |
-
LLAMA_LOG_ERROR("%s: invalid seq_id[%d] = %d > %d\n", __func__, i, batch.seq_id[i][0], LLAMA_MAX_PARALLEL_SEQUENCES);
|
| 749 |
-
throw -1;
|
| 750 |
-
}
|
| 751 |
-
}
|
| 752 |
-
}
|
| 753 |
-
|
| 754 |
// micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
|
| 755 |
-
GGML_ASSERT(cparams.n_ubatch >=
|
| 756 |
|
| 757 |
if (t_compute_start_us == 0) {
|
| 758 |
t_compute_start_us = ggml_time_us();
|
| 759 |
}
|
| 760 |
|
|
|
|
| 761 |
embd_seq.clear();
|
| 762 |
|
| 763 |
n_queued_tokens += n_tokens;
|
| 764 |
|
|
|
|
|
|
|
| 765 |
const int64_t n_embd = hparams.n_embd;
|
| 766 |
|
| 767 |
-
llama_sbatch sbatch = llama_sbatch(batch, n_embd, /* simple_split */ true
|
| 768 |
|
| 769 |
const llama_ubatch ubatch = sbatch.split_simple(n_tokens);
|
| 770 |
|
|
@@ -774,7 +765,7 @@ int llama_context::encode(llama_batch & inp_batch) {
|
|
| 774 |
return -2;
|
| 775 |
};
|
| 776 |
|
| 777 |
-
for (
|
| 778 |
output_ids[i] = i;
|
| 779 |
}
|
| 780 |
|
|
@@ -830,7 +821,8 @@ int llama_context::encode(llama_batch & inp_batch) {
|
|
| 830 |
|
| 831 |
GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits
|
| 832 |
|
| 833 |
-
|
|
|
|
| 834 |
const llama_seq_id seq_id = ubatch.seq_id[i][0];
|
| 835 |
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
|
| 836 |
continue;
|
|
@@ -845,6 +837,7 @@ int llama_context::encode(llama_batch & inp_batch) {
|
|
| 845 |
auto & embd_seq_out = embd_seq;
|
| 846 |
const uint32_t n_cls_out = hparams.n_cls_out;
|
| 847 |
|
|
|
|
| 848 |
for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
|
| 849 |
const llama_seq_id seq_id = ubatch.seq_id[s][0];
|
| 850 |
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
|
|
@@ -878,10 +871,10 @@ int llama_context::encode(llama_batch & inp_batch) {
|
|
| 878 |
|
| 879 |
// remember the sequence ids used during the encoding - needed for cross attention later
|
| 880 |
cross.seq_ids_enc.resize(n_tokens);
|
| 881 |
-
for (
|
| 882 |
cross.seq_ids_enc[i].clear();
|
| 883 |
-
for (int s = 0; s <
|
| 884 |
-
llama_seq_id seq_id =
|
| 885 |
cross.seq_ids_enc[i].insert(seq_id);
|
| 886 |
}
|
| 887 |
}
|
|
@@ -890,51 +883,45 @@ int llama_context::encode(llama_batch & inp_batch) {
|
|
| 890 |
return 0;
|
| 891 |
}
|
| 892 |
|
| 893 |
-
int llama_context::decode(llama_batch &
|
| 894 |
if (!memory) {
|
| 895 |
LLAMA_LOG_DEBUG("%s: cannot decode batches with this context (calling encode() instead)\n", __func__);
|
| 896 |
-
return encode(
|
| 897 |
}
|
| 898 |
|
| 899 |
-
if (
|
| 900 |
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
|
| 901 |
return -1;
|
| 902 |
}
|
| 903 |
|
| 904 |
-
|
| 905 |
-
|
| 906 |
-
LLAMA_LOG_ERROR("%s: pos == NULL, but seq_id != NULL\n", __func__);
|
| 907 |
-
return -1;
|
| 908 |
-
}
|
| 909 |
-
}
|
| 910 |
|
| 911 |
-
|
| 912 |
-
|
|
|
|
|
|
|
| 913 |
|
| 914 |
-
const llama_batch & batch = batch_allocr
|
| 915 |
|
| 916 |
const auto & vocab = model.vocab;
|
| 917 |
const auto & hparams = model.hparams;
|
| 918 |
|
| 919 |
const int32_t n_vocab = vocab.n_tokens();
|
|
|
|
| 920 |
|
| 921 |
-
const
|
| 922 |
-
const int64_t n_embd = hparams.n_embd;
|
| 923 |
|
| 924 |
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
|
| 925 |
|
| 926 |
-
|
| 927 |
-
if (batch.token) {
|
| 928 |
-
for (int64_t i = 0; i < n_tokens_all; ++i) {
|
| 929 |
-
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
|
| 930 |
-
LLAMA_LOG_ERROR("%s: invalid token[%" PRId64 "] = %d\n", __func__, i, batch.token[i]);
|
| 931 |
-
return -1;
|
| 932 |
-
}
|
| 933 |
|
| 934 |
-
|
| 935 |
-
|
| 936 |
-
|
| 937 |
-
|
|
|
|
|
|
|
| 938 |
}
|
| 939 |
}
|
| 940 |
|
|
@@ -947,25 +934,9 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
| 947 |
}
|
| 948 |
n_queued_tokens += n_tokens_all;
|
| 949 |
|
| 950 |
-
// this
|
| 951 |
-
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
|
| 952 |
-
|
| 953 |
embd_seq.clear();
|
| 954 |
|
| 955 |
-
int64_t n_outputs_all = 0;
|
| 956 |
-
|
| 957 |
-
// count outputs
|
| 958 |
-
if (batch.logits && !embd_pooled) {
|
| 959 |
-
for (uint32_t i = 0; i < n_tokens_all; ++i) {
|
| 960 |
-
n_outputs_all += batch.logits[i] != 0;
|
| 961 |
-
}
|
| 962 |
-
} else if (embd_pooled) {
|
| 963 |
-
n_outputs_all = n_tokens_all;
|
| 964 |
-
} else {
|
| 965 |
-
// keep last output only
|
| 966 |
-
n_outputs_all = 1;
|
| 967 |
-
}
|
| 968 |
-
|
| 969 |
bool did_optimize = false;
|
| 970 |
|
| 971 |
// handle any pending defrags/shifts
|
|
@@ -974,7 +945,7 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
| 974 |
llama_memory_state_ptr mstate;
|
| 975 |
|
| 976 |
while (true) {
|
| 977 |
-
mstate = memory->init_batch(batch, cparams.n_ubatch,
|
| 978 |
if (!mstate) {
|
| 979 |
return -2;
|
| 980 |
}
|
|
@@ -1018,7 +989,7 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
| 1018 |
|
| 1019 |
// reserve output buffer
|
| 1020 |
if (output_reserve(n_outputs_all) < n_outputs_all) {
|
| 1021 |
-
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %
|
| 1022 |
return -2;
|
| 1023 |
};
|
| 1024 |
|
|
@@ -1027,7 +998,7 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
| 1027 |
do {
|
| 1028 |
const auto & ubatch = mstate->get_ubatch();
|
| 1029 |
|
| 1030 |
-
// count the outputs in this
|
| 1031 |
{
|
| 1032 |
int32_t n_outputs_new = 0;
|
| 1033 |
|
|
@@ -1052,18 +1023,19 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
| 1052 |
|
| 1053 |
if (!res) {
|
| 1054 |
// the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
|
| 1055 |
-
llama_pos pos_min[
|
| 1056 |
-
for (int s = 0; s <
|
| 1057 |
pos_min[s] = std::numeric_limits<llama_pos>::max();
|
| 1058 |
}
|
| 1059 |
|
|
|
|
| 1060 |
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
|
| 1061 |
const auto & seq_id = ubatch.seq_id[i][0];
|
| 1062 |
|
| 1063 |
pos_min[seq_id] = std::min(pos_min[seq_id], ubatch.pos[i]);
|
| 1064 |
}
|
| 1065 |
|
| 1066 |
-
for (int s = 0; s <
|
| 1067 |
if (pos_min[s] == std::numeric_limits<llama_pos>::max()) {
|
| 1068 |
continue;
|
| 1069 |
}
|
|
@@ -1086,7 +1058,7 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
| 1086 |
// ggml_graph_dump_dot(gf, NULL, "llama.dot");
|
| 1087 |
//}
|
| 1088 |
|
| 1089 |
-
auto * t_logits =
|
| 1090 |
auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr;
|
| 1091 |
|
| 1092 |
if (t_embd && res->get_embd_pooled()) {
|
|
@@ -1170,14 +1142,14 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
| 1170 |
n_outputs = n_outputs_all;
|
| 1171 |
|
| 1172 |
// set output mappings
|
| 1173 |
-
{
|
| 1174 |
bool sorted_output = true;
|
| 1175 |
|
| 1176 |
auto & out_ids = mstate->out_ids();
|
| 1177 |
|
| 1178 |
-
GGML_ASSERT(out_ids.size() == (size_t)
|
| 1179 |
|
| 1180 |
-
for (int64_t i = 0; i <
|
| 1181 |
int64_t out_id = out_ids[i];
|
| 1182 |
output_ids[out_id] = i;
|
| 1183 |
if (out_id != i) {
|
|
@@ -1189,20 +1161,22 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
| 1189 |
// note: this is mostly relevant for recurrent models atm
|
| 1190 |
if (!sorted_output) {
|
| 1191 |
const uint32_t n_vocab = model.vocab.n_tokens();
|
| 1192 |
-
const
|
| 1193 |
|
| 1194 |
GGML_ASSERT((size_t) n_outputs == out_ids.size());
|
| 1195 |
|
| 1196 |
// TODO: is there something more efficient which also minimizes swaps?
|
| 1197 |
// selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
|
| 1198 |
-
for (
|
| 1199 |
-
|
| 1200 |
-
for (
|
| 1201 |
if (out_ids[j] < out_ids[j_min]) {
|
| 1202 |
j_min = j;
|
| 1203 |
}
|
| 1204 |
}
|
| 1205 |
-
if (j_min == i) {
|
|
|
|
|
|
|
| 1206 |
std::swap(out_ids[i], out_ids[j_min]);
|
| 1207 |
if (logits_size > 0) {
|
| 1208 |
for (uint32_t k = 0; k < n_vocab; k++) {
|
|
@@ -1215,8 +1189,10 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
| 1215 |
}
|
| 1216 |
}
|
| 1217 |
}
|
|
|
|
| 1218 |
std::fill(output_ids.begin(), output_ids.end(), -1);
|
| 1219 |
-
|
|
|
|
| 1220 |
output_ids[out_ids[i]] = i;
|
| 1221 |
}
|
| 1222 |
}
|
|
@@ -1236,7 +1212,7 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
| 1236 |
// output
|
| 1237 |
//
|
| 1238 |
|
| 1239 |
-
|
| 1240 |
const auto & hparams = model.hparams;
|
| 1241 |
const auto & vocab = model.vocab;
|
| 1242 |
|
|
@@ -1246,9 +1222,8 @@ int32_t llama_context::output_reserve(int32_t n_outputs) {
|
|
| 1246 |
const auto n_vocab = vocab.n_tokens();
|
| 1247 |
const auto n_embd = hparams.n_embd;
|
| 1248 |
|
| 1249 |
-
|
| 1250 |
-
bool
|
| 1251 |
-
bool has_embd = cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
|
| 1252 |
|
| 1253 |
// TODO: hacky enc-dec support
|
| 1254 |
if (model.arch == LLM_ARCH_T5) {
|
|
@@ -1302,8 +1277,7 @@ int32_t llama_context::output_reserve(int32_t n_outputs) {
|
|
| 1302 |
// set all ids as invalid (negative)
|
| 1303 |
std::fill(output_ids.begin(), output_ids.end(), -1);
|
| 1304 |
|
| 1305 |
-
this->n_outputs
|
| 1306 |
-
this->n_outputs_max = n_outputs_max;
|
| 1307 |
|
| 1308 |
return n_outputs_max;
|
| 1309 |
}
|
|
@@ -1332,7 +1306,7 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
|
|
| 1332 |
LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs);
|
| 1333 |
|
| 1334 |
if (n_tokens % n_seqs != 0) {
|
| 1335 |
-
n_tokens = (n_tokens / n_seqs) * n_seqs;
|
| 1336 |
n_outputs = std::min(n_outputs, n_tokens);
|
| 1337 |
|
| 1338 |
LLAMA_LOG_DEBUG("%s: making n_tokens a multiple of n_seqs - n_tokens = %u, n_seqs = %u, n_outputs = %u\n", __func__, n_tokens, n_seqs, n_outputs);
|
|
@@ -1794,14 +1768,12 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
|
|
| 1794 |
|
| 1795 |
std::vector<int32_t> w_output_pos;
|
| 1796 |
|
| 1797 |
-
GGML_ASSERT(n_outputs <= n_outputs_max);
|
| 1798 |
-
|
| 1799 |
w_output_pos.resize(n_outputs);
|
| 1800 |
|
| 1801 |
// build a more compact representation of the output ids
|
| 1802 |
for (size_t i = 0; i < n_batch(); ++i) {
|
| 1803 |
// map an output id to a position in the batch
|
| 1804 |
-
|
| 1805 |
if (pos >= 0) {
|
| 1806 |
GGML_ASSERT(pos < n_outputs);
|
| 1807 |
w_output_pos[pos] = i;
|
|
@@ -2071,14 +2043,11 @@ void llama_context::opt_epoch_iter(
|
|
| 2071 |
|
| 2072 |
n_queued_tokens += n_tokens_all;
|
| 2073 |
|
| 2074 |
-
// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
|
| 2075 |
-
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
|
| 2076 |
-
|
| 2077 |
embd_seq.clear();
|
| 2078 |
|
| 2079 |
-
|
| 2080 |
|
| 2081 |
-
auto mstate = memory->init_batch(batch, cparams.n_ubatch,
|
| 2082 |
if (!mstate || mstate->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
|
| 2083 |
LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
|
| 2084 |
break;
|
|
@@ -2086,7 +2055,7 @@ void llama_context::opt_epoch_iter(
|
|
| 2086 |
|
| 2087 |
// reserve output buffer
|
| 2088 |
if (output_reserve(n_outputs_all) < n_outputs_all) {
|
| 2089 |
-
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %
|
| 2090 |
GGML_ABORT("TODO: handle this error");
|
| 2091 |
};
|
| 2092 |
|
|
|
|
| 1 |
#include "llama-context.h"
|
| 2 |
|
| 3 |
#include "llama-impl.h"
|
| 4 |
+
#include "llama-batch.h"
|
| 5 |
#include "llama-io.h"
|
| 6 |
#include "llama-memory.h"
|
| 7 |
#include "llama-mmap.h"
|
|
|
|
| 19 |
llama_context::llama_context(
|
| 20 |
const llama_model & model,
|
| 21 |
llama_context_params params) :
|
| 22 |
+
model(model),
|
| 23 |
+
batch_allocr(std::make_unique<llama_batch_allocr>()) {
|
| 24 |
LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__);
|
| 25 |
|
| 26 |
t_start_us = model.t_start_us;
|
|
|
|
| 29 |
const auto & hparams = model.hparams;
|
| 30 |
|
| 31 |
cparams.n_seq_max = std::max(1u, params.n_seq_max);
|
| 32 |
+
if (cparams.n_seq_max > LLAMA_MAX_SEQ) {
|
| 33 |
+
throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_SEQ));
|
| 34 |
}
|
| 35 |
|
| 36 |
cparams.n_threads = params.n_threads;
|
|
|
|
| 496 |
}
|
| 497 |
|
| 498 |
float * llama_context::get_logits_ith(int32_t i) {
|
| 499 |
+
int64_t j = -1;
|
| 500 |
|
| 501 |
try {
|
| 502 |
if (logits == nullptr) {
|
|
|
|
| 519 |
}
|
| 520 |
if (j >= n_outputs) {
|
| 521 |
// This should not happen
|
| 522 |
+
throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
|
| 523 |
}
|
| 524 |
|
| 525 |
return logits + j*model.vocab.n_tokens();
|
|
|
|
| 538 |
}
|
| 539 |
|
| 540 |
float * llama_context::get_embeddings_ith(int32_t i) {
|
| 541 |
+
int64_t j = -1;
|
| 542 |
|
| 543 |
try {
|
| 544 |
if (embd == nullptr) {
|
|
|
|
| 561 |
}
|
| 562 |
if (j >= n_outputs) {
|
| 563 |
// This should not happen
|
| 564 |
+
throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
|
| 565 |
}
|
| 566 |
|
| 567 |
return embd + j*model.hparams.n_embd;
|
|
|
|
| 721 |
return res;
|
| 722 |
}
|
| 723 |
|
| 724 |
+
int llama_context::encode(const llama_batch & batch_inp) {
|
| 725 |
+
if (batch_inp.n_tokens == 0) {
|
| 726 |
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
|
| 727 |
return -1;
|
| 728 |
}
|
| 729 |
|
|
|
|
| 730 |
// note: during encode, we always pass the full sequence starting from pos = 0
|
| 731 |
+
if (!batch_allocr->init(batch_inp, model.vocab, nullptr, true)) {
|
| 732 |
+
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
|
| 733 |
+
return -1;
|
| 734 |
+
}
|
| 735 |
|
| 736 |
+
const llama_batch & batch = batch_allocr->get_batch();
|
|
|
|
| 737 |
|
| 738 |
+
const uint32_t n_tokens = batch.n_tokens;
|
| 739 |
|
| 740 |
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
|
| 741 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 742 |
// micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
|
| 743 |
+
GGML_ASSERT(cparams.n_ubatch >= n_tokens && "encoder requires n_ubatch >= n_tokens");
|
| 744 |
|
| 745 |
if (t_compute_start_us == 0) {
|
| 746 |
t_compute_start_us = ggml_time_us();
|
| 747 |
}
|
| 748 |
|
| 749 |
+
// TODO: this clear of the buffer can easily be forgotten - need something better
|
| 750 |
embd_seq.clear();
|
| 751 |
|
| 752 |
n_queued_tokens += n_tokens;
|
| 753 |
|
| 754 |
+
const auto & hparams = model.hparams;
|
| 755 |
+
|
| 756 |
const int64_t n_embd = hparams.n_embd;
|
| 757 |
|
| 758 |
+
llama_sbatch sbatch = llama_sbatch(batch, n_embd, /* simple_split */ true);
|
| 759 |
|
| 760 |
const llama_ubatch ubatch = sbatch.split_simple(n_tokens);
|
| 761 |
|
|
|
|
| 765 |
return -2;
|
| 766 |
};
|
| 767 |
|
| 768 |
+
for (uint32_t i = 0; i < n_tokens; ++i) {
|
| 769 |
output_ids[i] = i;
|
| 770 |
}
|
| 771 |
|
|
|
|
| 821 |
|
| 822 |
GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits
|
| 823 |
|
| 824 |
+
// TODO: fix indexing [UBATCH_IDX]
|
| 825 |
+
for (uint32_t i = 0; i < n_tokens; i++) {
|
| 826 |
const llama_seq_id seq_id = ubatch.seq_id[i][0];
|
| 827 |
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
|
| 828 |
continue;
|
|
|
|
| 837 |
auto & embd_seq_out = embd_seq;
|
| 838 |
const uint32_t n_cls_out = hparams.n_cls_out;
|
| 839 |
|
| 840 |
+
// TODO: fix indexing [UBATCH_IDX]
|
| 841 |
for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
|
| 842 |
const llama_seq_id seq_id = ubatch.seq_id[s][0];
|
| 843 |
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
|
|
|
|
| 871 |
|
| 872 |
// remember the sequence ids used during the encoding - needed for cross attention later
|
| 873 |
cross.seq_ids_enc.resize(n_tokens);
|
| 874 |
+
for (uint32_t i = 0; i < n_tokens; i++) {
|
| 875 |
cross.seq_ids_enc[i].clear();
|
| 876 |
+
for (int s = 0; s < batch.n_seq_id[i]; s++) {
|
| 877 |
+
llama_seq_id seq_id = batch.seq_id[i][s];
|
| 878 |
cross.seq_ids_enc[i].insert(seq_id);
|
| 879 |
}
|
| 880 |
}
|
|
|
|
| 883 |
return 0;
|
| 884 |
}
|
| 885 |
|
| 886 |
+
int llama_context::decode(const llama_batch & batch_inp) {
|
| 887 |
if (!memory) {
|
| 888 |
LLAMA_LOG_DEBUG("%s: cannot decode batches with this context (calling encode() instead)\n", __func__);
|
| 889 |
+
return encode(batch_inp);
|
| 890 |
}
|
| 891 |
|
| 892 |
+
if (batch_inp.n_tokens == 0) {
|
| 893 |
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
|
| 894 |
return -1;
|
| 895 |
}
|
| 896 |
|
| 897 |
+
// when computing embeddings, all tokens are output
|
| 898 |
+
const bool embd_all = cparams.embeddings;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 899 |
|
| 900 |
+
if (!batch_allocr->init(batch_inp, model.vocab, memory.get(), embd_all)) {
|
| 901 |
+
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
|
| 902 |
+
return -1;
|
| 903 |
+
}
|
| 904 |
|
| 905 |
+
const llama_batch & batch = batch_allocr->get_batch();
|
| 906 |
|
| 907 |
const auto & vocab = model.vocab;
|
| 908 |
const auto & hparams = model.hparams;
|
| 909 |
|
| 910 |
const int32_t n_vocab = vocab.n_tokens();
|
| 911 |
+
const int64_t n_embd = hparams.n_embd;
|
| 912 |
|
| 913 |
+
const uint32_t n_tokens_all = batch.n_tokens;
|
|
|
|
| 914 |
|
| 915 |
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
|
| 916 |
|
| 917 |
+
const uint32_t n_outputs_all = batch_allocr->get_n_outputs();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 918 |
|
| 919 |
+
if (embd_all) {
|
| 920 |
+
// require that all tokens are output
|
| 921 |
+
if (n_outputs_all != n_tokens_all) {
|
| 922 |
+
LLAMA_LOG_ERROR("%s: pooled embedding requires that all tokens are output (n_outputs_all = %d, n_tokens_all = %d)\n",
|
| 923 |
+
__func__, n_outputs_all, n_tokens_all);
|
| 924 |
+
return -1;
|
| 925 |
}
|
| 926 |
}
|
| 927 |
|
|
|
|
| 934 |
}
|
| 935 |
n_queued_tokens += n_tokens_all;
|
| 936 |
|
| 937 |
+
// TODO: this clear of the buffer can easily be forgotten - need something better
|
|
|
|
|
|
|
| 938 |
embd_seq.clear();
|
| 939 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 940 |
bool did_optimize = false;
|
| 941 |
|
| 942 |
// handle any pending defrags/shifts
|
|
|
|
| 945 |
llama_memory_state_ptr mstate;
|
| 946 |
|
| 947 |
while (true) {
|
| 948 |
+
mstate = memory->init_batch(batch, cparams.n_ubatch, embd_all);
|
| 949 |
if (!mstate) {
|
| 950 |
return -2;
|
| 951 |
}
|
|
|
|
| 989 |
|
| 990 |
// reserve output buffer
|
| 991 |
if (output_reserve(n_outputs_all) < n_outputs_all) {
|
| 992 |
+
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all);
|
| 993 |
return -2;
|
| 994 |
};
|
| 995 |
|
|
|
|
| 998 |
do {
|
| 999 |
const auto & ubatch = mstate->get_ubatch();
|
| 1000 |
|
| 1001 |
+
// count the outputs in this ubatch
|
| 1002 |
{
|
| 1003 |
int32_t n_outputs_new = 0;
|
| 1004 |
|
|
|
|
| 1023 |
|
| 1024 |
if (!res) {
|
| 1025 |
// the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
|
| 1026 |
+
llama_pos pos_min[LLAMA_MAX_SEQ];
|
| 1027 |
+
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
| 1028 |
pos_min[s] = std::numeric_limits<llama_pos>::max();
|
| 1029 |
}
|
| 1030 |
|
| 1031 |
+
// TODO: fix sequence indexing
|
| 1032 |
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
|
| 1033 |
const auto & seq_id = ubatch.seq_id[i][0];
|
| 1034 |
|
| 1035 |
pos_min[seq_id] = std::min(pos_min[seq_id], ubatch.pos[i]);
|
| 1036 |
}
|
| 1037 |
|
| 1038 |
+
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
| 1039 |
if (pos_min[s] == std::numeric_limits<llama_pos>::max()) {
|
| 1040 |
continue;
|
| 1041 |
}
|
|
|
|
| 1058 |
// ggml_graph_dump_dot(gf, NULL, "llama.dot");
|
| 1059 |
//}
|
| 1060 |
|
| 1061 |
+
auto * t_logits = res->get_logits();
|
| 1062 |
auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr;
|
| 1063 |
|
| 1064 |
if (t_embd && res->get_embd_pooled()) {
|
|
|
|
| 1142 |
n_outputs = n_outputs_all;
|
| 1143 |
|
| 1144 |
// set output mappings
|
| 1145 |
+
if (n_outputs > 0) {
|
| 1146 |
bool sorted_output = true;
|
| 1147 |
|
| 1148 |
auto & out_ids = mstate->out_ids();
|
| 1149 |
|
| 1150 |
+
GGML_ASSERT(out_ids.size() == (size_t) n_outputs);
|
| 1151 |
|
| 1152 |
+
for (int64_t i = 0; i < n_outputs; ++i) {
|
| 1153 |
int64_t out_id = out_ids[i];
|
| 1154 |
output_ids[out_id] = i;
|
| 1155 |
if (out_id != i) {
|
|
|
|
| 1161 |
// note: this is mostly relevant for recurrent models atm
|
| 1162 |
if (!sorted_output) {
|
| 1163 |
const uint32_t n_vocab = model.vocab.n_tokens();
|
| 1164 |
+
const uint64_t n_embd = model.hparams.n_embd;
|
| 1165 |
|
| 1166 |
GGML_ASSERT((size_t) n_outputs == out_ids.size());
|
| 1167 |
|
| 1168 |
// TODO: is there something more efficient which also minimizes swaps?
|
| 1169 |
// selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
|
| 1170 |
+
for (uint32_t i = 0; i < n_outputs - 1; ++i) {
|
| 1171 |
+
uint32_t j_min = i;
|
| 1172 |
+
for (uint32_t j = i + 1; j < n_outputs; ++j) {
|
| 1173 |
if (out_ids[j] < out_ids[j_min]) {
|
| 1174 |
j_min = j;
|
| 1175 |
}
|
| 1176 |
}
|
| 1177 |
+
if (j_min == i) {
|
| 1178 |
+
continue;
|
| 1179 |
+
}
|
| 1180 |
std::swap(out_ids[i], out_ids[j_min]);
|
| 1181 |
if (logits_size > 0) {
|
| 1182 |
for (uint32_t k = 0; k < n_vocab; k++) {
|
|
|
|
| 1189 |
}
|
| 1190 |
}
|
| 1191 |
}
|
| 1192 |
+
|
| 1193 |
std::fill(output_ids.begin(), output_ids.end(), -1);
|
| 1194 |
+
|
| 1195 |
+
for (uint32_t i = 0; i < n_outputs; ++i) {
|
| 1196 |
output_ids[out_ids[i]] = i;
|
| 1197 |
}
|
| 1198 |
}
|
|
|
|
| 1212 |
// output
|
| 1213 |
//
|
| 1214 |
|
| 1215 |
+
uint32_t llama_context::output_reserve(int32_t n_outputs) {
|
| 1216 |
const auto & hparams = model.hparams;
|
| 1217 |
const auto & vocab = model.vocab;
|
| 1218 |
|
|
|
|
| 1222 |
const auto n_vocab = vocab.n_tokens();
|
| 1223 |
const auto n_embd = hparams.n_embd;
|
| 1224 |
|
| 1225 |
+
bool has_logits = true;
|
| 1226 |
+
bool has_embd = cparams.embeddings;
|
|
|
|
| 1227 |
|
| 1228 |
// TODO: hacky enc-dec support
|
| 1229 |
if (model.arch == LLM_ARCH_T5) {
|
|
|
|
| 1277 |
// set all ids as invalid (negative)
|
| 1278 |
std::fill(output_ids.begin(), output_ids.end(), -1);
|
| 1279 |
|
| 1280 |
+
this->n_outputs = 0;
|
|
|
|
| 1281 |
|
| 1282 |
return n_outputs_max;
|
| 1283 |
}
|
|
|
|
| 1306 |
LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs);
|
| 1307 |
|
| 1308 |
if (n_tokens % n_seqs != 0) {
|
| 1309 |
+
n_tokens = ((n_tokens + (n_seqs - 1)) / n_seqs) * n_seqs; // round to next multiple of n_seqs
|
| 1310 |
n_outputs = std::min(n_outputs, n_tokens);
|
| 1311 |
|
| 1312 |
LLAMA_LOG_DEBUG("%s: making n_tokens a multiple of n_seqs - n_tokens = %u, n_seqs = %u, n_outputs = %u\n", __func__, n_tokens, n_seqs, n_outputs);
|
|
|
|
| 1768 |
|
| 1769 |
std::vector<int32_t> w_output_pos;
|
| 1770 |
|
|
|
|
|
|
|
| 1771 |
w_output_pos.resize(n_outputs);
|
| 1772 |
|
| 1773 |
// build a more compact representation of the output ids
|
| 1774 |
for (size_t i = 0; i < n_batch(); ++i) {
|
| 1775 |
// map an output id to a position in the batch
|
| 1776 |
+
int64_t pos = output_ids[i];
|
| 1777 |
if (pos >= 0) {
|
| 1778 |
GGML_ASSERT(pos < n_outputs);
|
| 1779 |
w_output_pos[pos] = i;
|
|
|
|
| 2043 |
|
| 2044 |
n_queued_tokens += n_tokens_all;
|
| 2045 |
|
|
|
|
|
|
|
|
|
|
| 2046 |
embd_seq.clear();
|
| 2047 |
|
| 2048 |
+
uint32_t n_outputs_all = n_tokens_all;
|
| 2049 |
|
| 2050 |
+
auto mstate = memory->init_batch(batch, cparams.n_ubatch, true);
|
| 2051 |
if (!mstate || mstate->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
|
| 2052 |
LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
|
| 2053 |
break;
|
|
|
|
| 2055 |
|
| 2056 |
// reserve output buffer
|
| 2057 |
if (output_reserve(n_outputs_all) < n_outputs_all) {
|
| 2058 |
+
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all);
|
| 2059 |
GGML_ABORT("TODO: handle this error");
|
| 2060 |
};
|
| 2061 |
|
examples/talk-llama/llama-context.h
CHANGED
|
@@ -1,7 +1,6 @@
|
|
| 1 |
#pragma once
|
| 2 |
|
| 3 |
#include "llama.h"
|
| 4 |
-
#include "llama-batch.h"
|
| 5 |
#include "llama-cparams.h"
|
| 6 |
#include "llama-graph.h"
|
| 7 |
#include "llama-adapter.h"
|
|
@@ -13,6 +12,7 @@
|
|
| 13 |
#include <vector>
|
| 14 |
|
| 15 |
struct llama_model;
|
|
|
|
| 16 |
|
| 17 |
class llama_io_read_i;
|
| 18 |
class llama_io_write_i;
|
|
@@ -102,8 +102,8 @@ struct llama_context {
|
|
| 102 |
llama_memory_state_i * mstate,
|
| 103 |
ggml_status & ret);
|
| 104 |
|
| 105 |
-
int encode(llama_batch &
|
| 106 |
-
int decode(llama_batch &
|
| 107 |
|
| 108 |
//
|
| 109 |
// state save/load
|
|
@@ -181,7 +181,7 @@ private:
|
|
| 181 |
|
| 182 |
// Make sure enough space is available for outputs.
|
| 183 |
// Returns max number of outputs for which space was reserved.
|
| 184 |
-
|
| 185 |
|
| 186 |
//
|
| 187 |
// graph
|
|
@@ -246,8 +246,10 @@ private:
|
|
| 246 |
// populated only when pooling_type != LLAMA_POOLING_TYPE_NONE
|
| 247 |
std::map<llama_seq_id, std::vector<float>> embd_seq;
|
| 248 |
|
| 249 |
-
|
| 250 |
-
|
|
|
|
|
|
|
| 251 |
|
| 252 |
std::vector<int32_t> output_ids; // map batch token positions to ids of the logits and embd buffers
|
| 253 |
|
|
|
|
| 1 |
#pragma once
|
| 2 |
|
| 3 |
#include "llama.h"
|
|
|
|
| 4 |
#include "llama-cparams.h"
|
| 5 |
#include "llama-graph.h"
|
| 6 |
#include "llama-adapter.h"
|
|
|
|
| 12 |
#include <vector>
|
| 13 |
|
| 14 |
struct llama_model;
|
| 15 |
+
class llama_batch_allocr;
|
| 16 |
|
| 17 |
class llama_io_read_i;
|
| 18 |
class llama_io_write_i;
|
|
|
|
| 102 |
llama_memory_state_i * mstate,
|
| 103 |
ggml_status & ret);
|
| 104 |
|
| 105 |
+
int encode(const llama_batch & batch_inp);
|
| 106 |
+
int decode(const llama_batch & batch_inp);
|
| 107 |
|
| 108 |
//
|
| 109 |
// state save/load
|
|
|
|
| 181 |
|
| 182 |
// Make sure enough space is available for outputs.
|
| 183 |
// Returns max number of outputs for which space was reserved.
|
| 184 |
+
uint32_t output_reserve(int32_t n_outputs);
|
| 185 |
|
| 186 |
//
|
| 187 |
// graph
|
|
|
|
| 246 |
// populated only when pooling_type != LLAMA_POOLING_TYPE_NONE
|
| 247 |
std::map<llama_seq_id, std::vector<float>> embd_seq;
|
| 248 |
|
| 249 |
+
// reuse the batch_allocr to avoid unnecessary memory allocations
|
| 250 |
+
std::unique_ptr<llama_batch_allocr> batch_allocr;
|
| 251 |
+
|
| 252 |
+
uint32_t n_outputs = 0; // number of actually-used outputs in the current ubatch or last logical batch
|
| 253 |
|
| 254 |
std::vector<int32_t> output_ids; // map batch token positions to ids of the logits and embd buffers
|
| 255 |
|
examples/talk-llama/llama-cparams.cpp
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
#include "llama-cparams.h"
|
| 2 |
|
| 3 |
size_t llama_max_parallel_sequences(void) {
|
| 4 |
-
return
|
| 5 |
}
|
|
|
|
| 1 |
#include "llama-cparams.h"
|
| 2 |
|
| 3 |
size_t llama_max_parallel_sequences(void) {
|
| 4 |
+
return LLAMA_MAX_SEQ;
|
| 5 |
}
|
examples/talk-llama/llama-cparams.h
CHANGED
|
@@ -4,7 +4,7 @@
|
|
| 4 |
|
| 5 |
#include <cstdint>
|
| 6 |
|
| 7 |
-
#define
|
| 8 |
|
| 9 |
struct llama_cparams {
|
| 10 |
uint32_t n_ctx; // context size used during inference
|
|
|
|
| 4 |
|
| 5 |
#include <cstdint>
|
| 6 |
|
| 7 |
+
#define LLAMA_MAX_SEQ 64
|
| 8 |
|
| 9 |
struct llama_cparams {
|
| 10 |
uint32_t n_ctx; // context size used during inference
|
examples/talk-llama/llama-graph.cpp
CHANGED
|
@@ -139,6 +139,7 @@ void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
|
|
| 139 |
|
| 140 |
std::vector<uint64_t> sum(n_tokens, 0);
|
| 141 |
|
|
|
|
| 142 |
for (int s = 0; s < n_seqs; ++s) {
|
| 143 |
const llama_seq_id seq_id = ubatch->seq_id[s][0];
|
| 144 |
|
|
@@ -156,6 +157,7 @@ void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
|
|
| 156 |
}
|
| 157 |
}
|
| 158 |
|
|
|
|
| 159 |
for (int s = 0; s < n_seqs; ++s) {
|
| 160 |
const llama_seq_id seq_id = ubatch->seq_id[s][0];
|
| 161 |
|
|
@@ -180,6 +182,7 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
|
|
| 180 |
uint32_t * data = (uint32_t *) cls->data;
|
| 181 |
memset(cls->data, 0, n_tokens * ggml_element_size(cls));
|
| 182 |
|
|
|
|
| 183 |
for (int s = 0; s < n_seqs; ++s) {
|
| 184 |
const llama_seq_id seq_id = ubatch->seq_id[s][0];
|
| 185 |
|
|
@@ -210,6 +213,7 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
|
|
| 210 |
std::vector<int> last_pos(n_tokens, -1);
|
| 211 |
std::vector<int> last_row(n_tokens, -1);
|
| 212 |
|
|
|
|
| 213 |
for (int s = 0; s < n_seqs; ++s) {
|
| 214 |
const llama_seq_id seq_id = ubatch->seq_id[s][0];
|
| 215 |
|
|
@@ -250,22 +254,6 @@ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
|
|
| 250 |
}
|
| 251 |
}
|
| 252 |
|
| 253 |
-
void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) {
|
| 254 |
-
GGML_UNUSED(ubatch);
|
| 255 |
-
|
| 256 |
-
const int64_t n_kv = kv_state->get_n_kv();
|
| 257 |
-
|
| 258 |
-
if (s_mask) {
|
| 259 |
-
GGML_ASSERT(ggml_backend_buffer_is_host(s_mask->buffer));
|
| 260 |
-
float * data = (float *) s_mask->data;
|
| 261 |
-
|
| 262 |
-
// clear unused states
|
| 263 |
-
for (int i = 0; i < n_kv; ++i) {
|
| 264 |
-
data[i] = kv_state->s_mask(i);
|
| 265 |
-
}
|
| 266 |
-
}
|
| 267 |
-
}
|
| 268 |
-
|
| 269 |
void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
|
| 270 |
GGML_UNUSED(ubatch);
|
| 271 |
|
|
@@ -299,6 +287,7 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
|
|
| 299 |
const int32_t ti = s0*n_seq_tokens + i;
|
| 300 |
float f = -INFINITY;
|
| 301 |
|
|
|
|
| 302 |
for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
|
| 303 |
if (ubatch->seq_id[s0][s] == seq_id && ubatch->pos[ti] <= ubatch->pos[tj]) {
|
| 304 |
if (hparams.use_alibi) {
|
|
@@ -338,6 +327,7 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
|
|
| 338 |
const int32_t ti = s0*n_seq_tokens + i;
|
| 339 |
float f = -INFINITY;
|
| 340 |
|
|
|
|
| 341 |
for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
|
| 342 |
if (ubatch->seq_id[s0][s] == seq_id) {
|
| 343 |
if (hparams.use_alibi) {
|
|
@@ -393,6 +383,7 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
|
|
| 393 |
for (int j = 0; j < n_tokens; ++j) {
|
| 394 |
for (int i = 0; i < n_enc; ++i) {
|
| 395 |
float f = -INFINITY;
|
|
|
|
| 396 |
for (int s = 0; s < ubatch->n_seq_id[j]; ++s) {
|
| 397 |
const llama_seq_id seq_id = ubatch->seq_id[j][s];
|
| 398 |
if (cross->seq_ids_enc[i].find(seq_id) != cross->seq_ids_enc[i].end()) {
|
|
@@ -650,6 +641,7 @@ ggml_tensor * llm_graph_context::build_ffn(
|
|
| 650 |
{
|
| 651 |
// Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
|
| 652 |
int64_t split_point = cur->ne[0] / 2;
|
|
|
|
| 653 |
ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
|
| 654 |
ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
|
| 655 |
|
|
@@ -663,7 +655,7 @@ ggml_tensor * llm_graph_context::build_ffn(
|
|
| 663 |
{
|
| 664 |
// Split into two equal parts
|
| 665 |
int64_t split_point = cur->ne[0] / 2;
|
| 666 |
-
// TODO: these conts should not be needed
|
| 667 |
ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
|
| 668 |
ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
|
| 669 |
|
|
@@ -986,23 +978,6 @@ ggml_tensor * llm_graph_context::build_inp_s_copy() const {
|
|
| 986 |
return cur;
|
| 987 |
}
|
| 988 |
|
| 989 |
-
ggml_tensor * llm_graph_context::build_inp_s_mask() const {
|
| 990 |
-
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
|
| 991 |
-
|
| 992 |
-
auto inp = std::make_unique<llm_graph_input_s_mask>(kv_state);
|
| 993 |
-
|
| 994 |
-
const auto n_kv = kv_state->get_n_kv();
|
| 995 |
-
|
| 996 |
-
auto & cur = inp->s_mask;
|
| 997 |
-
|
| 998 |
-
cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_kv);
|
| 999 |
-
ggml_set_input(cur);
|
| 1000 |
-
|
| 1001 |
-
res->add_input(std::move(inp));
|
| 1002 |
-
|
| 1003 |
-
return cur;
|
| 1004 |
-
}
|
| 1005 |
-
|
| 1006 |
ggml_tensor * llm_graph_context::build_inp_cross_embd() const {
|
| 1007 |
auto inp = std::make_unique<llm_graph_input_cross_embd>(cross);
|
| 1008 |
|
|
@@ -1455,43 +1430,53 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
| 1455 |
return cur;
|
| 1456 |
}
|
| 1457 |
|
| 1458 |
-
ggml_tensor * llm_graph_context::
|
| 1459 |
ggml_cgraph * gf,
|
| 1460 |
ggml_tensor * s,
|
| 1461 |
ggml_tensor * state_copy,
|
| 1462 |
-
|
| 1463 |
-
int32_t
|
| 1464 |
-
|
| 1465 |
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
|
| 1466 |
|
| 1467 |
const auto n_kv = kv_state->get_n_kv();
|
| 1468 |
const auto kv_head = kv_state->get_head();
|
|
|
|
|
|
|
|
|
|
| 1469 |
|
| 1470 |
-
|
|
|
|
|
|
|
|
|
|
| 1471 |
|
| 1472 |
-
|
| 1473 |
-
// NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
|
| 1474 |
-
// this shrinks the tensors's ne[1] to n_kv
|
| 1475 |
-
states = ggml_get_rows(ctx0, states, state_copy);
|
| 1476 |
|
| 1477 |
-
|
| 1478 |
-
|
| 1479 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1480 |
|
| 1481 |
-
// copy states which won't be changed further (between n_seqs and n_kv)
|
|
|
|
| 1482 |
ggml_build_forward_expand(gf,
|
| 1483 |
ggml_cpy(ctx0,
|
| 1484 |
-
|
| 1485 |
-
ggml_view_1d(ctx0, s,
|
| 1486 |
|
| 1487 |
-
|
| 1488 |
-
return ggml_view_2d(ctx0, states, n_state, n_seqs, states->nb[1], 0);
|
| 1489 |
}
|
| 1490 |
|
| 1491 |
ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
|
| 1492 |
ggml_cgraph * gf,
|
| 1493 |
ggml_tensor * state_copy,
|
| 1494 |
-
ggml_tensor * state_mask,
|
| 1495 |
const llama_ubatch & ubatch,
|
| 1496 |
int il) const {
|
| 1497 |
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
|
|
@@ -1502,8 +1487,8 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
|
|
| 1502 |
|
| 1503 |
ggml_tensor * token_shift_all = kv_state->get_k_l(il);
|
| 1504 |
|
| 1505 |
-
ggml_tensor * token_shift =
|
| 1506 |
-
gf, token_shift_all, state_copy,
|
| 1507 |
hparams.n_embd_k_s(), n_seqs);
|
| 1508 |
|
| 1509 |
token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs);
|
|
@@ -1578,23 +1563,30 @@ void llm_graph_context::build_pooling(
|
|
| 1578 |
ggml_tensor * inp_cls = build_inp_cls();
|
| 1579 |
inp = ggml_get_rows(ctx0, inp, inp_cls);
|
| 1580 |
|
| 1581 |
-
if (cls
|
| 1582 |
// classification head
|
| 1583 |
// https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
|
| 1584 |
-
cur =
|
|
|
|
|
|
|
|
|
|
| 1585 |
cur = ggml_tanh(ctx0, cur);
|
| 1586 |
|
| 1587 |
// some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
|
| 1588 |
// https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896
|
| 1589 |
if (cls_out) {
|
| 1590 |
-
|
| 1591 |
-
|
|
|
|
|
|
|
| 1592 |
}
|
| 1593 |
} else if (cls_out) {
|
| 1594 |
// Single layer classification head (direct projection)
|
| 1595 |
// https://github.com/huggingface/transformers/blob/f4fc42216cd56ab6b68270bf80d811614d8d59e4/src/transformers/models/bert/modeling_bert.py#L1476
|
| 1596 |
-
|
| 1597 |
-
|
|
|
|
|
|
|
| 1598 |
} else {
|
| 1599 |
GGML_ABORT("RANK pooling requires either cls+cls_b or cls_out+cls_out_b");
|
| 1600 |
}
|
|
|
|
| 139 |
|
| 140 |
std::vector<uint64_t> sum(n_tokens, 0);
|
| 141 |
|
| 142 |
+
// TODO: fix indexing [UBATCH_IDX]
|
| 143 |
for (int s = 0; s < n_seqs; ++s) {
|
| 144 |
const llama_seq_id seq_id = ubatch->seq_id[s][0];
|
| 145 |
|
|
|
|
| 157 |
}
|
| 158 |
}
|
| 159 |
|
| 160 |
+
// TODO: fix indexing [UBATCH_IDX]
|
| 161 |
for (int s = 0; s < n_seqs; ++s) {
|
| 162 |
const llama_seq_id seq_id = ubatch->seq_id[s][0];
|
| 163 |
|
|
|
|
| 182 |
uint32_t * data = (uint32_t *) cls->data;
|
| 183 |
memset(cls->data, 0, n_tokens * ggml_element_size(cls));
|
| 184 |
|
| 185 |
+
// TODO: fix indexing [UBATCH_IDX]
|
| 186 |
for (int s = 0; s < n_seqs; ++s) {
|
| 187 |
const llama_seq_id seq_id = ubatch->seq_id[s][0];
|
| 188 |
|
|
|
|
| 213 |
std::vector<int> last_pos(n_tokens, -1);
|
| 214 |
std::vector<int> last_row(n_tokens, -1);
|
| 215 |
|
| 216 |
+
// TODO: fix indexing [UBATCH_IDX]
|
| 217 |
for (int s = 0; s < n_seqs; ++s) {
|
| 218 |
const llama_seq_id seq_id = ubatch->seq_id[s][0];
|
| 219 |
|
|
|
|
| 254 |
}
|
| 255 |
}
|
| 256 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 257 |
void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
|
| 258 |
GGML_UNUSED(ubatch);
|
| 259 |
|
|
|
|
| 287 |
const int32_t ti = s0*n_seq_tokens + i;
|
| 288 |
float f = -INFINITY;
|
| 289 |
|
| 290 |
+
// TODO: fix indexing [UBATCH_IDX]
|
| 291 |
for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
|
| 292 |
if (ubatch->seq_id[s0][s] == seq_id && ubatch->pos[ti] <= ubatch->pos[tj]) {
|
| 293 |
if (hparams.use_alibi) {
|
|
|
|
| 327 |
const int32_t ti = s0*n_seq_tokens + i;
|
| 328 |
float f = -INFINITY;
|
| 329 |
|
| 330 |
+
// TODO: fix indexing [UBATCH_IDX]
|
| 331 |
for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
|
| 332 |
if (ubatch->seq_id[s0][s] == seq_id) {
|
| 333 |
if (hparams.use_alibi) {
|
|
|
|
| 383 |
for (int j = 0; j < n_tokens; ++j) {
|
| 384 |
for (int i = 0; i < n_enc; ++i) {
|
| 385 |
float f = -INFINITY;
|
| 386 |
+
// TODO: fix indexing [UBATCH_IDX]
|
| 387 |
for (int s = 0; s < ubatch->n_seq_id[j]; ++s) {
|
| 388 |
const llama_seq_id seq_id = ubatch->seq_id[j][s];
|
| 389 |
if (cross->seq_ids_enc[i].find(seq_id) != cross->seq_ids_enc[i].end()) {
|
|
|
|
| 641 |
{
|
| 642 |
// Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
|
| 643 |
int64_t split_point = cur->ne[0] / 2;
|
| 644 |
+
// TODO: these conts should not be needed, see https://github.com/ggml-org/llama.cpp/pull/14090#discussion_r2137437217
|
| 645 |
ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
|
| 646 |
ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
|
| 647 |
|
|
|
|
| 655 |
{
|
| 656 |
// Split into two equal parts
|
| 657 |
int64_t split_point = cur->ne[0] / 2;
|
| 658 |
+
// TODO: these conts should not be needed, see https://github.com/ggml-org/llama.cpp/pull/14090#discussion_r2137437217
|
| 659 |
ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
|
| 660 |
ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
|
| 661 |
|
|
|
|
| 978 |
return cur;
|
| 979 |
}
|
| 980 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 981 |
ggml_tensor * llm_graph_context::build_inp_cross_embd() const {
|
| 982 |
auto inp = std::make_unique<llm_graph_input_cross_embd>(cross);
|
| 983 |
|
|
|
|
| 1430 |
return cur;
|
| 1431 |
}
|
| 1432 |
|
| 1433 |
+
ggml_tensor * llm_graph_context::build_recurrent_state(
|
| 1434 |
ggml_cgraph * gf,
|
| 1435 |
ggml_tensor * s,
|
| 1436 |
ggml_tensor * state_copy,
|
| 1437 |
+
int32_t state_size,
|
| 1438 |
+
int32_t n_seqs,
|
| 1439 |
+
bool avoid_copies) const {
|
| 1440 |
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
|
| 1441 |
|
| 1442 |
const auto n_kv = kv_state->get_n_kv();
|
| 1443 |
const auto kv_head = kv_state->get_head();
|
| 1444 |
+
const auto rs_zero = kv_state->get_rs_z();
|
| 1445 |
+
|
| 1446 |
+
ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_state->get_size());
|
| 1447 |
|
| 1448 |
+
// Clear a single state which will then be copied to the other cleared states.
|
| 1449 |
+
// Note that this is a no-op when the view is zero-sized.
|
| 1450 |
+
ggml_tensor * state_zero = ggml_view_1d(ctx0, states, state_size*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0));
|
| 1451 |
+
ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0));
|
| 1452 |
|
| 1453 |
+
ggml_tensor * output_states;
|
|
|
|
|
|
|
|
|
|
| 1454 |
|
| 1455 |
+
if (!avoid_copies) {
|
| 1456 |
+
// copy states
|
| 1457 |
+
// NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
|
| 1458 |
+
// {state_size, kv_size} -> {state_size, n_seqs}
|
| 1459 |
+
output_states = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0));
|
| 1460 |
+
ggml_build_forward_expand(gf, output_states);
|
| 1461 |
+
} else {
|
| 1462 |
+
// FIXME: make the gathering operation happen before the copy below
|
| 1463 |
+
// (maybe with an optional lambda function passed as a parameter instead of `avoid_copies`?)
|
| 1464 |
+
output_states = states;
|
| 1465 |
+
}
|
| 1466 |
|
| 1467 |
+
// copy extra states which won't be changed further (between n_seqs and n_kv)
|
| 1468 |
+
ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0]));
|
| 1469 |
ggml_build_forward_expand(gf,
|
| 1470 |
ggml_cpy(ctx0,
|
| 1471 |
+
states_extra,
|
| 1472 |
+
ggml_view_1d(ctx0, s, state_size*(n_kv - n_seqs), (kv_head + n_seqs)*state_size*ggml_element_size(s))));
|
| 1473 |
|
| 1474 |
+
return output_states;
|
|
|
|
| 1475 |
}
|
| 1476 |
|
| 1477 |
ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
|
| 1478 |
ggml_cgraph * gf,
|
| 1479 |
ggml_tensor * state_copy,
|
|
|
|
| 1480 |
const llama_ubatch & ubatch,
|
| 1481 |
int il) const {
|
| 1482 |
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
|
|
|
|
| 1487 |
|
| 1488 |
ggml_tensor * token_shift_all = kv_state->get_k_l(il);
|
| 1489 |
|
| 1490 |
+
ggml_tensor * token_shift = build_recurrent_state(
|
| 1491 |
+
gf, token_shift_all, state_copy,
|
| 1492 |
hparams.n_embd_k_s(), n_seqs);
|
| 1493 |
|
| 1494 |
token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs);
|
|
|
|
| 1563 |
ggml_tensor * inp_cls = build_inp_cls();
|
| 1564 |
inp = ggml_get_rows(ctx0, inp, inp_cls);
|
| 1565 |
|
| 1566 |
+
if (cls) {
|
| 1567 |
// classification head
|
| 1568 |
// https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
|
| 1569 |
+
cur = ggml_mul_mat(ctx0, cls, inp);
|
| 1570 |
+
if (cls_b) {
|
| 1571 |
+
cur = ggml_add(ctx0, cur, cls_b);
|
| 1572 |
+
}
|
| 1573 |
cur = ggml_tanh(ctx0, cur);
|
| 1574 |
|
| 1575 |
// some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
|
| 1576 |
// https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896
|
| 1577 |
if (cls_out) {
|
| 1578 |
+
cur = ggml_mul_mat(ctx0, cls_out, cur);
|
| 1579 |
+
if (cls_out_b) {
|
| 1580 |
+
cur = ggml_add(ctx0, cur, cls_out_b);
|
| 1581 |
+
}
|
| 1582 |
}
|
| 1583 |
} else if (cls_out) {
|
| 1584 |
// Single layer classification head (direct projection)
|
| 1585 |
// https://github.com/huggingface/transformers/blob/f4fc42216cd56ab6b68270bf80d811614d8d59e4/src/transformers/models/bert/modeling_bert.py#L1476
|
| 1586 |
+
cur = ggml_mul_mat(ctx0, cls_out, inp);
|
| 1587 |
+
if (cls_out_b) {
|
| 1588 |
+
cur = ggml_add(ctx0, cur, cls_out_b);
|
| 1589 |
+
}
|
| 1590 |
} else {
|
| 1591 |
GGML_ABORT("RANK pooling requires either cls+cls_b or cls_out+cls_out_b");
|
| 1592 |
}
|
examples/talk-llama/llama-graph.h
CHANGED
|
@@ -200,18 +200,6 @@ public:
|
|
| 200 |
const llama_kv_cache_recurrent_state * kv_state;
|
| 201 |
};
|
| 202 |
|
| 203 |
-
class llm_graph_input_s_mask : public llm_graph_input_i {
|
| 204 |
-
public:
|
| 205 |
-
llm_graph_input_s_mask(const llama_kv_cache_recurrent_state * kv_state) : kv_state(kv_state) {}
|
| 206 |
-
virtual ~llm_graph_input_s_mask() = default;
|
| 207 |
-
|
| 208 |
-
void set_input(const llama_ubatch * ubatch) override;
|
| 209 |
-
|
| 210 |
-
ggml_tensor * s_mask; // F32 [1, n_kv]
|
| 211 |
-
|
| 212 |
-
const llama_kv_cache_recurrent_state * kv_state;
|
| 213 |
-
};
|
| 214 |
-
|
| 215 |
class llm_graph_input_cross_embd : public llm_graph_input_i {
|
| 216 |
public:
|
| 217 |
llm_graph_input_cross_embd(
|
|
@@ -390,7 +378,7 @@ struct llm_graph_params {
|
|
| 390 |
const llama_memory_state_i * mstate;
|
| 391 |
const llama_cross * cross;
|
| 392 |
|
| 393 |
-
|
| 394 |
|
| 395 |
const llm_graph_cb & cb;
|
| 396 |
};
|
|
@@ -424,8 +412,8 @@ struct llm_graph_context {
|
|
| 424 |
const float norm_eps;
|
| 425 |
const float norm_rms_eps;
|
| 426 |
|
| 427 |
-
const
|
| 428 |
-
const
|
| 429 |
const int32_t n_ctx_orig; // yarn
|
| 430 |
|
| 431 |
const enum llama_pooling_type pooling_type;
|
|
@@ -521,7 +509,6 @@ struct llm_graph_context {
|
|
| 521 |
ggml_tensor * build_inp_mean() const;
|
| 522 |
ggml_tensor * build_inp_cls() const;
|
| 523 |
ggml_tensor * build_inp_s_copy() const;
|
| 524 |
-
ggml_tensor * build_inp_s_mask() const;
|
| 525 |
|
| 526 |
ggml_tensor * build_inp_cross_embd() const;
|
| 527 |
ggml_tensor * build_inp_pos_bucket_enc() const;
|
|
@@ -606,18 +593,17 @@ struct llm_graph_context {
|
|
| 606 |
// recurrent
|
| 607 |
//
|
| 608 |
|
| 609 |
-
ggml_tensor *
|
| 610 |
ggml_cgraph * gf,
|
| 611 |
ggml_tensor * s,
|
| 612 |
ggml_tensor * state_copy,
|
| 613 |
-
|
| 614 |
-
int32_t
|
| 615 |
-
|
| 616 |
|
| 617 |
ggml_tensor * build_rwkv_token_shift_load(
|
| 618 |
ggml_cgraph * gf,
|
| 619 |
ggml_tensor * state_copy,
|
| 620 |
-
ggml_tensor * state_mask,
|
| 621 |
const llama_ubatch & ubatch,
|
| 622 |
int il) const;
|
| 623 |
|
|
|
|
| 200 |
const llama_kv_cache_recurrent_state * kv_state;
|
| 201 |
};
|
| 202 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
class llm_graph_input_cross_embd : public llm_graph_input_i {
|
| 204 |
public:
|
| 205 |
llm_graph_input_cross_embd(
|
|
|
|
| 378 |
const llama_memory_state_i * mstate;
|
| 379 |
const llama_cross * cross;
|
| 380 |
|
| 381 |
+
uint32_t n_outputs;
|
| 382 |
|
| 383 |
const llm_graph_cb & cb;
|
| 384 |
};
|
|
|
|
| 412 |
const float norm_eps;
|
| 413 |
const float norm_rms_eps;
|
| 414 |
|
| 415 |
+
const int64_t n_tokens;
|
| 416 |
+
const int64_t n_outputs;
|
| 417 |
const int32_t n_ctx_orig; // yarn
|
| 418 |
|
| 419 |
const enum llama_pooling_type pooling_type;
|
|
|
|
| 509 |
ggml_tensor * build_inp_mean() const;
|
| 510 |
ggml_tensor * build_inp_cls() const;
|
| 511 |
ggml_tensor * build_inp_s_copy() const;
|
|
|
|
| 512 |
|
| 513 |
ggml_tensor * build_inp_cross_embd() const;
|
| 514 |
ggml_tensor * build_inp_pos_bucket_enc() const;
|
|
|
|
| 593 |
// recurrent
|
| 594 |
//
|
| 595 |
|
| 596 |
+
ggml_tensor * build_recurrent_state(
|
| 597 |
ggml_cgraph * gf,
|
| 598 |
ggml_tensor * s,
|
| 599 |
ggml_tensor * state_copy,
|
| 600 |
+
int32_t state_size,
|
| 601 |
+
int32_t n_seqs,
|
| 602 |
+
bool avoid_copies = false) const;
|
| 603 |
|
| 604 |
ggml_tensor * build_rwkv_token_shift_load(
|
| 605 |
ggml_cgraph * gf,
|
| 606 |
ggml_tensor * state_copy,
|
|
|
|
| 607 |
const llama_ubatch & ubatch,
|
| 608 |
int il) const;
|
| 609 |
|
examples/talk-llama/llama-kv-cache-recurrent.cpp
CHANGED
|
@@ -359,18 +359,16 @@ llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const {
|
|
| 359 |
return result;
|
| 360 |
}
|
| 361 |
|
| 362 |
-
llama_memory_state_ptr llama_kv_cache_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
auto sbatch = llama_sbatch(batch, hparams.n_embd, false, logits_all);
|
| 366 |
|
| 367 |
std::vector<llama_ubatch> ubatches;
|
| 368 |
|
| 369 |
while (sbatch.n_tokens > 0) {
|
| 370 |
llama_ubatch ubatch;
|
| 371 |
|
| 372 |
-
if (
|
| 373 |
-
//
|
| 374 |
ubatch = sbatch.split_seq(n_ubatch);
|
| 375 |
} else {
|
| 376 |
ubatch = sbatch.split_equal(n_ubatch);
|
|
@@ -406,21 +404,12 @@ bool llama_kv_cache_recurrent::prepare(const std::vector<llama_ubatch> & ubatche
|
|
| 406 |
|
| 407 |
bool success = true;
|
| 408 |
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
// recovery from failures when the batch does not fit in the KV cache will not work correctly until this is fixed
|
| 416 |
-
//
|
| 417 |
-
GGML_UNUSED(ubatches);
|
| 418 |
-
//for (const auto & ubatch : ubatches) {
|
| 419 |
-
// if (!find_slot(ubatch)) {
|
| 420 |
-
// success = false;
|
| 421 |
-
// break;
|
| 422 |
-
// }
|
| 423 |
-
//}
|
| 424 |
|
| 425 |
// restore the original state
|
| 426 |
cells = std::move(org_cells);
|
|
@@ -431,14 +420,13 @@ bool llama_kv_cache_recurrent::prepare(const std::vector<llama_ubatch> & ubatche
|
|
| 431 |
}
|
| 432 |
|
| 433 |
bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
|
| 434 |
-
const uint32_t
|
| 435 |
-
const uint32_t n_seqs = ubatch.n_seqs;
|
| 436 |
|
| 437 |
const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
|
| 438 |
|
| 439 |
// if we have enough unused cells before the current head ->
|
| 440 |
// better to start searching from the beginning of the cache, hoping to fill it
|
| 441 |
-
if (head > used + 2*
|
| 442 |
head = 0;
|
| 443 |
}
|
| 444 |
|
|
@@ -534,16 +522,16 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
|
|
| 534 |
empty_cell.src = orig_cell.src;
|
| 535 |
orig_cell.seq_id.erase(seq_id);
|
| 536 |
empty_cell.seq_id.insert(seq_id); // will be overwritten
|
|
|
|
| 537 |
}
|
| 538 |
seq_meta.tail = next_empty_cell;
|
| 539 |
// find next empty cell
|
| 540 |
if (s + 1 < n_seqs) {
|
| 541 |
-
next_empty_cell += 1;
|
| 542 |
for (uint32_t i = 0; i < size; ++i) {
|
|
|
|
| 543 |
if (next_empty_cell >= size) { next_empty_cell -= size; }
|
| 544 |
kv_cell & cell = cells[next_empty_cell];
|
| 545 |
if (cell.is_empty()) { break; }
|
| 546 |
-
next_empty_cell += 1;
|
| 547 |
}
|
| 548 |
}
|
| 549 |
}
|
|
@@ -553,8 +541,8 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
|
|
| 553 |
|
| 554 |
// gather and re-order
|
| 555 |
for (uint32_t s = 0; s < n_seqs; ++s) {
|
| 556 |
-
int32_t dst_id = s + min;
|
| 557 |
-
int32_t src_id = cells[ubatch.seq_id[s][0]].tail;
|
| 558 |
if (dst_id != src_id) {
|
| 559 |
kv_cell & dst_cell = cells[dst_id];
|
| 560 |
kv_cell & src_cell = cells[src_id];
|
|
@@ -563,12 +551,14 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
|
|
| 563 |
std::swap(dst_cell.src, src_cell.src);
|
| 564 |
std::swap(dst_cell.seq_id, src_cell.seq_id);
|
| 565 |
|
| 566 |
-
// swap tails
|
| 567 |
-
for (
|
| 568 |
-
cells[
|
| 569 |
-
|
| 570 |
-
|
| 571 |
-
|
|
|
|
|
|
|
| 572 |
}
|
| 573 |
}
|
| 574 |
}
|
|
@@ -576,7 +566,7 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
|
|
| 576 |
// update the pos of the used seqs
|
| 577 |
for (uint32_t s = 0; s < n_seqs; ++s) {
|
| 578 |
const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1];
|
| 579 |
-
int32_t cell_id = s + min;
|
| 580 |
kv_cell & cell = cells[cell_id];
|
| 581 |
|
| 582 |
if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
|
|
@@ -594,6 +584,38 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
|
|
| 594 |
}
|
| 595 |
}
|
| 596 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 597 |
// allow getting the range of used cells, from head to head + n
|
| 598 |
head = min;
|
| 599 |
n = max - min + 1;
|
|
@@ -605,47 +627,8 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
|
|
| 605 |
}
|
| 606 |
|
| 607 |
bool llama_kv_cache_recurrent::get_can_shift() const {
|
| 608 |
-
|
| 609 |
-
|
| 610 |
-
|
| 611 |
-
int32_t llama_kv_cache_recurrent::s_copy(int i) const {
|
| 612 |
-
const uint32_t cell_id = i + head;
|
| 613 |
-
|
| 614 |
-
//////////////////////////////////////////////
|
| 615 |
-
// TODO: this should not mutate the KV cache !
|
| 616 |
-
kv_cell & cell = const_cast<kv_cell &>(cells[cell_id]);
|
| 617 |
-
|
| 618 |
-
// prevent out-of-bound sources
|
| 619 |
-
if (cell.src < 0 || (uint32_t) cell.src >= size) {
|
| 620 |
-
cell.src = cell_id;
|
| 621 |
-
}
|
| 622 |
-
|
| 623 |
-
int32_t res = cell.src;
|
| 624 |
-
|
| 625 |
-
// TODO: do not mutate the KV cache
|
| 626 |
-
// ensure copy only happens once
|
| 627 |
-
if (cell.src != (int32_t) cell_id) {
|
| 628 |
-
cell.src = cell_id;
|
| 629 |
-
}
|
| 630 |
-
|
| 631 |
-
return res;
|
| 632 |
-
}
|
| 633 |
-
|
| 634 |
-
float llama_kv_cache_recurrent::s_mask(int i) const {
|
| 635 |
-
const uint32_t cell_id = i + head;
|
| 636 |
-
|
| 637 |
-
//////////////////////////////////////////////
|
| 638 |
-
// TODO: this should not mutate the KV cache !
|
| 639 |
-
kv_cell & cell = const_cast<kv_cell &>(cells[cell_id]);
|
| 640 |
-
|
| 641 |
-
float res = (float) (cell.src >= 0);
|
| 642 |
-
|
| 643 |
-
// only clear once
|
| 644 |
-
if (cell.src < 0) {
|
| 645 |
-
cell.src = cell_id;
|
| 646 |
-
}
|
| 647 |
-
|
| 648 |
-
return res;
|
| 649 |
}
|
| 650 |
|
| 651 |
size_t llama_kv_cache_recurrent::total_size() const {
|
|
@@ -1111,6 +1094,10 @@ uint32_t llama_kv_cache_recurrent_state::get_head() const {
|
|
| 1111 |
return is_full ? 0 : kv->head;
|
| 1112 |
}
|
| 1113 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1114 |
uint32_t llama_kv_cache_recurrent_state::get_size() const {
|
| 1115 |
return kv->size;
|
| 1116 |
}
|
|
@@ -1124,9 +1111,5 @@ ggml_tensor * llama_kv_cache_recurrent_state::get_v_l(int32_t il) const {
|
|
| 1124 |
}
|
| 1125 |
|
| 1126 |
int32_t llama_kv_cache_recurrent_state::s_copy(int i) const {
|
| 1127 |
-
return kv->
|
| 1128 |
-
}
|
| 1129 |
-
|
| 1130 |
-
float llama_kv_cache_recurrent_state::s_mask(int i) const {
|
| 1131 |
-
return kv->s_mask(i);
|
| 1132 |
}
|
|
|
|
| 359 |
return result;
|
| 360 |
}
|
| 361 |
|
| 362 |
+
llama_memory_state_ptr llama_kv_cache_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_all) {
|
| 363 |
+
auto sbatch = llama_sbatch(batch, hparams.n_embd, false);
|
|
|
|
|
|
|
| 364 |
|
| 365 |
std::vector<llama_ubatch> ubatches;
|
| 366 |
|
| 367 |
while (sbatch.n_tokens > 0) {
|
| 368 |
llama_ubatch ubatch;
|
| 369 |
|
| 370 |
+
if (embd_all) {
|
| 371 |
+
// if all tokens are output, split by sequence
|
| 372 |
ubatch = sbatch.split_seq(n_ubatch);
|
| 373 |
} else {
|
| 374 |
ubatch = sbatch.split_equal(n_ubatch);
|
|
|
|
| 404 |
|
| 405 |
bool success = true;
|
| 406 |
|
| 407 |
+
for (const auto & ubatch : ubatches) {
|
| 408 |
+
if (!find_slot(ubatch)) {
|
| 409 |
+
success = false;
|
| 410 |
+
break;
|
| 411 |
+
}
|
| 412 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 413 |
|
| 414 |
// restore the original state
|
| 415 |
cells = std::move(org_cells);
|
|
|
|
| 420 |
}
|
| 421 |
|
| 422 |
bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
|
| 423 |
+
const uint32_t n_seqs = ubatch.n_seqs;
|
|
|
|
| 424 |
|
| 425 |
const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
|
| 426 |
|
| 427 |
// if we have enough unused cells before the current head ->
|
| 428 |
// better to start searching from the beginning of the cache, hoping to fill it
|
| 429 |
+
if (head > used + 2*n_seqs) {
|
| 430 |
head = 0;
|
| 431 |
}
|
| 432 |
|
|
|
|
| 522 |
empty_cell.src = orig_cell.src;
|
| 523 |
orig_cell.seq_id.erase(seq_id);
|
| 524 |
empty_cell.seq_id.insert(seq_id); // will be overwritten
|
| 525 |
+
GGML_ASSERT(!orig_cell.is_empty()); // has at least one remaining seq_id
|
| 526 |
}
|
| 527 |
seq_meta.tail = next_empty_cell;
|
| 528 |
// find next empty cell
|
| 529 |
if (s + 1 < n_seqs) {
|
|
|
|
| 530 |
for (uint32_t i = 0; i < size; ++i) {
|
| 531 |
+
next_empty_cell += 1;
|
| 532 |
if (next_empty_cell >= size) { next_empty_cell -= size; }
|
| 533 |
kv_cell & cell = cells[next_empty_cell];
|
| 534 |
if (cell.is_empty()) { break; }
|
|
|
|
| 535 |
}
|
| 536 |
}
|
| 537 |
}
|
|
|
|
| 541 |
|
| 542 |
// gather and re-order
|
| 543 |
for (uint32_t s = 0; s < n_seqs; ++s) {
|
| 544 |
+
const int32_t dst_id = s + min;
|
| 545 |
+
const int32_t src_id = cells[ubatch.seq_id[s][0]].tail;
|
| 546 |
if (dst_id != src_id) {
|
| 547 |
kv_cell & dst_cell = cells[dst_id];
|
| 548 |
kv_cell & src_cell = cells[src_id];
|
|
|
|
| 551 |
std::swap(dst_cell.src, src_cell.src);
|
| 552 |
std::swap(dst_cell.seq_id, src_cell.seq_id);
|
| 553 |
|
| 554 |
+
// swap tails
|
| 555 |
+
for (uint32_t i = 0; i < size; ++i) {
|
| 556 |
+
int32_t & tail = cells[i].tail;
|
| 557 |
+
if (tail == src_id) {
|
| 558 |
+
tail = dst_id;
|
| 559 |
+
} else if (tail == dst_id) {
|
| 560 |
+
tail = src_id;
|
| 561 |
+
}
|
| 562 |
}
|
| 563 |
}
|
| 564 |
}
|
|
|
|
| 566 |
// update the pos of the used seqs
|
| 567 |
for (uint32_t s = 0; s < n_seqs; ++s) {
|
| 568 |
const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1];
|
| 569 |
+
const int32_t cell_id = s + min;
|
| 570 |
kv_cell & cell = cells[cell_id];
|
| 571 |
|
| 572 |
if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
|
|
|
|
| 584 |
}
|
| 585 |
}
|
| 586 |
|
| 587 |
+
// Find first cell without src refs, to use as the zero-ed state
|
| 588 |
+
{
|
| 589 |
+
// TODO: bake-in src refcounts in the cell metadata
|
| 590 |
+
std::vector<int32_t> refcounts(size, 0);
|
| 591 |
+
for (size_t i = 0; i < size; ++i) {
|
| 592 |
+
const int32_t src = cells[i].src;
|
| 593 |
+
if (src >= 0) {
|
| 594 |
+
refcounts[src] += 1;
|
| 595 |
+
}
|
| 596 |
+
}
|
| 597 |
+
|
| 598 |
+
rs_z = -1;
|
| 599 |
+
for (int i = min; i <= max; ++i) {
|
| 600 |
+
if (refcounts[i] == 0) {
|
| 601 |
+
rs_z = i;
|
| 602 |
+
break;
|
| 603 |
+
}
|
| 604 |
+
}
|
| 605 |
+
|
| 606 |
+
for (int i = min; i <= max; ++i) {
|
| 607 |
+
if (cells[i].src < 0) {
|
| 608 |
+
GGML_ASSERT(rs_z >= 0);
|
| 609 |
+
cells[i].src0 = rs_z;
|
| 610 |
+
} else {
|
| 611 |
+
// Stage the source ids for all used cells to allow correct seq_* behavior
|
| 612 |
+
// and still make these values available when setting the inputs
|
| 613 |
+
cells[i].src0 = cells[i].src;
|
| 614 |
+
}
|
| 615 |
+
cells[i].src = i; // avoid moving or clearing twice
|
| 616 |
+
}
|
| 617 |
+
}
|
| 618 |
+
|
| 619 |
// allow getting the range of used cells, from head to head + n
|
| 620 |
head = min;
|
| 621 |
n = max - min + 1;
|
|
|
|
| 627 |
}
|
| 628 |
|
| 629 |
bool llama_kv_cache_recurrent::get_can_shift() const {
|
| 630 |
+
// shifting the pos is trivial for recurrent models
|
| 631 |
+
return true;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 632 |
}
|
| 633 |
|
| 634 |
size_t llama_kv_cache_recurrent::total_size() const {
|
|
|
|
| 1094 |
return is_full ? 0 : kv->head;
|
| 1095 |
}
|
| 1096 |
|
| 1097 |
+
int32_t llama_kv_cache_recurrent_state::get_rs_z() const {
|
| 1098 |
+
return is_full ? 0 : kv->rs_z;
|
| 1099 |
+
}
|
| 1100 |
+
|
| 1101 |
uint32_t llama_kv_cache_recurrent_state::get_size() const {
|
| 1102 |
return kv->size;
|
| 1103 |
}
|
|
|
|
| 1111 |
}
|
| 1112 |
|
| 1113 |
int32_t llama_kv_cache_recurrent_state::s_copy(int i) const {
|
| 1114 |
+
return kv->cells[i + kv->head].src0;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1115 |
}
|
examples/talk-llama/llama-kv-cache-recurrent.h
CHANGED
|
@@ -32,8 +32,7 @@ public:
|
|
| 32 |
llama_memory_state_ptr init_batch(
|
| 33 |
const llama_batch & batch,
|
| 34 |
uint32_t n_ubatch,
|
| 35 |
-
bool
|
| 36 |
-
bool logits_all) override;
|
| 37 |
|
| 38 |
llama_memory_state_ptr init_full() override;
|
| 39 |
|
|
@@ -57,10 +56,6 @@ public:
|
|
| 57 |
|
| 58 |
bool get_can_shift() const override;
|
| 59 |
|
| 60 |
-
// TODO: temporary methods - they are not really const as they do const_cast<>, fix this
|
| 61 |
-
int32_t s_copy(int i) const;
|
| 62 |
-
float s_mask(int i) const;
|
| 63 |
-
|
| 64 |
// state write/load
|
| 65 |
|
| 66 |
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
|
|
@@ -73,10 +68,14 @@ public:
|
|
| 73 |
// computed before each graph build
|
| 74 |
uint32_t n = 0;
|
| 75 |
|
|
|
|
|
|
|
|
|
|
| 76 |
// TODO: optimize for recurrent state needs
|
| 77 |
struct kv_cell {
|
| 78 |
llama_pos pos = -1;
|
| 79 |
-
int32_t src = -1; // used to
|
|
|
|
| 80 |
int32_t tail = -1;
|
| 81 |
|
| 82 |
std::set<llama_seq_id> seq_id;
|
|
@@ -157,13 +156,13 @@ public:
|
|
| 157 |
|
| 158 |
uint32_t get_n_kv() const;
|
| 159 |
uint32_t get_head() const;
|
|
|
|
| 160 |
uint32_t get_size() const;
|
| 161 |
|
| 162 |
ggml_tensor * get_k_l(int32_t il) const;
|
| 163 |
ggml_tensor * get_v_l(int32_t il) const;
|
| 164 |
|
| 165 |
int32_t s_copy(int i) const;
|
| 166 |
-
float s_mask(int i) const;
|
| 167 |
|
| 168 |
private:
|
| 169 |
const llama_memory_status status;
|
|
|
|
| 32 |
llama_memory_state_ptr init_batch(
|
| 33 |
const llama_batch & batch,
|
| 34 |
uint32_t n_ubatch,
|
| 35 |
+
bool embd_all) override;
|
|
|
|
| 36 |
|
| 37 |
llama_memory_state_ptr init_full() override;
|
| 38 |
|
|
|
|
| 56 |
|
| 57 |
bool get_can_shift() const override;
|
| 58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
// state write/load
|
| 60 |
|
| 61 |
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
|
|
|
|
| 68 |
// computed before each graph build
|
| 69 |
uint32_t n = 0;
|
| 70 |
|
| 71 |
+
// first zero-ed state
|
| 72 |
+
int32_t rs_z = -1;
|
| 73 |
+
|
| 74 |
// TODO: optimize for recurrent state needs
|
| 75 |
struct kv_cell {
|
| 76 |
llama_pos pos = -1;
|
| 77 |
+
int32_t src = -1; // used to know where states should be copied from
|
| 78 |
+
int32_t src0 = -1; // like src, but only used when setting the inputs (allowing to copy once)
|
| 79 |
int32_t tail = -1;
|
| 80 |
|
| 81 |
std::set<llama_seq_id> seq_id;
|
|
|
|
| 156 |
|
| 157 |
uint32_t get_n_kv() const;
|
| 158 |
uint32_t get_head() const;
|
| 159 |
+
int32_t get_rs_z() const;
|
| 160 |
uint32_t get_size() const;
|
| 161 |
|
| 162 |
ggml_tensor * get_k_l(int32_t il) const;
|
| 163 |
ggml_tensor * get_v_l(int32_t il) const;
|
| 164 |
|
| 165 |
int32_t s_copy(int i) const;
|
|
|
|
| 166 |
|
| 167 |
private:
|
| 168 |
const llama_memory_status status;
|
examples/talk-llama/llama-kv-cache-unified-iswa.cpp
CHANGED
|
@@ -95,36 +95,69 @@ llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
|
|
| 95 |
return kv_swa->seq_pos_max(seq_id);
|
| 96 |
}
|
| 97 |
|
| 98 |
-
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool
|
| 99 |
-
GGML_UNUSED(
|
| 100 |
|
| 101 |
-
//
|
| 102 |
-
|
|
|
|
| 103 |
|
| 104 |
-
|
| 105 |
|
| 106 |
-
|
|
|
|
| 107 |
|
| 108 |
-
|
| 109 |
-
|
| 110 |
|
| 111 |
-
ubatches
|
| 112 |
-
|
|
|
|
|
|
|
| 113 |
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
return std::make_unique<llama_kv_cache_unified_iswa_state>(
|
| 122 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
|
| 124 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
|
| 126 |
-
return std::make_unique<llama_kv_cache_unified_iswa_state>(
|
| 127 |
-
this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches));
|
| 128 |
}
|
| 129 |
|
| 130 |
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_full() {
|
|
|
|
| 95 |
return kv_swa->seq_pos_max(seq_id);
|
| 96 |
}
|
| 97 |
|
| 98 |
+
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_all) {
|
| 99 |
+
GGML_UNUSED(embd_all);
|
| 100 |
|
| 101 |
+
// first try simple split
|
| 102 |
+
do {
|
| 103 |
+
auto sbatch = llama_sbatch(batch, hparams.n_embd, true);
|
| 104 |
|
| 105 |
+
std::vector<llama_ubatch> ubatches;
|
| 106 |
|
| 107 |
+
while (sbatch.n_tokens > 0) {
|
| 108 |
+
auto ubatch = sbatch.split_simple(n_ubatch);
|
| 109 |
|
| 110 |
+
ubatches.push_back(ubatch);
|
| 111 |
+
}
|
| 112 |
|
| 113 |
+
auto heads_base = kv_base->prepare(ubatches);
|
| 114 |
+
if (heads_base.empty()) {
|
| 115 |
+
break;
|
| 116 |
+
}
|
| 117 |
|
| 118 |
+
auto heads_swa = kv_swa->prepare(ubatches);
|
| 119 |
+
if (heads_swa.empty()) {
|
| 120 |
+
break;
|
| 121 |
+
}
|
| 122 |
|
| 123 |
+
assert(heads_base.size() == heads_swa.size());
|
| 124 |
+
|
| 125 |
+
return std::make_unique<llama_kv_cache_unified_iswa_state>(
|
| 126 |
+
this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches));
|
| 127 |
+
} while (false);
|
| 128 |
+
|
| 129 |
+
// if it fails, try equal split
|
| 130 |
+
do {
|
| 131 |
+
auto sbatch = llama_sbatch(batch, hparams.n_embd, false);
|
| 132 |
+
|
| 133 |
+
std::vector<llama_ubatch> ubatches;
|
| 134 |
|
| 135 |
+
while (sbatch.n_tokens > 0) {
|
| 136 |
+
auto ubatch = sbatch.split_equal(n_ubatch);
|
| 137 |
+
|
| 138 |
+
ubatches.push_back(ubatch);
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
auto heads_base = kv_base->prepare(ubatches);
|
| 142 |
+
if (heads_base.empty()) {
|
| 143 |
+
break;
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
auto heads_swa = kv_swa->prepare(ubatches);
|
| 147 |
+
if (heads_swa.empty()) {
|
| 148 |
+
break;
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
assert(heads_base.size() == heads_swa.size());
|
| 152 |
+
|
| 153 |
+
return std::make_unique<llama_kv_cache_unified_iswa_state>(
|
| 154 |
+
this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches));
|
| 155 |
+
} while (false);
|
| 156 |
+
|
| 157 |
+
// TODO: if we fail again, we should attempt different splitting strategies
|
| 158 |
+
// but to do that properly, we first have to refactor the batches to be more flexible
|
| 159 |
|
| 160 |
+
return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
|
|
|
| 161 |
}
|
| 162 |
|
| 163 |
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_full() {
|
examples/talk-llama/llama-kv-cache-unified-iswa.h
CHANGED
|
@@ -34,8 +34,7 @@ public:
|
|
| 34 |
llama_memory_state_ptr init_batch(
|
| 35 |
const llama_batch & batch,
|
| 36 |
uint32_t n_ubatch,
|
| 37 |
-
bool
|
| 38 |
-
bool logits_all) override;
|
| 39 |
|
| 40 |
llama_memory_state_ptr init_full() override;
|
| 41 |
|
|
|
|
| 34 |
llama_memory_state_ptr init_batch(
|
| 35 |
const llama_batch & batch,
|
| 36 |
uint32_t n_ubatch,
|
| 37 |
+
bool embd_all) override;
|
|
|
|
| 38 |
|
| 39 |
llama_memory_state_ptr init_full() override;
|
| 40 |
|
examples/talk-llama/llama-kv-cache-unified.cpp
CHANGED
|
@@ -127,6 +127,9 @@ llama_kv_cache_unified::llama_kv_cache_unified(
|
|
| 127 |
ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
|
| 128 |
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
|
| 129 |
}
|
|
|
|
|
|
|
|
|
|
| 130 |
}
|
| 131 |
|
| 132 |
void llama_kv_cache_unified::clear(bool data) {
|
|
@@ -307,24 +310,27 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
|
|
| 307 |
llama_memory_state_ptr llama_kv_cache_unified::init_batch(
|
| 308 |
const llama_batch & batch,
|
| 309 |
uint32_t n_ubatch,
|
| 310 |
-
bool
|
| 311 |
-
|
| 312 |
-
GGML_UNUSED(embd_pooled);
|
| 313 |
|
| 314 |
-
|
|
|
|
| 315 |
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 325 |
|
| 326 |
-
return std::make_unique<llama_kv_cache_unified_state>(
|
| 327 |
-
this, std::move(sbatch), std::move(heads), std::move(ubatches));
|
| 328 |
}
|
| 329 |
|
| 330 |
llama_memory_state_ptr llama_kv_cache_unified::init_full() {
|
|
@@ -512,43 +518,68 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
|
|
| 512 |
head_cur = 0;
|
| 513 |
}
|
| 514 |
|
| 515 |
-
// otherwise, one cell per token.
|
| 516 |
-
|
| 517 |
if (n_tokens > cells.size()) {
|
| 518 |
LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size());
|
| 519 |
return -1;
|
| 520 |
}
|
| 521 |
|
| 522 |
-
|
| 523 |
-
|
| 524 |
-
LLAMA_LOG_WARN("begin: n = %5d, used = %5d, head = %5d, n_swa = %5d\n", cells.used_max_p1(), cells.get_used(), head, n_swa);
|
| 525 |
|
| 526 |
-
|
| 527 |
-
|
| 528 |
-
std::string ss;
|
| 529 |
-
if (n_swa > 0) {
|
| 530 |
for (uint32_t i = 0; i < cells.size(); ++i) {
|
| 531 |
if (cells.is_empty(i)) {
|
| 532 |
ss += '.';
|
| 533 |
} else {
|
| 534 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 535 |
}
|
| 536 |
if (i%256 == 255) {
|
|
|
|
| 537 |
ss += '\n';
|
| 538 |
}
|
| 539 |
}
|
|
|
|
| 540 |
}
|
| 541 |
-
LLAMA_LOG_WARN("\n%s\n", ss.c_str());
|
| 542 |
-
}
|
| 543 |
|
| 544 |
-
|
| 545 |
-
|
| 546 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 547 |
}
|
| 548 |
|
| 549 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 550 |
}
|
| 551 |
-
#endif
|
| 552 |
|
| 553 |
uint32_t n_tested = 0;
|
| 554 |
|
|
@@ -559,21 +590,15 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
|
|
| 559 |
continue;
|
| 560 |
}
|
| 561 |
|
| 562 |
-
// keep track of what the minimum sequence positions would be if we accept the ubatch
|
| 563 |
-
llama_seq_id seq_pos_min[LLAMA_MAX_PARALLEL_SEQUENCES];
|
| 564 |
-
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
|
| 565 |
-
seq_pos_min[s] = cells.seq_pos_min(s);
|
| 566 |
-
}
|
| 567 |
-
|
| 568 |
bool found = true;
|
| 569 |
for (uint32_t i = 0; i < n_tokens; i++) {
|
| 570 |
-
const llama_pos pos = ubatch.pos[i];
|
| 571 |
-
const llama_seq_id seq_id = ubatch.seq_id[i][0];
|
| 572 |
|
| 573 |
// can we use this cell? either:
|
| 574 |
// - the cell is empty
|
| 575 |
// - the cell is occupied only by one sequence:
|
| 576 |
-
// - mask causally, if the sequence is the same as the one we are inserting
|
| 577 |
// - mask SWA, using current max pos for that sequence in the cache
|
| 578 |
// always insert in the cell with minimum pos
|
| 579 |
bool can_use = cells.is_empty(head_cur + i);
|
|
@@ -581,21 +606,17 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
|
|
| 581 |
if (!can_use && cells.seq_count(head_cur + i) == 1) {
|
| 582 |
const llama_pos pos_cell = cells.pos_get(head_cur + i);
|
| 583 |
|
| 584 |
-
// causal mask
|
| 585 |
-
|
| 586 |
-
|
| 587 |
-
|
|
|
|
| 588 |
|
| 589 |
if (!can_use) {
|
| 590 |
const llama_seq_id seq_id_cell = cells.seq_get(head_cur + i);
|
| 591 |
|
| 592 |
// SWA mask
|
| 593 |
-
|
| 594 |
-
// all positions between [pos_min, pos_max] for each sequence will be present in the cache
|
| 595 |
-
// ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092
|
| 596 |
-
if (pos_cell == seq_pos_min[seq_id_cell] &&
|
| 597 |
-
is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
|
| 598 |
-
seq_pos_min[seq_id_cell]++;
|
| 599 |
can_use = true;
|
| 600 |
}
|
| 601 |
}
|
|
@@ -623,18 +644,58 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
|
|
| 623 |
}
|
| 624 |
|
| 625 |
void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch) {
|
| 626 |
-
|
| 627 |
-
|
| 628 |
-
|
| 629 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 630 |
|
| 631 |
-
|
| 632 |
|
| 633 |
-
|
| 634 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 635 |
}
|
| 636 |
}
|
| 637 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 638 |
// move the head at the end of the slot
|
| 639 |
head = head_cur + ubatch.n_tokens;
|
| 640 |
}
|
|
@@ -731,14 +792,14 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
|
|
| 731 |
}
|
| 732 |
|
| 733 |
void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
|
| 734 |
-
const
|
| 735 |
-
const
|
| 736 |
-
const
|
| 737 |
|
| 738 |
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
|
| 739 |
float * data = (float *) dst->data;
|
| 740 |
|
| 741 |
-
const
|
| 742 |
|
| 743 |
// Use only the previous KV cells of the correct sequence for each token of the ubatch.
|
| 744 |
// It's assumed that if a token in the batch has multiple sequences, they are equivalent.
|
|
@@ -752,12 +813,14 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
|
|
| 752 |
// xxxxx-----
|
| 753 |
// xxxxx-----
|
| 754 |
// To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
|
| 755 |
-
for (
|
| 756 |
-
for (
|
| 757 |
const llama_seq_id seq_id = ubatch->seq_id[s][0];
|
| 758 |
|
| 759 |
-
for (
|
| 760 |
-
const
|
|
|
|
|
|
|
| 761 |
|
| 762 |
for (uint32_t i = 0; i < n_kv; ++i) {
|
| 763 |
float f = 0.0f;
|
|
@@ -787,16 +850,16 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
|
|
| 787 |
f = -INFINITY;
|
| 788 |
}
|
| 789 |
|
| 790 |
-
data[h*(n_kv*n_tokens) +
|
| 791 |
}
|
| 792 |
}
|
| 793 |
}
|
| 794 |
|
| 795 |
// mask padded tokens
|
| 796 |
if (data) {
|
| 797 |
-
for (
|
| 798 |
-
for (uint32_t
|
| 799 |
-
data[h*(n_kv*n_tokens) +
|
| 800 |
}
|
| 801 |
}
|
| 802 |
}
|
|
@@ -1447,9 +1510,11 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
|
|
| 1447 |
seq_rm(dest_seq_id, -1, -1);
|
| 1448 |
|
| 1449 |
llama_sbatch sbatch;
|
| 1450 |
-
llama_ubatch
|
| 1451 |
|
| 1452 |
-
|
|
|
|
|
|
|
| 1453 |
|
| 1454 |
for (uint32_t i = 0; i < cell_count; ++i) {
|
| 1455 |
llama_pos pos;
|
|
@@ -1469,18 +1534,18 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
|
|
| 1469 |
io.read_to(&seq_id, sizeof(seq_id));
|
| 1470 |
}
|
| 1471 |
|
| 1472 |
-
|
| 1473 |
-
|
| 1474 |
-
|
| 1475 |
}
|
| 1476 |
|
| 1477 |
-
const auto head_cur = find_slot(
|
| 1478 |
if (head_cur < 0) {
|
| 1479 |
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
|
| 1480 |
return false;
|
| 1481 |
}
|
| 1482 |
|
| 1483 |
-
apply_ubatch(head_cur,
|
| 1484 |
|
| 1485 |
// keep the head at the old position because we will read the KV data into it in state_read_data()
|
| 1486 |
head = head_cur;
|
|
@@ -1488,8 +1553,8 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
|
|
| 1488 |
// DEBUG CHECK: head_cur should be our first cell, head_cur + cell_count - 1 should be our last cell (verify seq_id and pos values)
|
| 1489 |
// Assume that this is one contiguous block of cells
|
| 1490 |
GGML_ASSERT(head_cur + cell_count <= cells.size());
|
| 1491 |
-
GGML_ASSERT(cells.pos_get(head_cur) ==
|
| 1492 |
-
GGML_ASSERT(cells.pos_get(head_cur + cell_count - 1) ==
|
| 1493 |
GGML_ASSERT(cells.seq_has(head_cur, dest_seq_id));
|
| 1494 |
GGML_ASSERT(cells.seq_has(head_cur + cell_count - 1, dest_seq_id));
|
| 1495 |
} else {
|
|
@@ -1674,7 +1739,7 @@ llama_kv_cache_unified_state::llama_kv_cache_unified_state(
|
|
| 1674 |
llama_context * lctx,
|
| 1675 |
bool do_shift,
|
| 1676 |
defrag_info dinfo) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), lctx(lctx), do_shift(do_shift), dinfo(std::move(dinfo)) {
|
| 1677 |
-
if (!do_shift && dinfo.empty()) {
|
| 1678 |
status = LLAMA_MEMORY_STATUS_NO_UPDATE;
|
| 1679 |
}
|
| 1680 |
}
|
|
|
|
| 127 |
ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
|
| 128 |
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
|
| 129 |
}
|
| 130 |
+
|
| 131 |
+
const char * LLAMA_KV_CACHE_DEBUG = getenv("LLAMA_KV_CACHE_DEBUG");
|
| 132 |
+
debug = LLAMA_KV_CACHE_DEBUG ? atoi(LLAMA_KV_CACHE_DEBUG) : 0;
|
| 133 |
}
|
| 134 |
|
| 135 |
void llama_kv_cache_unified::clear(bool data) {
|
|
|
|
| 310 |
llama_memory_state_ptr llama_kv_cache_unified::init_batch(
|
| 311 |
const llama_batch & batch,
|
| 312 |
uint32_t n_ubatch,
|
| 313 |
+
bool embd_all) {
|
| 314 |
+
GGML_UNUSED(embd_all);
|
|
|
|
| 315 |
|
| 316 |
+
do {
|
| 317 |
+
auto sbatch = llama_sbatch(batch, hparams.n_embd, true);
|
| 318 |
|
| 319 |
+
std::vector<llama_ubatch> ubatches;
|
| 320 |
+
while (sbatch.n_tokens > 0) {
|
| 321 |
+
ubatches.push_back(sbatch.split_simple(n_ubatch));
|
| 322 |
+
}
|
| 323 |
|
| 324 |
+
auto heads = prepare(ubatches);
|
| 325 |
+
if (heads.empty()) {
|
| 326 |
+
break;
|
| 327 |
+
}
|
| 328 |
+
|
| 329 |
+
return std::make_unique<llama_kv_cache_unified_state>(
|
| 330 |
+
this, std::move(sbatch), std::move(heads), std::move(ubatches));
|
| 331 |
+
} while (false);
|
| 332 |
|
| 333 |
+
return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
|
|
|
| 334 |
}
|
| 335 |
|
| 336 |
llama_memory_state_ptr llama_kv_cache_unified::init_full() {
|
|
|
|
| 518 |
head_cur = 0;
|
| 519 |
}
|
| 520 |
|
|
|
|
|
|
|
| 521 |
if (n_tokens > cells.size()) {
|
| 522 |
LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size());
|
| 523 |
return -1;
|
| 524 |
}
|
| 525 |
|
| 526 |
+
if (debug > 0) {
|
| 527 |
+
LLAMA_LOG_DEBUG("%s: n = %5d, used = %5d, head = %5d, size = %5d, n_swa = %5d\n", __func__, cells.used_max_p1(), cells.get_used(), head, get_size(), n_swa);
|
|
|
|
| 528 |
|
| 529 |
+
if ((debug == 2 && n_swa > 0) || debug > 2) {
|
| 530 |
+
std::string ss;
|
|
|
|
|
|
|
| 531 |
for (uint32_t i = 0; i < cells.size(); ++i) {
|
| 532 |
if (cells.is_empty(i)) {
|
| 533 |
ss += '.';
|
| 534 |
} else {
|
| 535 |
+
assert(cells.seq_count(i) >= 1);
|
| 536 |
+
|
| 537 |
+
if (cells.seq_count(i) == 1) {
|
| 538 |
+
ss += std::to_string(cells.seq_get(i));
|
| 539 |
+
} else {
|
| 540 |
+
ss += 'M';
|
| 541 |
+
}
|
| 542 |
}
|
| 543 |
if (i%256 == 255) {
|
| 544 |
+
ss += " *";
|
| 545 |
ss += '\n';
|
| 546 |
}
|
| 547 |
}
|
| 548 |
+
LLAMA_LOG_DEBUG("\n%s\n", ss.c_str());
|
| 549 |
}
|
|
|
|
|
|
|
| 550 |
|
| 551 |
+
if ((debug == 2 && n_swa > 0) || debug > 2) {
|
| 552 |
+
std::string ss;
|
| 553 |
+
for (uint32_t i = 0; i < cells.size(); ++i) {
|
| 554 |
+
std::string cur;
|
| 555 |
+
if (cells.is_empty(i)) {
|
| 556 |
+
cur = '.';
|
| 557 |
+
} else {
|
| 558 |
+
cur = std::to_string(cells.pos_get(i));
|
| 559 |
+
}
|
| 560 |
+
const int n = cur.size();
|
| 561 |
+
for (int j = 0; j < 5 - n; ++j) {
|
| 562 |
+
cur += ' ';
|
| 563 |
+
}
|
| 564 |
+
ss += cur;
|
| 565 |
+
if (i%256 == 255) {
|
| 566 |
+
ss += " *";
|
| 567 |
+
}
|
| 568 |
+
if (i%64 == 63) {
|
| 569 |
+
ss += '\n';
|
| 570 |
+
}
|
| 571 |
+
}
|
| 572 |
+
LLAMA_LOG_DEBUG("\n%s\n", ss.c_str());
|
| 573 |
}
|
| 574 |
|
| 575 |
+
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
| 576 |
+
if (cells.seq_pos_min(s) < 0) {
|
| 577 |
+
continue;
|
| 578 |
+
}
|
| 579 |
+
|
| 580 |
+
LLAMA_LOG_DEBUG("%s: min[%d] = %5d, max[%d] = %5d\n", __func__, s, cells.seq_pos_min(s), s, cells.seq_pos_max(s));
|
| 581 |
+
}
|
| 582 |
}
|
|
|
|
| 583 |
|
| 584 |
uint32_t n_tested = 0;
|
| 585 |
|
|
|
|
| 590 |
continue;
|
| 591 |
}
|
| 592 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 593 |
bool found = true;
|
| 594 |
for (uint32_t i = 0; i < n_tokens; i++) {
|
| 595 |
+
//const llama_pos pos = ubatch.pos[i];
|
| 596 |
+
//const llama_seq_id seq_id = ubatch.seq_id[i][0];
|
| 597 |
|
| 598 |
// can we use this cell? either:
|
| 599 |
// - the cell is empty
|
| 600 |
// - the cell is occupied only by one sequence:
|
| 601 |
+
// - (disabled) mask causally, if the sequence is the same as the one we are inserting
|
| 602 |
// - mask SWA, using current max pos for that sequence in the cache
|
| 603 |
// always insert in the cell with minimum pos
|
| 604 |
bool can_use = cells.is_empty(head_cur + i);
|
|
|
|
| 606 |
if (!can_use && cells.seq_count(head_cur + i) == 1) {
|
| 607 |
const llama_pos pos_cell = cells.pos_get(head_cur + i);
|
| 608 |
|
| 609 |
+
// (disabled) causal mask
|
| 610 |
+
// note: it's better to purge any "future" tokens beforehand
|
| 611 |
+
//if (cells.seq_has(head_cur + i, seq_id)) {
|
| 612 |
+
// can_use = pos_cell >= pos;
|
| 613 |
+
//}
|
| 614 |
|
| 615 |
if (!can_use) {
|
| 616 |
const llama_seq_id seq_id_cell = cells.seq_get(head_cur + i);
|
| 617 |
|
| 618 |
// SWA mask
|
| 619 |
+
if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 620 |
can_use = true;
|
| 621 |
}
|
| 622 |
}
|
|
|
|
| 644 |
}
|
| 645 |
|
| 646 |
void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch) {
|
| 647 |
+
if (debug > 0) {
|
| 648 |
+
LLAMA_LOG_DEBUG("%s: ubatch info:\n", __func__);
|
| 649 |
+
LLAMA_LOG_DEBUG("%s: n_tokens = %d, equal_seqs = %d\n", __func__, ubatch.n_tokens, ubatch.equal_seqs);
|
| 650 |
+
LLAMA_LOG_DEBUG("%s: n_seq_tokens = %d, n_seqs = %d\n", __func__, ubatch.n_seq_tokens, ubatch.n_seqs);
|
| 651 |
+
}
|
| 652 |
+
|
| 653 |
+
// keep track of the max sequence position that we would overwrite with this ubatch
|
| 654 |
+
// for non-SWA cache, this would be always empty
|
| 655 |
+
llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ];
|
| 656 |
+
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
| 657 |
+
seq_pos_max_rm[s] = -1;
|
| 658 |
+
}
|
| 659 |
+
|
| 660 |
+
for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
|
| 661 |
+
for (uint32_t j = 0; j < ubatch.n_seq_tokens; ++j) {
|
| 662 |
+
const uint32_t idx = s*ubatch.n_seq_tokens + j;
|
| 663 |
+
|
| 664 |
+
if (!cells.is_empty(head_cur + idx)) {
|
| 665 |
+
assert(cells.seq_count(head_cur + idx) == 1);
|
| 666 |
+
|
| 667 |
+
const llama_seq_id seq_id = cells.seq_get(head_cur + idx);
|
| 668 |
+
const llama_pos pos = cells.pos_get(head_cur + idx);
|
| 669 |
|
| 670 |
+
seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
|
| 671 |
|
| 672 |
+
cells.rm(head_cur + idx);
|
| 673 |
+
}
|
| 674 |
+
|
| 675 |
+
cells.pos_set(head_cur + idx, ubatch.pos[idx]);
|
| 676 |
+
|
| 677 |
+
// TODO: fix indexing [UBATCH_IDX]
|
| 678 |
+
for (int32_t i = 0; i < ubatch.n_seq_id[s]; i++) {
|
| 679 |
+
cells.seq_add(head_cur + idx, ubatch.seq_id[s][i]);
|
| 680 |
+
}
|
| 681 |
}
|
| 682 |
}
|
| 683 |
|
| 684 |
+
// note: we want to preserve the invariant that all positions between [pos_min, pos_max] for each sequence
|
| 685 |
+
// will be present in the cache. so we have to purge any position which is less than those we would overwrite
|
| 686 |
+
// ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092
|
| 687 |
+
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
| 688 |
+
if (seq_pos_max_rm[s] == -1) {
|
| 689 |
+
continue;
|
| 690 |
+
}
|
| 691 |
+
|
| 692 |
+
if (cells.seq_pos_min(s) <= seq_pos_max_rm[s]) {
|
| 693 |
+
LLAMA_LOG_DEBUG("%s: purging positions [%d, %d] of sequence %d from KV cache\n",
|
| 694 |
+
__func__, cells.seq_pos_min(s), seq_pos_max_rm[s], s);
|
| 695 |
+
|
| 696 |
+
seq_rm(s, cells.seq_pos_min(s), seq_pos_max_rm[s] + 1);
|
| 697 |
+
}
|
| 698 |
+
}
|
| 699 |
// move the head at the end of the slot
|
| 700 |
head = head_cur + ubatch.n_tokens;
|
| 701 |
}
|
|
|
|
| 792 |
}
|
| 793 |
|
| 794 |
void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
|
| 795 |
+
const uint32_t n_tokens = ubatch->n_tokens;
|
| 796 |
+
const uint32_t n_seq_tokens = ubatch->n_seq_tokens;
|
| 797 |
+
const uint32_t n_seqs = ubatch->n_seqs;
|
| 798 |
|
| 799 |
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
|
| 800 |
float * data = (float *) dst->data;
|
| 801 |
|
| 802 |
+
const int64_t n_kv = dst->ne[0];
|
| 803 |
|
| 804 |
// Use only the previous KV cells of the correct sequence for each token of the ubatch.
|
| 805 |
// It's assumed that if a token in the batch has multiple sequences, they are equivalent.
|
|
|
|
| 813 |
// xxxxx-----
|
| 814 |
// xxxxx-----
|
| 815 |
// To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
|
| 816 |
+
for (uint32_t h = 0; h < 1; ++h) {
|
| 817 |
+
for (uint32_t s = 0; s < n_seqs; ++s) {
|
| 818 |
const llama_seq_id seq_id = ubatch->seq_id[s][0];
|
| 819 |
|
| 820 |
+
for (uint32_t j = 0; j < n_seq_tokens; ++j) {
|
| 821 |
+
const uint32_t idx = s*n_seq_tokens + j;
|
| 822 |
+
|
| 823 |
+
const llama_pos p1 = ubatch->pos[idx];
|
| 824 |
|
| 825 |
for (uint32_t i = 0; i < n_kv; ++i) {
|
| 826 |
float f = 0.0f;
|
|
|
|
| 850 |
f = -INFINITY;
|
| 851 |
}
|
| 852 |
|
| 853 |
+
data[h*(n_kv*n_tokens) + idx*n_kv + i] = f;
|
| 854 |
}
|
| 855 |
}
|
| 856 |
}
|
| 857 |
|
| 858 |
// mask padded tokens
|
| 859 |
if (data) {
|
| 860 |
+
for (uint32_t j = n_tokens; j < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++j) {
|
| 861 |
+
for (uint32_t i = 0; i < n_kv; ++i) {
|
| 862 |
+
data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY;
|
| 863 |
}
|
| 864 |
}
|
| 865 |
}
|
|
|
|
| 1510 |
seq_rm(dest_seq_id, -1, -1);
|
| 1511 |
|
| 1512 |
llama_sbatch sbatch;
|
| 1513 |
+
llama_ubatch ubatch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
|
| 1514 |
|
| 1515 |
+
ubatch.n_tokens = cell_count;
|
| 1516 |
+
ubatch.n_seq_tokens = cell_count;
|
| 1517 |
+
ubatch.n_seqs = 1;
|
| 1518 |
|
| 1519 |
for (uint32_t i = 0; i < cell_count; ++i) {
|
| 1520 |
llama_pos pos;
|
|
|
|
| 1534 |
io.read_to(&seq_id, sizeof(seq_id));
|
| 1535 |
}
|
| 1536 |
|
| 1537 |
+
ubatch.pos[i] = pos;
|
| 1538 |
+
ubatch.n_seq_id[i] = n_seq_id;
|
| 1539 |
+
ubatch.seq_id[i] = &dest_seq_id;
|
| 1540 |
}
|
| 1541 |
|
| 1542 |
+
const auto head_cur = find_slot(ubatch);
|
| 1543 |
if (head_cur < 0) {
|
| 1544 |
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
|
| 1545 |
return false;
|
| 1546 |
}
|
| 1547 |
|
| 1548 |
+
apply_ubatch(head_cur, ubatch);
|
| 1549 |
|
| 1550 |
// keep the head at the old position because we will read the KV data into it in state_read_data()
|
| 1551 |
head = head_cur;
|
|
|
|
| 1553 |
// DEBUG CHECK: head_cur should be our first cell, head_cur + cell_count - 1 should be our last cell (verify seq_id and pos values)
|
| 1554 |
// Assume that this is one contiguous block of cells
|
| 1555 |
GGML_ASSERT(head_cur + cell_count <= cells.size());
|
| 1556 |
+
GGML_ASSERT(cells.pos_get(head_cur) == ubatch.pos[0]);
|
| 1557 |
+
GGML_ASSERT(cells.pos_get(head_cur + cell_count - 1) == ubatch.pos[cell_count - 1]);
|
| 1558 |
GGML_ASSERT(cells.seq_has(head_cur, dest_seq_id));
|
| 1559 |
GGML_ASSERT(cells.seq_has(head_cur + cell_count - 1, dest_seq_id));
|
| 1560 |
} else {
|
|
|
|
| 1739 |
llama_context * lctx,
|
| 1740 |
bool do_shift,
|
| 1741 |
defrag_info dinfo) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), lctx(lctx), do_shift(do_shift), dinfo(std::move(dinfo)) {
|
| 1742 |
+
if (!do_shift && this->dinfo.empty()) {
|
| 1743 |
status = LLAMA_MEMORY_STATUS_NO_UPDATE;
|
| 1744 |
}
|
| 1745 |
}
|
examples/talk-llama/llama-kv-cache-unified.h
CHANGED
|
@@ -59,8 +59,7 @@ public:
|
|
| 59 |
llama_memory_state_ptr init_batch(
|
| 60 |
const llama_batch & batch,
|
| 61 |
uint32_t n_ubatch,
|
| 62 |
-
bool
|
| 63 |
-
bool logits_all) override;
|
| 64 |
|
| 65 |
llama_memory_state_ptr init_full() override;
|
| 66 |
|
|
@@ -158,6 +157,8 @@ private:
|
|
| 158 |
// SWA
|
| 159 |
const uint32_t n_swa = 0;
|
| 160 |
|
|
|
|
|
|
|
| 161 |
const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
|
| 162 |
|
| 163 |
std::vector<ggml_context_ptr> ctxs;
|
|
|
|
| 59 |
llama_memory_state_ptr init_batch(
|
| 60 |
const llama_batch & batch,
|
| 61 |
uint32_t n_ubatch,
|
| 62 |
+
bool embd_all) override;
|
|
|
|
| 63 |
|
| 64 |
llama_memory_state_ptr init_full() override;
|
| 65 |
|
|
|
|
| 157 |
// SWA
|
| 158 |
const uint32_t n_swa = 0;
|
| 159 |
|
| 160 |
+
int debug = 0;
|
| 161 |
+
|
| 162 |
const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
|
| 163 |
|
| 164 |
std::vector<ggml_context_ptr> ctxs;
|
examples/talk-llama/llama-kv-cells.h
CHANGED
|
@@ -23,7 +23,7 @@ public:
|
|
| 23 |
|
| 24 |
used.clear();
|
| 25 |
|
| 26 |
-
for (uint32_t s = 0; s <
|
| 27 |
seq_pos[s].clear();
|
| 28 |
}
|
| 29 |
}
|
|
@@ -240,7 +240,7 @@ public:
|
|
| 240 |
llama_seq_id seq_get(uint32_t i) const {
|
| 241 |
assert(seq[i].count() == 1);
|
| 242 |
|
| 243 |
-
for (int s = 0; s <
|
| 244 |
if (seq[i].test(s)) {
|
| 245 |
return s;
|
| 246 |
}
|
|
@@ -253,7 +253,7 @@ public:
|
|
| 253 |
// return -1 if the sequence is not present
|
| 254 |
llama_pos seq_pos_min(llama_seq_id seq_id) const {
|
| 255 |
assert(seq_id >= 0);
|
| 256 |
-
assert(seq_id <
|
| 257 |
|
| 258 |
if (seq_pos[seq_id].empty()) {
|
| 259 |
return -1;
|
|
@@ -266,7 +266,7 @@ public:
|
|
| 266 |
// return -1 if the sequence is not present
|
| 267 |
llama_pos seq_pos_max(llama_seq_id seq_id) const {
|
| 268 |
assert(seq_id >= 0);
|
| 269 |
-
assert(seq_id <
|
| 270 |
|
| 271 |
if (seq_pos[seq_id].empty()) {
|
| 272 |
return -1;
|
|
@@ -384,20 +384,20 @@ private:
|
|
| 384 |
//
|
| 385 |
std::vector<llama_pos> shift;
|
| 386 |
|
| 387 |
-
using bits_t = std::bitset<
|
| 388 |
|
| 389 |
// the bitset seq[i] tells us which sequences are currently occupying the i-th cell
|
| 390 |
std::vector<bits_t> seq;
|
| 391 |
|
| 392 |
// the set seq_pos[s] tells us which positions are currently present for sequence s
|
| 393 |
// this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache
|
| 394 |
-
std::set<llama_pos> seq_pos[
|
| 395 |
|
| 396 |
// helper functions for updating `seq_pos`, once cell at a time:
|
| 397 |
|
| 398 |
// remove cell i
|
| 399 |
void seq_pos_rm(uint32_t i) {
|
| 400 |
-
for (int s = 0; s <
|
| 401 |
if (seq[i].test(s)) {
|
| 402 |
seq_pos[s].erase(pos[i]);
|
| 403 |
}
|
|
@@ -406,7 +406,7 @@ private:
|
|
| 406 |
|
| 407 |
// add cell i
|
| 408 |
void seq_pos_add(uint32_t i) {
|
| 409 |
-
for (int s = 0; s <
|
| 410 |
if (seq[i].test(s)) {
|
| 411 |
seq_pos[s].insert(pos[i]);
|
| 412 |
}
|
|
|
|
| 23 |
|
| 24 |
used.clear();
|
| 25 |
|
| 26 |
+
for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
| 27 |
seq_pos[s].clear();
|
| 28 |
}
|
| 29 |
}
|
|
|
|
| 240 |
llama_seq_id seq_get(uint32_t i) const {
|
| 241 |
assert(seq[i].count() == 1);
|
| 242 |
|
| 243 |
+
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
| 244 |
if (seq[i].test(s)) {
|
| 245 |
return s;
|
| 246 |
}
|
|
|
|
| 253 |
// return -1 if the sequence is not present
|
| 254 |
llama_pos seq_pos_min(llama_seq_id seq_id) const {
|
| 255 |
assert(seq_id >= 0);
|
| 256 |
+
assert(seq_id < LLAMA_MAX_SEQ);
|
| 257 |
|
| 258 |
if (seq_pos[seq_id].empty()) {
|
| 259 |
return -1;
|
|
|
|
| 266 |
// return -1 if the sequence is not present
|
| 267 |
llama_pos seq_pos_max(llama_seq_id seq_id) const {
|
| 268 |
assert(seq_id >= 0);
|
| 269 |
+
assert(seq_id < LLAMA_MAX_SEQ);
|
| 270 |
|
| 271 |
if (seq_pos[seq_id].empty()) {
|
| 272 |
return -1;
|
|
|
|
| 384 |
//
|
| 385 |
std::vector<llama_pos> shift;
|
| 386 |
|
| 387 |
+
using bits_t = std::bitset<LLAMA_MAX_SEQ>;
|
| 388 |
|
| 389 |
// the bitset seq[i] tells us which sequences are currently occupying the i-th cell
|
| 390 |
std::vector<bits_t> seq;
|
| 391 |
|
| 392 |
// the set seq_pos[s] tells us which positions are currently present for sequence s
|
| 393 |
// this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache
|
| 394 |
+
std::set<llama_pos> seq_pos[LLAMA_MAX_SEQ];
|
| 395 |
|
| 396 |
// helper functions for updating `seq_pos`, once cell at a time:
|
| 397 |
|
| 398 |
// remove cell i
|
| 399 |
void seq_pos_rm(uint32_t i) {
|
| 400 |
+
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
| 401 |
if (seq[i].test(s)) {
|
| 402 |
seq_pos[s].erase(pos[i]);
|
| 403 |
}
|
|
|
|
| 406 |
|
| 407 |
// add cell i
|
| 408 |
void seq_pos_add(uint32_t i) {
|
| 409 |
+
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
| 410 |
if (seq[i].test(s)) {
|
| 411 |
seq_pos[s].insert(pos[i]);
|
| 412 |
}
|
examples/talk-llama/llama-memory.h
CHANGED
|
@@ -73,8 +73,7 @@ struct llama_memory_i {
|
|
| 73 |
virtual llama_memory_state_ptr init_batch(
|
| 74 |
const llama_batch & batch,
|
| 75 |
uint32_t n_ubatch,
|
| 76 |
-
bool
|
| 77 |
-
bool logits_all) = 0;
|
| 78 |
|
| 79 |
// simulate full cache, used for allocating worst-case compute buffers
|
| 80 |
virtual llama_memory_state_ptr init_full() = 0;
|
|
|
|
| 73 |
virtual llama_memory_state_ptr init_batch(
|
| 74 |
const llama_batch & batch,
|
| 75 |
uint32_t n_ubatch,
|
| 76 |
+
bool embd_all) = 0;
|
|
|
|
| 77 |
|
| 78 |
// simulate full cache, used for allocating worst-case compute buffers
|
| 79 |
virtual llama_memory_state_ptr init_full() = 0;
|
examples/talk-llama/llama-model.cpp
CHANGED
|
@@ -80,6 +80,7 @@ const char * llm_type_name(llm_type type) {
|
|
| 80 |
case LLM_TYPE_40B: return "40B";
|
| 81 |
case LLM_TYPE_65B: return "65B";
|
| 82 |
case LLM_TYPE_70B: return "70B";
|
|
|
|
| 83 |
case LLM_TYPE_236B: return "236B";
|
| 84 |
case LLM_TYPE_290B: return "290B";
|
| 85 |
case LLM_TYPE_314B: return "314B";
|
|
@@ -598,6 +599,16 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|
| 598 |
hparams.use_kq_norm = false;
|
| 599 |
}
|
| 600 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 601 |
case LLM_ARCH_DECI:
|
| 602 |
{
|
| 603 |
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
|
@@ -738,6 +749,16 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|
| 738 |
}
|
| 739 |
}
|
| 740 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 741 |
case LLM_ARCH_BLOOM:
|
| 742 |
{
|
| 743 |
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
|
|
@@ -1444,6 +1465,20 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|
| 1444 |
default: type = LLM_TYPE_UNKNOWN;
|
| 1445 |
}
|
| 1446 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1447 |
default: throw std::runtime_error("unsupported model architecture");
|
| 1448 |
}
|
| 1449 |
|
|
@@ -2187,6 +2222,32 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
|
| 2187 |
layer.layer_out_norm_b = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i), {n_embd}, 0);
|
| 2188 |
}
|
| 2189 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2190 |
case LLM_ARCH_JINA_BERT_V2:
|
| 2191 |
{
|
| 2192 |
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); // word_embeddings
|
|
@@ -2224,8 +2285,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
|
| 2224 |
layer.attn_norm_2 = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED);
|
| 2225 |
layer.attn_norm_2_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED);
|
| 2226 |
|
| 2227 |
-
layer.
|
| 2228 |
-
layer.
|
| 2229 |
|
| 2230 |
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
|
| 2231 |
layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0);
|
|
@@ -4123,6 +4184,89 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
|
| 4123 |
layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0);
|
| 4124 |
}
|
| 4125 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4126 |
default:
|
| 4127 |
throw std::runtime_error("unknown architecture");
|
| 4128 |
}
|
|
@@ -6043,7 +6187,7 @@ struct llm_build_bert : public llm_graph_context {
|
|
| 6043 |
model.layers[il].ffn_gate, NULL, NULL,
|
| 6044 |
model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
|
| 6045 |
NULL,
|
| 6046 |
-
LLM_FFN_GELU, LLM_FFN_PAR, il);
|
| 6047 |
cb(cur, "ffn_out", il);
|
| 6048 |
} else {
|
| 6049 |
cur = build_ffn(cur,
|
|
@@ -6074,6 +6218,117 @@ struct llm_build_bert : public llm_graph_context {
|
|
| 6074 |
}
|
| 6075 |
};
|
| 6076 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6077 |
struct llm_build_bloom : public llm_graph_context {
|
| 6078 |
llm_build_bloom(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
|
| 6079 |
const int64_t n_embd_head = hparams.n_embd_head_v;
|
|
@@ -8857,7 +9112,6 @@ struct llm_build_mamba : public llm_graph_context {
|
|
| 8857 |
inpL = build_inp_embd(model.tok_embd);
|
| 8858 |
|
| 8859 |
ggml_tensor * state_copy = build_inp_s_copy();
|
| 8860 |
-
ggml_tensor * state_mask = build_inp_s_mask();
|
| 8861 |
|
| 8862 |
for (int il = 0; il < n_layer; ++il) {
|
| 8863 |
// norm
|
|
@@ -8866,8 +9120,7 @@ struct llm_build_mamba : public llm_graph_context {
|
|
| 8866 |
LLM_NORM_RMS, il);
|
| 8867 |
cb(cur, "attn_norm", il);
|
| 8868 |
|
| 8869 |
-
|
| 8870 |
-
cur = build_mamba_layer(gf, cur, state_copy, state_mask, ubatch, il);
|
| 8871 |
|
| 8872 |
if (il == n_layer - 1) {
|
| 8873 |
// skip computing output for unused tokens
|
|
@@ -8908,7 +9161,6 @@ struct llm_build_mamba : public llm_graph_context {
|
|
| 8908 |
ggml_cgraph * gf,
|
| 8909 |
ggml_tensor * cur,
|
| 8910 |
ggml_tensor * state_copy,
|
| 8911 |
-
ggml_tensor * state_mask,
|
| 8912 |
const llama_ubatch & ubatch,
|
| 8913 |
int il) const {
|
| 8914 |
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
|
|
@@ -8935,12 +9187,12 @@ struct llm_build_mamba : public llm_graph_context {
|
|
| 8935 |
ggml_tensor * ssm_states_all = kv_state->get_v_l(il);
|
| 8936 |
|
| 8937 |
// (ab)using the KV cache to store the states
|
| 8938 |
-
ggml_tensor * conv =
|
| 8939 |
-
gf, conv_states_all, state_copy,
|
| 8940 |
hparams.n_embd_k_s(), n_seqs);
|
| 8941 |
conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner, n_seqs);
|
| 8942 |
-
ggml_tensor * ssm =
|
| 8943 |
-
gf, ssm_states_all, state_copy,
|
| 8944 |
hparams.n_embd_v_s(), n_seqs);
|
| 8945 |
ssm = ggml_reshape_3d(ctx0, ssm, d_state, d_inner, n_seqs);
|
| 8946 |
|
|
@@ -11656,7 +11908,6 @@ struct llm_build_rwkv6_base : public llm_graph_context {
|
|
| 11656 |
ggml_tensor * cur,
|
| 11657 |
ggml_tensor * x_prev,
|
| 11658 |
ggml_tensor * state_copy,
|
| 11659 |
-
ggml_tensor * state_mask,
|
| 11660 |
const llama_ubatch & ubatch,
|
| 11661 |
int il) const {
|
| 11662 |
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
|
|
@@ -11780,8 +12031,8 @@ struct llm_build_rwkv6_base : public llm_graph_context {
|
|
| 11780 |
k = ggml_sub(ctx0, k, ggml_mul(ctx0, k, w));
|
| 11781 |
}
|
| 11782 |
|
| 11783 |
-
ggml_tensor * wkv_state =
|
| 11784 |
-
gf, kv_state->get_v_l(il), state_copy,
|
| 11785 |
hparams.n_embd_v_s(), n_seqs);
|
| 11786 |
|
| 11787 |
ggml_tensor * wkv_output;
|
|
@@ -11837,7 +12088,6 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base {
|
|
| 11837 |
inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1);
|
| 11838 |
|
| 11839 |
ggml_tensor * state_copy = build_inp_s_copy();
|
| 11840 |
-
ggml_tensor * state_mask = build_inp_s_mask();
|
| 11841 |
|
| 11842 |
const auto n_embd = hparams.n_embd;
|
| 11843 |
const auto n_seq_tokens = ubatch.n_seq_tokens;
|
|
@@ -11848,7 +12098,7 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base {
|
|
| 11848 |
inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
|
| 11849 |
|
| 11850 |
ggml_tensor * token_shift = build_rwkv_token_shift_load(
|
| 11851 |
-
gf, state_copy,
|
| 11852 |
);
|
| 11853 |
|
| 11854 |
ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0);
|
|
@@ -11864,7 +12114,7 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base {
|
|
| 11864 |
1
|
| 11865 |
);
|
| 11866 |
|
| 11867 |
-
cur = build_rwkv6_time_mix(gf, att_norm, x_prev, state_copy,
|
| 11868 |
|
| 11869 |
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
|
| 11870 |
cb(ffn_inp, "ffn_inp", il);
|
|
@@ -11935,7 +12185,6 @@ struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base {
|
|
| 11935 |
inpL = build_inp_embd(model.tok_embd);
|
| 11936 |
|
| 11937 |
ggml_tensor * state_copy = build_inp_s_copy();
|
| 11938 |
-
ggml_tensor * state_mask = build_inp_s_mask();
|
| 11939 |
|
| 11940 |
const auto n_embd = hparams.n_embd;
|
| 11941 |
const auto n_seq_tokens = ubatch.n_seq_tokens;
|
|
@@ -11946,7 +12195,7 @@ struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base {
|
|
| 11946 |
inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
|
| 11947 |
|
| 11948 |
ggml_tensor * token_shift = build_rwkv_token_shift_load(
|
| 11949 |
-
gf, state_copy,
|
| 11950 |
);
|
| 11951 |
|
| 11952 |
ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, il);
|
|
@@ -11959,7 +12208,7 @@ struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base {
|
|
| 11959 |
1
|
| 11960 |
);
|
| 11961 |
|
| 11962 |
-
cur = build_rwkv6_time_mix(gf, att_norm, x_prev, state_copy,
|
| 11963 |
|
| 11964 |
token_shift = ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm));
|
| 11965 |
ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il));
|
|
@@ -12051,7 +12300,6 @@ struct llm_build_rwkv7_base : public llm_graph_context {
|
|
| 12051 |
ggml_tensor * cur,
|
| 12052 |
ggml_tensor * x_prev,
|
| 12053 |
ggml_tensor * state_copy,
|
| 12054 |
-
ggml_tensor * state_mask,
|
| 12055 |
ggml_tensor *& first_layer_value,
|
| 12056 |
const llama_ubatch & ubatch,
|
| 12057 |
int il) const {
|
|
@@ -12134,8 +12382,8 @@ struct llm_build_rwkv7_base : public llm_graph_context {
|
|
| 12134 |
v = ggml_reshape_3d(ctx0, v, head_size, head_count, n_tokens);
|
| 12135 |
a = ggml_reshape_3d(ctx0, a, head_size, head_count, n_tokens);
|
| 12136 |
|
| 12137 |
-
ggml_tensor * wkv_state =
|
| 12138 |
-
gf, kv_state->get_v_l(il), state_copy,
|
| 12139 |
hparams.n_embd_v_s(), n_seqs);
|
| 12140 |
|
| 12141 |
ggml_tensor * wkv_output = ggml_rwkv_wkv7(ctx0, r, w, k, v, ggml_neg(ctx0, kk), ggml_mul(ctx0, kk, a), wkv_state);
|
|
@@ -12193,7 +12441,6 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base {
|
|
| 12193 |
inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1);
|
| 12194 |
|
| 12195 |
ggml_tensor * state_copy = build_inp_s_copy();
|
| 12196 |
-
ggml_tensor * state_mask = build_inp_s_mask();
|
| 12197 |
|
| 12198 |
const auto n_embd = hparams.n_embd;
|
| 12199 |
const auto n_seq_tokens = ubatch.n_seq_tokens;
|
|
@@ -12204,7 +12451,7 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base {
|
|
| 12204 |
inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
|
| 12205 |
|
| 12206 |
ggml_tensor * token_shift = build_rwkv_token_shift_load(
|
| 12207 |
-
gf, state_copy,
|
| 12208 |
);
|
| 12209 |
|
| 12210 |
ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0);
|
|
@@ -12220,7 +12467,7 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base {
|
|
| 12220 |
1
|
| 12221 |
);
|
| 12222 |
|
| 12223 |
-
cur = build_rwkv7_time_mix(gf, att_norm, x_prev, state_copy,
|
| 12224 |
|
| 12225 |
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
|
| 12226 |
cb(ffn_inp, "ffn_inp", il);
|
|
@@ -12287,7 +12534,6 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base {
|
|
| 12287 |
inpL = build_inp_embd(model.tok_embd);
|
| 12288 |
|
| 12289 |
ggml_tensor * state_copy = build_inp_s_copy();
|
| 12290 |
-
ggml_tensor * state_mask = build_inp_s_mask();
|
| 12291 |
|
| 12292 |
const auto n_embd = hparams.n_embd;
|
| 12293 |
const auto n_seq_tokens = ubatch.n_seq_tokens;
|
|
@@ -12298,7 +12544,7 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base {
|
|
| 12298 |
inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
|
| 12299 |
|
| 12300 |
ggml_tensor * token_shift = build_rwkv_token_shift_load(
|
| 12301 |
-
gf, state_copy,
|
| 12302 |
);
|
| 12303 |
|
| 12304 |
ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, il);
|
|
@@ -12311,7 +12557,7 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base {
|
|
| 12311 |
1
|
| 12312 |
);
|
| 12313 |
|
| 12314 |
-
cur = build_rwkv7_time_mix(gf, att_norm, x_prev, state_copy,
|
| 12315 |
|
| 12316 |
token_shift = ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm));
|
| 12317 |
ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il));
|
|
@@ -13203,6 +13449,291 @@ struct llm_build_bailingmoe : public llm_graph_context {
|
|
| 13203 |
}
|
| 13204 |
};
|
| 13205 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13206 |
llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_cparams & cparams) const {
|
| 13207 |
llama_memory_i * res;
|
| 13208 |
|
|
@@ -13211,6 +13742,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
|
| 13211 |
case LLM_ARCH_JINA_BERT_V2:
|
| 13212 |
case LLM_ARCH_NOMIC_BERT:
|
| 13213 |
case LLM_ARCH_NOMIC_BERT_MOE:
|
|
|
|
| 13214 |
case LLM_ARCH_WAVTOKENIZER_DEC:
|
| 13215 |
{
|
| 13216 |
res = nullptr;
|
|
@@ -13319,6 +13851,10 @@ llm_graph_result_ptr llama_model::build_graph(
|
|
| 13319 |
{
|
| 13320 |
llm = std::make_unique<llm_build_bert>(*this, params, gf);
|
| 13321 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13322 |
case LLM_ARCH_BLOOM:
|
| 13323 |
{
|
| 13324 |
llm = std::make_unique<llm_build_bloom>(*this, params, gf);
|
|
@@ -13541,6 +14077,14 @@ llm_graph_result_ptr llama_model::build_graph(
|
|
| 13541 |
{
|
| 13542 |
llm = std::make_unique<llm_build_bailingmoe>(*this, params, gf);
|
| 13543 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13544 |
default:
|
| 13545 |
GGML_ABORT("fatal error");
|
| 13546 |
}
|
|
@@ -13690,6 +14234,8 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
|
|
| 13690 |
case LLM_ARCH_GRANITE_MOE:
|
| 13691 |
case LLM_ARCH_CHAMELEON:
|
| 13692 |
case LLM_ARCH_BAILINGMOE:
|
|
|
|
|
|
|
| 13693 |
return LLAMA_ROPE_TYPE_NORM;
|
| 13694 |
|
| 13695 |
// the pairs of head values are offset by n_rot/2
|
|
@@ -13723,6 +14269,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
|
|
| 13723 |
case LLM_ARCH_NEMOTRON:
|
| 13724 |
case LLM_ARCH_EXAONE:
|
| 13725 |
case LLM_ARCH_MINICPM3:
|
|
|
|
| 13726 |
return LLAMA_ROPE_TYPE_NEOX;
|
| 13727 |
|
| 13728 |
case LLM_ARCH_QWEN2VL:
|
|
|
|
| 80 |
case LLM_TYPE_40B: return "40B";
|
| 81 |
case LLM_TYPE_65B: return "65B";
|
| 82 |
case LLM_TYPE_70B: return "70B";
|
| 83 |
+
case LLM_TYPE_142B: return "142B";
|
| 84 |
case LLM_TYPE_236B: return "236B";
|
| 85 |
case LLM_TYPE_290B: return "290B";
|
| 86 |
case LLM_TYPE_314B: return "314B";
|
|
|
|
| 599 |
hparams.use_kq_norm = false;
|
| 600 |
}
|
| 601 |
} break;
|
| 602 |
+
case LLM_ARCH_ARCEE:
|
| 603 |
+
{
|
| 604 |
+
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
| 605 |
+
|
| 606 |
+
// Arcee uses the same structure as Llama
|
| 607 |
+
switch (hparams.n_layer) {
|
| 608 |
+
case 36: type = LLM_TYPE_4B; break;
|
| 609 |
+
default: type = LLM_TYPE_UNKNOWN;
|
| 610 |
+
}
|
| 611 |
+
} break;
|
| 612 |
case LLM_ARCH_DECI:
|
| 613 |
{
|
| 614 |
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
|
|
|
| 749 |
}
|
| 750 |
}
|
| 751 |
} break;
|
| 752 |
+
case LLM_ARCH_NEO_BERT:
|
| 753 |
+
{
|
| 754 |
+
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
| 755 |
+
ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
|
| 756 |
+
ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type);
|
| 757 |
+
|
| 758 |
+
if (hparams.n_layer == 28) {
|
| 759 |
+
type = LLM_TYPE_250M;
|
| 760 |
+
}
|
| 761 |
+
} break;
|
| 762 |
case LLM_ARCH_BLOOM:
|
| 763 |
{
|
| 764 |
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
|
|
|
|
| 1465 |
default: type = LLM_TYPE_UNKNOWN;
|
| 1466 |
}
|
| 1467 |
} break;
|
| 1468 |
+
case LLM_ARCH_DOTS1:
|
| 1469 |
+
{
|
| 1470 |
+
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
| 1471 |
+
ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead);
|
| 1472 |
+
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
|
| 1473 |
+
ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared);
|
| 1474 |
+
ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale);
|
| 1475 |
+
ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false);
|
| 1476 |
+
ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false);
|
| 1477 |
+
switch (hparams.n_layer) {
|
| 1478 |
+
case 62: type = LLM_TYPE_142B; break;
|
| 1479 |
+
default: type = LLM_TYPE_UNKNOWN;
|
| 1480 |
+
}
|
| 1481 |
+
} break;
|
| 1482 |
default: throw std::runtime_error("unsupported model architecture");
|
| 1483 |
}
|
| 1484 |
|
|
|
|
| 2222 |
layer.layer_out_norm_b = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i), {n_embd}, 0);
|
| 2223 |
}
|
| 2224 |
} break;
|
| 2225 |
+
case LLM_ARCH_NEO_BERT:
|
| 2226 |
+
{
|
| 2227 |
+
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
| 2228 |
+
|
| 2229 |
+
cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, TENSOR_NOT_REQUIRED);
|
| 2230 |
+
cls_b = create_tensor(tn(LLM_TENSOR_CLS, "bias"), {n_embd}, TENSOR_NOT_REQUIRED);
|
| 2231 |
+
|
| 2232 |
+
cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED);
|
| 2233 |
+
cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"), {hparams.n_cls_out}, TENSOR_NOT_REQUIRED);
|
| 2234 |
+
|
| 2235 |
+
output_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
| 2236 |
+
|
| 2237 |
+
for (int i = 0; i < n_layer; ++i) {
|
| 2238 |
+
auto & layer = layers[i];
|
| 2239 |
+
|
| 2240 |
+
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
| 2241 |
+
|
| 2242 |
+
layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
|
| 2243 |
+
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
|
| 2244 |
+
|
| 2245 |
+
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
| 2246 |
+
|
| 2247 |
+
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff*2}, 0);
|
| 2248 |
+
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
|
| 2249 |
+
}
|
| 2250 |
+
} break;
|
| 2251 |
case LLM_ARCH_JINA_BERT_V2:
|
| 2252 |
{
|
| 2253 |
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); // word_embeddings
|
|
|
|
| 2285 |
layer.attn_norm_2 = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED);
|
| 2286 |
layer.attn_norm_2_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED);
|
| 2287 |
|
| 2288 |
+
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED);
|
| 2289 |
+
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, layer.ffn_gate ? n_ff : n_ff * 2}, 0);
|
| 2290 |
|
| 2291 |
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
|
| 2292 |
layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0);
|
|
|
|
| 4184 |
layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0);
|
| 4185 |
}
|
| 4186 |
} break;
|
| 4187 |
+
case LLM_ARCH_DOTS1:
|
| 4188 |
+
{
|
| 4189 |
+
const int64_t n_ff_exp = hparams.n_ff_exp;
|
| 4190 |
+
const int64_t n_expert_shared = hparams.n_expert_shared;
|
| 4191 |
+
|
| 4192 |
+
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
| 4193 |
+
|
| 4194 |
+
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
| 4195 |
+
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
|
| 4196 |
+
|
| 4197 |
+
for (int i = 0; i < n_layer; ++i) {
|
| 4198 |
+
auto & layer = layers[i];
|
| 4199 |
+
|
| 4200 |
+
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
| 4201 |
+
|
| 4202 |
+
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
|
| 4203 |
+
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
|
| 4204 |
+
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
|
| 4205 |
+
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
|
| 4206 |
+
|
| 4207 |
+
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0);
|
| 4208 |
+
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0);
|
| 4209 |
+
|
| 4210 |
+
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
| 4211 |
+
|
| 4212 |
+
if (i < (int) hparams.n_layer_dense_lead) {
|
| 4213 |
+
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
|
| 4214 |
+
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
|
| 4215 |
+
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
| 4216 |
+
} else {
|
| 4217 |
+
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
|
| 4218 |
+
layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED);
|
| 4219 |
+
|
| 4220 |
+
if (n_expert == 0) {
|
| 4221 |
+
throw std::runtime_error("n_expert must be > 0");
|
| 4222 |
+
}
|
| 4223 |
+
if (n_expert_used == 0) {
|
| 4224 |
+
throw std::runtime_error("n_expert_used must be > 0");
|
| 4225 |
+
}
|
| 4226 |
+
|
| 4227 |
+
// MoE branch
|
| 4228 |
+
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0);
|
| 4229 |
+
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0);
|
| 4230 |
+
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0);
|
| 4231 |
+
|
| 4232 |
+
// Shared expert branch
|
| 4233 |
+
layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0);
|
| 4234 |
+
layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, 0);
|
| 4235 |
+
layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0);
|
| 4236 |
+
}
|
| 4237 |
+
}
|
| 4238 |
+
} break;
|
| 4239 |
+
case LLM_ARCH_ARCEE:
|
| 4240 |
+
{
|
| 4241 |
+
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
| 4242 |
+
|
| 4243 |
+
// output
|
| 4244 |
+
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
| 4245 |
+
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
|
| 4246 |
+
|
| 4247 |
+
// if output is NULL, init from the input tok embed
|
| 4248 |
+
if (output == NULL) {
|
| 4249 |
+
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
|
| 4250 |
+
}
|
| 4251 |
+
|
| 4252 |
+
for (int i = 0; i < n_layer; ++i) {
|
| 4253 |
+
auto & layer = layers[i];
|
| 4254 |
+
|
| 4255 |
+
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
| 4256 |
+
|
| 4257 |
+
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
|
| 4258 |
+
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0);
|
| 4259 |
+
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0);
|
| 4260 |
+
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
|
| 4261 |
+
|
| 4262 |
+
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
| 4263 |
+
|
| 4264 |
+
layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
|
| 4265 |
+
|
| 4266 |
+
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
|
| 4267 |
+
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
| 4268 |
+
}
|
| 4269 |
+
} break;
|
| 4270 |
default:
|
| 4271 |
throw std::runtime_error("unknown architecture");
|
| 4272 |
}
|
|
|
|
| 6187 |
model.layers[il].ffn_gate, NULL, NULL,
|
| 6188 |
model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
|
| 6189 |
NULL,
|
| 6190 |
+
model.layers[il].ffn_gate ? LLM_FFN_GELU : LLM_FFN_GEGLU, LLM_FFN_PAR, il);
|
| 6191 |
cb(cur, "ffn_out", il);
|
| 6192 |
} else {
|
| 6193 |
cur = build_ffn(cur,
|
|
|
|
| 6218 |
}
|
| 6219 |
};
|
| 6220 |
|
| 6221 |
+
struct llm_build_neo_bert : public llm_graph_context {
|
| 6222 |
+
llm_build_neo_bert(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
|
| 6223 |
+
const int64_t n_embd_head = hparams.n_embd_head_v;
|
| 6224 |
+
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
|
| 6225 |
+
|
| 6226 |
+
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
| 6227 |
+
|
| 6228 |
+
ggml_tensor * cur;
|
| 6229 |
+
ggml_tensor * inpL;
|
| 6230 |
+
ggml_tensor * inp_pos = build_inp_pos();
|
| 6231 |
+
|
| 6232 |
+
// construct input embeddings (token, type, position)
|
| 6233 |
+
inpL = build_inp_embd(model.tok_embd);
|
| 6234 |
+
cb(inpL, "inp_embd", -1);
|
| 6235 |
+
|
| 6236 |
+
auto * inp_attn = build_attn_inp_no_cache();
|
| 6237 |
+
|
| 6238 |
+
// iterate layers
|
| 6239 |
+
for (int il = 0; il < n_layer; ++il) {
|
| 6240 |
+
ggml_tensor * cur = inpL;
|
| 6241 |
+
|
| 6242 |
+
ggml_tensor * Qcur;
|
| 6243 |
+
ggml_tensor * Kcur;
|
| 6244 |
+
ggml_tensor * Vcur;
|
| 6245 |
+
|
| 6246 |
+
// pre-norm
|
| 6247 |
+
cur = build_norm(inpL,
|
| 6248 |
+
model.layers[il].attn_norm, NULL,
|
| 6249 |
+
LLM_NORM_RMS, il);
|
| 6250 |
+
|
| 6251 |
+
// self-attention
|
| 6252 |
+
cur = build_lora_mm(model.layers[il].wqkv, cur);
|
| 6253 |
+
cb(cur, "wqkv", il);
|
| 6254 |
+
|
| 6255 |
+
Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
|
| 6256 |
+
Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
|
| 6257 |
+
Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
|
| 6258 |
+
|
| 6259 |
+
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
| 6260 |
+
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
| 6261 |
+
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
| 6262 |
+
|
| 6263 |
+
// RoPE
|
| 6264 |
+
Qcur = ggml_rope_ext(
|
| 6265 |
+
ctx0, Qcur, inp_pos, nullptr,
|
| 6266 |
+
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
| 6267 |
+
ext_factor, attn_factor, beta_fast, beta_slow
|
| 6268 |
+
);
|
| 6269 |
+
|
| 6270 |
+
Kcur = ggml_rope_ext(
|
| 6271 |
+
ctx0, Kcur, inp_pos, nullptr,
|
| 6272 |
+
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
| 6273 |
+
ext_factor, attn_factor, beta_fast, beta_slow
|
| 6274 |
+
);
|
| 6275 |
+
|
| 6276 |
+
cb(Qcur, "Qcur", il);
|
| 6277 |
+
cb(Kcur, "Kcur", il);
|
| 6278 |
+
cb(Vcur, "Vcur", il);
|
| 6279 |
+
|
| 6280 |
+
cur = build_attn(inp_attn, gf,
|
| 6281 |
+
model.layers[il].wo, nullptr,
|
| 6282 |
+
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
| 6283 |
+
cb(cur, "kqv_out", il);
|
| 6284 |
+
|
| 6285 |
+
if (il == n_layer - 1 && pooling_type == LLAMA_POOLING_TYPE_NONE) {
|
| 6286 |
+
// skip computing output for unused tokens
|
| 6287 |
+
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 6288 |
+
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 6289 |
+
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
|
| 6290 |
+
}
|
| 6291 |
+
|
| 6292 |
+
// re-add the layer input
|
| 6293 |
+
cur = ggml_add(ctx0, cur, inpL);
|
| 6294 |
+
|
| 6295 |
+
ggml_tensor * ffn_inp = cur;
|
| 6296 |
+
cb(ffn_inp, "ffn_inp", il);
|
| 6297 |
+
|
| 6298 |
+
// pre-norm
|
| 6299 |
+
cur = build_norm(ffn_inp,
|
| 6300 |
+
model.layers[il].ffn_norm, NULL,
|
| 6301 |
+
LLM_NORM_RMS, il);
|
| 6302 |
+
cb(cur, "ffn_norm", il);
|
| 6303 |
+
|
| 6304 |
+
// feed-forward network
|
| 6305 |
+
cur = build_ffn(cur,
|
| 6306 |
+
model.layers[il].ffn_up,
|
| 6307 |
+
NULL, NULL, NULL, NULL, NULL,
|
| 6308 |
+
model.layers[il].ffn_down,
|
| 6309 |
+
NULL, NULL, NULL,
|
| 6310 |
+
LLM_FFN_SWIGLU, LLM_FFN_SEQ, il);
|
| 6311 |
+
|
| 6312 |
+
// attentions bypass the intermediate layer
|
| 6313 |
+
cur = ggml_add(ctx0, cur, ffn_inp);
|
| 6314 |
+
|
| 6315 |
+
// input for next layer
|
| 6316 |
+
inpL = cur;
|
| 6317 |
+
}
|
| 6318 |
+
|
| 6319 |
+
cur = inpL;
|
| 6320 |
+
|
| 6321 |
+
cur = build_norm(cur,
|
| 6322 |
+
model.output_norm_enc, NULL,
|
| 6323 |
+
LLM_NORM_RMS, -1);
|
| 6324 |
+
|
| 6325 |
+
cb(cur, "result_embd", -1);
|
| 6326 |
+
res->t_embd = cur;
|
| 6327 |
+
|
| 6328 |
+
ggml_build_forward_expand(gf, cur);
|
| 6329 |
+
}
|
| 6330 |
+
};
|
| 6331 |
+
|
| 6332 |
struct llm_build_bloom : public llm_graph_context {
|
| 6333 |
llm_build_bloom(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
|
| 6334 |
const int64_t n_embd_head = hparams.n_embd_head_v;
|
|
|
|
| 9112 |
inpL = build_inp_embd(model.tok_embd);
|
| 9113 |
|
| 9114 |
ggml_tensor * state_copy = build_inp_s_copy();
|
|
|
|
| 9115 |
|
| 9116 |
for (int il = 0; il < n_layer; ++il) {
|
| 9117 |
// norm
|
|
|
|
| 9120 |
LLM_NORM_RMS, il);
|
| 9121 |
cb(cur, "attn_norm", il);
|
| 9122 |
|
| 9123 |
+
cur = build_mamba_layer(gf, cur, state_copy, ubatch, il);
|
|
|
|
| 9124 |
|
| 9125 |
if (il == n_layer - 1) {
|
| 9126 |
// skip computing output for unused tokens
|
|
|
|
| 9161 |
ggml_cgraph * gf,
|
| 9162 |
ggml_tensor * cur,
|
| 9163 |
ggml_tensor * state_copy,
|
|
|
|
| 9164 |
const llama_ubatch & ubatch,
|
| 9165 |
int il) const {
|
| 9166 |
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
|
|
|
|
| 9187 |
ggml_tensor * ssm_states_all = kv_state->get_v_l(il);
|
| 9188 |
|
| 9189 |
// (ab)using the KV cache to store the states
|
| 9190 |
+
ggml_tensor * conv = build_recurrent_state(
|
| 9191 |
+
gf, conv_states_all, state_copy,
|
| 9192 |
hparams.n_embd_k_s(), n_seqs);
|
| 9193 |
conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner, n_seqs);
|
| 9194 |
+
ggml_tensor * ssm = build_recurrent_state(
|
| 9195 |
+
gf, ssm_states_all, state_copy,
|
| 9196 |
hparams.n_embd_v_s(), n_seqs);
|
| 9197 |
ssm = ggml_reshape_3d(ctx0, ssm, d_state, d_inner, n_seqs);
|
| 9198 |
|
|
|
|
| 11908 |
ggml_tensor * cur,
|
| 11909 |
ggml_tensor * x_prev,
|
| 11910 |
ggml_tensor * state_copy,
|
|
|
|
| 11911 |
const llama_ubatch & ubatch,
|
| 11912 |
int il) const {
|
| 11913 |
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
|
|
|
|
| 12031 |
k = ggml_sub(ctx0, k, ggml_mul(ctx0, k, w));
|
| 12032 |
}
|
| 12033 |
|
| 12034 |
+
ggml_tensor * wkv_state = build_recurrent_state(
|
| 12035 |
+
gf, kv_state->get_v_l(il), state_copy,
|
| 12036 |
hparams.n_embd_v_s(), n_seqs);
|
| 12037 |
|
| 12038 |
ggml_tensor * wkv_output;
|
|
|
|
| 12088 |
inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1);
|
| 12089 |
|
| 12090 |
ggml_tensor * state_copy = build_inp_s_copy();
|
|
|
|
| 12091 |
|
| 12092 |
const auto n_embd = hparams.n_embd;
|
| 12093 |
const auto n_seq_tokens = ubatch.n_seq_tokens;
|
|
|
|
| 12098 |
inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
|
| 12099 |
|
| 12100 |
ggml_tensor * token_shift = build_rwkv_token_shift_load(
|
| 12101 |
+
gf, state_copy, ubatch, il
|
| 12102 |
);
|
| 12103 |
|
| 12104 |
ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0);
|
|
|
|
| 12114 |
1
|
| 12115 |
);
|
| 12116 |
|
| 12117 |
+
cur = build_rwkv6_time_mix(gf, att_norm, x_prev, state_copy, ubatch, il);
|
| 12118 |
|
| 12119 |
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
|
| 12120 |
cb(ffn_inp, "ffn_inp", il);
|
|
|
|
| 12185 |
inpL = build_inp_embd(model.tok_embd);
|
| 12186 |
|
| 12187 |
ggml_tensor * state_copy = build_inp_s_copy();
|
|
|
|
| 12188 |
|
| 12189 |
const auto n_embd = hparams.n_embd;
|
| 12190 |
const auto n_seq_tokens = ubatch.n_seq_tokens;
|
|
|
|
| 12195 |
inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
|
| 12196 |
|
| 12197 |
ggml_tensor * token_shift = build_rwkv_token_shift_load(
|
| 12198 |
+
gf, state_copy, ubatch, il
|
| 12199 |
);
|
| 12200 |
|
| 12201 |
ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, il);
|
|
|
|
| 12208 |
1
|
| 12209 |
);
|
| 12210 |
|
| 12211 |
+
cur = build_rwkv6_time_mix(gf, att_norm, x_prev, state_copy, ubatch, il);
|
| 12212 |
|
| 12213 |
token_shift = ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm));
|
| 12214 |
ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il));
|
|
|
|
| 12300 |
ggml_tensor * cur,
|
| 12301 |
ggml_tensor * x_prev,
|
| 12302 |
ggml_tensor * state_copy,
|
|
|
|
| 12303 |
ggml_tensor *& first_layer_value,
|
| 12304 |
const llama_ubatch & ubatch,
|
| 12305 |
int il) const {
|
|
|
|
| 12382 |
v = ggml_reshape_3d(ctx0, v, head_size, head_count, n_tokens);
|
| 12383 |
a = ggml_reshape_3d(ctx0, a, head_size, head_count, n_tokens);
|
| 12384 |
|
| 12385 |
+
ggml_tensor * wkv_state = build_recurrent_state(
|
| 12386 |
+
gf, kv_state->get_v_l(il), state_copy,
|
| 12387 |
hparams.n_embd_v_s(), n_seqs);
|
| 12388 |
|
| 12389 |
ggml_tensor * wkv_output = ggml_rwkv_wkv7(ctx0, r, w, k, v, ggml_neg(ctx0, kk), ggml_mul(ctx0, kk, a), wkv_state);
|
|
|
|
| 12441 |
inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1);
|
| 12442 |
|
| 12443 |
ggml_tensor * state_copy = build_inp_s_copy();
|
|
|
|
| 12444 |
|
| 12445 |
const auto n_embd = hparams.n_embd;
|
| 12446 |
const auto n_seq_tokens = ubatch.n_seq_tokens;
|
|
|
|
| 12451 |
inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
|
| 12452 |
|
| 12453 |
ggml_tensor * token_shift = build_rwkv_token_shift_load(
|
| 12454 |
+
gf, state_copy, ubatch, il
|
| 12455 |
);
|
| 12456 |
|
| 12457 |
ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0);
|
|
|
|
| 12467 |
1
|
| 12468 |
);
|
| 12469 |
|
| 12470 |
+
cur = build_rwkv7_time_mix(gf, att_norm, x_prev, state_copy, v_first, ubatch, il);
|
| 12471 |
|
| 12472 |
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
|
| 12473 |
cb(ffn_inp, "ffn_inp", il);
|
|
|
|
| 12534 |
inpL = build_inp_embd(model.tok_embd);
|
| 12535 |
|
| 12536 |
ggml_tensor * state_copy = build_inp_s_copy();
|
|
|
|
| 12537 |
|
| 12538 |
const auto n_embd = hparams.n_embd;
|
| 12539 |
const auto n_seq_tokens = ubatch.n_seq_tokens;
|
|
|
|
| 12544 |
inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
|
| 12545 |
|
| 12546 |
ggml_tensor * token_shift = build_rwkv_token_shift_load(
|
| 12547 |
+
gf, state_copy, ubatch, il
|
| 12548 |
);
|
| 12549 |
|
| 12550 |
ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, il);
|
|
|
|
| 12557 |
1
|
| 12558 |
);
|
| 12559 |
|
| 12560 |
+
cur = build_rwkv7_time_mix(gf, att_norm, x_prev, state_copy, v_first, ubatch, il);
|
| 12561 |
|
| 12562 |
token_shift = ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm));
|
| 12563 |
ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il));
|
|
|
|
| 13449 |
}
|
| 13450 |
};
|
| 13451 |
|
| 13452 |
+
struct llm_build_dots1 : public llm_graph_context {
|
| 13453 |
+
llm_build_dots1(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
|
| 13454 |
+
const int64_t n_embd_head = hparams.n_embd_head_v;
|
| 13455 |
+
|
| 13456 |
+
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
| 13457 |
+
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
| 13458 |
+
|
| 13459 |
+
ggml_tensor * cur;
|
| 13460 |
+
ggml_tensor * inpL;
|
| 13461 |
+
|
| 13462 |
+
inpL = build_inp_embd(model.tok_embd);
|
| 13463 |
+
|
| 13464 |
+
// inp_pos - contains the positions
|
| 13465 |
+
ggml_tensor * inp_pos = build_inp_pos();
|
| 13466 |
+
|
| 13467 |
+
auto * inp_attn = build_attn_inp_kv_unified();
|
| 13468 |
+
|
| 13469 |
+
for (int il = 0; il < n_layer; ++il) {
|
| 13470 |
+
ggml_tensor * inpSA = inpL;
|
| 13471 |
+
|
| 13472 |
+
// norm
|
| 13473 |
+
cur = build_norm(inpL,
|
| 13474 |
+
model.layers[il].attn_norm, NULL,
|
| 13475 |
+
LLM_NORM_RMS, il);
|
| 13476 |
+
cb(cur, "attn_norm", il);
|
| 13477 |
+
|
| 13478 |
+
// self_attention
|
| 13479 |
+
{
|
| 13480 |
+
// compute Q and K and RoPE them
|
| 13481 |
+
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
| 13482 |
+
cb(Qcur, "Qcur", il);
|
| 13483 |
+
|
| 13484 |
+
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
|
| 13485 |
+
cb(Kcur, "Kcur", il);
|
| 13486 |
+
|
| 13487 |
+
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
|
| 13488 |
+
cb(Vcur, "Vcur", il);
|
| 13489 |
+
|
| 13490 |
+
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
| 13491 |
+
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
| 13492 |
+
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
| 13493 |
+
|
| 13494 |
+
Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
|
| 13495 |
+
cb(Qcur, "Qcur_normed", il);
|
| 13496 |
+
|
| 13497 |
+
Qcur = ggml_rope_ext(
|
| 13498 |
+
ctx0, Qcur, inp_pos, nullptr,
|
| 13499 |
+
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
| 13500 |
+
ext_factor, attn_factor, beta_fast, beta_slow
|
| 13501 |
+
);
|
| 13502 |
+
|
| 13503 |
+
Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
|
| 13504 |
+
cb(Kcur, "Kcur_normed", il);
|
| 13505 |
+
|
| 13506 |
+
Kcur = ggml_rope_ext(
|
| 13507 |
+
ctx0, Kcur, inp_pos, nullptr,
|
| 13508 |
+
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
| 13509 |
+
ext_factor, attn_factor, beta_fast, beta_slow
|
| 13510 |
+
);
|
| 13511 |
+
|
| 13512 |
+
cb(Qcur, "Qcur", il);
|
| 13513 |
+
cb(Kcur, "Kcur", il);
|
| 13514 |
+
cb(Vcur, "Vcur", il);
|
| 13515 |
+
|
| 13516 |
+
cur = build_attn(inp_attn, gf,
|
| 13517 |
+
model.layers[il].wo, model.layers[il].bo,
|
| 13518 |
+
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
| 13519 |
+
}
|
| 13520 |
+
|
| 13521 |
+
if (il == n_layer - 1) {
|
| 13522 |
+
// skip computing output for unused tokens
|
| 13523 |
+
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 13524 |
+
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 13525 |
+
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
| 13526 |
+
}
|
| 13527 |
+
|
| 13528 |
+
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
| 13529 |
+
cb(ffn_inp, "ffn_inp", il);
|
| 13530 |
+
|
| 13531 |
+
// MoE branch
|
| 13532 |
+
cur = build_norm(ffn_inp,
|
| 13533 |
+
model.layers[il].ffn_norm, NULL,
|
| 13534 |
+
LLM_NORM_RMS, il);
|
| 13535 |
+
cb(cur, "ffn_norm", il);
|
| 13536 |
+
|
| 13537 |
+
if ((uint32_t) il < hparams.n_layer_dense_lead) {
|
| 13538 |
+
cur = build_ffn(cur,
|
| 13539 |
+
model.layers[il].ffn_up, NULL, NULL,
|
| 13540 |
+
model.layers[il].ffn_gate, NULL, NULL,
|
| 13541 |
+
model.layers[il].ffn_down, NULL, NULL,
|
| 13542 |
+
NULL,
|
| 13543 |
+
LLM_FFN_SILU, LLM_FFN_PAR, il);
|
| 13544 |
+
cb(cur, "ffn_out", il);
|
| 13545 |
+
} else {
|
| 13546 |
+
ggml_tensor * moe_out =
|
| 13547 |
+
build_moe_ffn(cur,
|
| 13548 |
+
model.layers[il].ffn_gate_inp,
|
| 13549 |
+
model.layers[il].ffn_up_exps,
|
| 13550 |
+
model.layers[il].ffn_gate_exps,
|
| 13551 |
+
model.layers[il].ffn_down_exps,
|
| 13552 |
+
model.layers[il].ffn_exp_probs_b,
|
| 13553 |
+
n_expert, n_expert_used,
|
| 13554 |
+
LLM_FFN_SILU, hparams.expert_weights_norm,
|
| 13555 |
+
true, hparams.expert_weights_scale,
|
| 13556 |
+
(llama_expert_gating_func_type) hparams.expert_gating_func,
|
| 13557 |
+
il);
|
| 13558 |
+
cb(moe_out, "ffn_moe_out", il);
|
| 13559 |
+
|
| 13560 |
+
{
|
| 13561 |
+
ggml_tensor * ffn_shexp = build_ffn(cur,
|
| 13562 |
+
model.layers[il].ffn_up_shexp, NULL, NULL,
|
| 13563 |
+
model.layers[il].ffn_gate_shexp, NULL, NULL,
|
| 13564 |
+
model.layers[il].ffn_down_shexp, NULL, NULL,
|
| 13565 |
+
NULL,
|
| 13566 |
+
LLM_FFN_SILU, LLM_FFN_PAR, il);
|
| 13567 |
+
cb(ffn_shexp, "ffn_shexp", il);
|
| 13568 |
+
|
| 13569 |
+
cur = ggml_add(ctx0, moe_out, ffn_shexp);
|
| 13570 |
+
cb(cur, "ffn_out", il);
|
| 13571 |
+
}
|
| 13572 |
+
}
|
| 13573 |
+
|
| 13574 |
+
cur = ggml_add(ctx0, cur, ffn_inp);
|
| 13575 |
+
|
| 13576 |
+
cur = build_cvec(cur, il);
|
| 13577 |
+
cb(cur, "l_out", il);
|
| 13578 |
+
|
| 13579 |
+
// input for next layer
|
| 13580 |
+
inpL = cur;
|
| 13581 |
+
}
|
| 13582 |
+
|
| 13583 |
+
cur = inpL;
|
| 13584 |
+
|
| 13585 |
+
cur = build_norm(cur,
|
| 13586 |
+
model.output_norm, NULL,
|
| 13587 |
+
LLM_NORM_RMS, -1);
|
| 13588 |
+
|
| 13589 |
+
cb(cur, "result_norm", -1);
|
| 13590 |
+
res->t_embd = cur;
|
| 13591 |
+
|
| 13592 |
+
// lm_head
|
| 13593 |
+
cur = build_lora_mm(model.output, cur);
|
| 13594 |
+
|
| 13595 |
+
cb(cur, "result_output", -1);
|
| 13596 |
+
res->t_logits = cur;
|
| 13597 |
+
|
| 13598 |
+
ggml_build_forward_expand(gf, cur);
|
| 13599 |
+
}
|
| 13600 |
+
};
|
| 13601 |
+
|
| 13602 |
+
struct llm_build_arcee : public llm_graph_context {
|
| 13603 |
+
llm_build_arcee(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
|
| 13604 |
+
const int64_t n_embd_head = hparams.n_embd_head_v;
|
| 13605 |
+
|
| 13606 |
+
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
| 13607 |
+
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
| 13608 |
+
|
| 13609 |
+
ggml_tensor * cur;
|
| 13610 |
+
ggml_tensor * inpL;
|
| 13611 |
+
|
| 13612 |
+
inpL = build_inp_embd(model.tok_embd);
|
| 13613 |
+
|
| 13614 |
+
// inp_pos - contains the positions
|
| 13615 |
+
ggml_tensor * inp_pos = build_inp_pos();
|
| 13616 |
+
|
| 13617 |
+
auto * inp_attn = build_attn_inp_kv_unified();
|
| 13618 |
+
|
| 13619 |
+
const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
|
| 13620 |
+
|
| 13621 |
+
for (int il = 0; il < n_layer; ++il) {
|
| 13622 |
+
ggml_tensor * inpSA = inpL;
|
| 13623 |
+
|
| 13624 |
+
// norm
|
| 13625 |
+
cur = build_norm(inpL,
|
| 13626 |
+
model.layers[il].attn_norm, NULL,
|
| 13627 |
+
LLM_NORM_RMS, il);
|
| 13628 |
+
cb(cur, "attn_norm", il);
|
| 13629 |
+
|
| 13630 |
+
// self-attention
|
| 13631 |
+
{
|
| 13632 |
+
// rope freq factors for llama3; may return nullptr for llama2 and other models
|
| 13633 |
+
ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
|
| 13634 |
+
|
| 13635 |
+
// compute Q and K and RoPE them
|
| 13636 |
+
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
| 13637 |
+
cb(Qcur, "Qcur", il);
|
| 13638 |
+
if (model.layers[il].bq) {
|
| 13639 |
+
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
|
| 13640 |
+
cb(Qcur, "Qcur", il);
|
| 13641 |
+
}
|
| 13642 |
+
|
| 13643 |
+
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
|
| 13644 |
+
cb(Kcur, "Kcur", il);
|
| 13645 |
+
if (model.layers[il].bk) {
|
| 13646 |
+
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
|
| 13647 |
+
cb(Kcur, "Kcur", il);
|
| 13648 |
+
}
|
| 13649 |
+
|
| 13650 |
+
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
|
| 13651 |
+
cb(Vcur, "Vcur", il);
|
| 13652 |
+
if (model.layers[il].bv) {
|
| 13653 |
+
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
|
| 13654 |
+
cb(Vcur, "Vcur", il);
|
| 13655 |
+
}
|
| 13656 |
+
|
| 13657 |
+
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
| 13658 |
+
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
| 13659 |
+
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
| 13660 |
+
|
| 13661 |
+
Qcur = ggml_rope_ext(
|
| 13662 |
+
ctx0, Qcur, inp_pos, rope_factors,
|
| 13663 |
+
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
| 13664 |
+
ext_factor, attn_factor, beta_fast, beta_slow
|
| 13665 |
+
);
|
| 13666 |
+
|
| 13667 |
+
Kcur = ggml_rope_ext(
|
| 13668 |
+
ctx0, Kcur, inp_pos, rope_factors,
|
| 13669 |
+
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
| 13670 |
+
ext_factor, attn_factor, beta_fast, beta_slow
|
| 13671 |
+
);
|
| 13672 |
+
|
| 13673 |
+
cb(Qcur, "Qcur", il);
|
| 13674 |
+
cb(Kcur, "Kcur", il);
|
| 13675 |
+
cb(Vcur, "Vcur", il);
|
| 13676 |
+
|
| 13677 |
+
cur = build_attn(inp_attn, gf,
|
| 13678 |
+
model.layers[il].wo, model.layers[il].bo,
|
| 13679 |
+
Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
|
| 13680 |
+
cb(cur, "attn_out", il);
|
| 13681 |
+
}
|
| 13682 |
+
|
| 13683 |
+
if (il == n_layer - 1) {
|
| 13684 |
+
// skip computing output for unused tokens
|
| 13685 |
+
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 13686 |
+
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 13687 |
+
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
| 13688 |
+
}
|
| 13689 |
+
|
| 13690 |
+
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
| 13691 |
+
cb(ffn_inp, "ffn_inp", il);
|
| 13692 |
+
|
| 13693 |
+
// feed-forward network
|
| 13694 |
+
// ARCEE uses relu^2 instead of silu
|
| 13695 |
+
cur = build_norm(ffn_inp,
|
| 13696 |
+
model.layers[il].ffn_norm, NULL,
|
| 13697 |
+
LLM_NORM_RMS, il);
|
| 13698 |
+
cb(cur, "ffn_norm", il);
|
| 13699 |
+
|
| 13700 |
+
cur = build_ffn(cur,
|
| 13701 |
+
model.layers[il].ffn_up, NULL, NULL,
|
| 13702 |
+
NULL, NULL, NULL,
|
| 13703 |
+
model.layers[il].ffn_down, NULL, NULL,
|
| 13704 |
+
NULL,
|
| 13705 |
+
LLM_FFN_RELU_SQR, LLM_FFN_SEQ, il);
|
| 13706 |
+
cb(cur, "ffn_out", il);
|
| 13707 |
+
|
| 13708 |
+
cur = ggml_add(ctx0, cur, ffn_inp);
|
| 13709 |
+
cb(cur, "ffn_out", il);
|
| 13710 |
+
|
| 13711 |
+
cur = build_cvec(cur, il);
|
| 13712 |
+
cb(cur, "l_out", il);
|
| 13713 |
+
|
| 13714 |
+
// input for next layer
|
| 13715 |
+
inpL = cur;
|
| 13716 |
+
}
|
| 13717 |
+
|
| 13718 |
+
cur = inpL;
|
| 13719 |
+
|
| 13720 |
+
cur = build_norm(cur,
|
| 13721 |
+
model.output_norm, NULL,
|
| 13722 |
+
LLM_NORM_RMS, -1);
|
| 13723 |
+
|
| 13724 |
+
cb(cur, "result_norm", -1);
|
| 13725 |
+
res->t_embd = cur;
|
| 13726 |
+
|
| 13727 |
+
// lm_head
|
| 13728 |
+
cur = build_lora_mm(model.output, cur);
|
| 13729 |
+
|
| 13730 |
+
cb(cur, "result_output", -1);
|
| 13731 |
+
res->t_logits = cur;
|
| 13732 |
+
|
| 13733 |
+
ggml_build_forward_expand(gf, cur);
|
| 13734 |
+
}
|
| 13735 |
+
};
|
| 13736 |
+
|
| 13737 |
llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_cparams & cparams) const {
|
| 13738 |
llama_memory_i * res;
|
| 13739 |
|
|
|
|
| 13742 |
case LLM_ARCH_JINA_BERT_V2:
|
| 13743 |
case LLM_ARCH_NOMIC_BERT:
|
| 13744 |
case LLM_ARCH_NOMIC_BERT_MOE:
|
| 13745 |
+
case LLM_ARCH_NEO_BERT:
|
| 13746 |
case LLM_ARCH_WAVTOKENIZER_DEC:
|
| 13747 |
{
|
| 13748 |
res = nullptr;
|
|
|
|
| 13851 |
{
|
| 13852 |
llm = std::make_unique<llm_build_bert>(*this, params, gf);
|
| 13853 |
} break;
|
| 13854 |
+
case LLM_ARCH_NEO_BERT:
|
| 13855 |
+
{
|
| 13856 |
+
llm = std::make_unique<llm_build_neo_bert>(*this, params, gf);
|
| 13857 |
+
} break;
|
| 13858 |
case LLM_ARCH_BLOOM:
|
| 13859 |
{
|
| 13860 |
llm = std::make_unique<llm_build_bloom>(*this, params, gf);
|
|
|
|
| 14077 |
{
|
| 14078 |
llm = std::make_unique<llm_build_bailingmoe>(*this, params, gf);
|
| 14079 |
} break;
|
| 14080 |
+
case LLM_ARCH_DOTS1:
|
| 14081 |
+
{
|
| 14082 |
+
llm = std::make_unique<llm_build_dots1>(*this, params, gf);
|
| 14083 |
+
} break;
|
| 14084 |
+
case LLM_ARCH_ARCEE:
|
| 14085 |
+
{
|
| 14086 |
+
llm = std::make_unique<llm_build_arcee>(*this, params, gf);
|
| 14087 |
+
} break;
|
| 14088 |
default:
|
| 14089 |
GGML_ABORT("fatal error");
|
| 14090 |
}
|
|
|
|
| 14234 |
case LLM_ARCH_GRANITE_MOE:
|
| 14235 |
case LLM_ARCH_CHAMELEON:
|
| 14236 |
case LLM_ARCH_BAILINGMOE:
|
| 14237 |
+
case LLM_ARCH_NEO_BERT:
|
| 14238 |
+
case LLM_ARCH_ARCEE:
|
| 14239 |
return LLAMA_ROPE_TYPE_NORM;
|
| 14240 |
|
| 14241 |
// the pairs of head values are offset by n_rot/2
|
|
|
|
| 14269 |
case LLM_ARCH_NEMOTRON:
|
| 14270 |
case LLM_ARCH_EXAONE:
|
| 14271 |
case LLM_ARCH_MINICPM3:
|
| 14272 |
+
case LLM_ARCH_DOTS1:
|
| 14273 |
return LLAMA_ROPE_TYPE_NEOX;
|
| 14274 |
|
| 14275 |
case LLM_ARCH_QWEN2VL:
|
examples/talk-llama/llama-model.h
CHANGED
|
@@ -73,6 +73,7 @@ enum llm_type {
|
|
| 73 |
LLM_TYPE_40B,
|
| 74 |
LLM_TYPE_65B,
|
| 75 |
LLM_TYPE_70B,
|
|
|
|
| 76 |
LLM_TYPE_236B,
|
| 77 |
LLM_TYPE_290B,
|
| 78 |
LLM_TYPE_314B,
|
|
|
|
| 73 |
LLM_TYPE_40B,
|
| 74 |
LLM_TYPE_65B,
|
| 75 |
LLM_TYPE_70B,
|
| 76 |
+
LLM_TYPE_142B,
|
| 77 |
LLM_TYPE_236B,
|
| 78 |
LLM_TYPE_290B,
|
| 79 |
LLM_TYPE_314B,
|
examples/talk-llama/llama-quant.cpp
CHANGED
|
@@ -585,7 +585,8 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
|
| 585 |
if (o.tag == LLAMA_KV_OVERRIDE_TYPE_FLOAT) {
|
| 586 |
gguf_set_val_f32(ctx_out.get(), o.key, o.val_f64);
|
| 587 |
} else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_INT) {
|
| 588 |
-
|
|
|
|
| 589 |
} else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_BOOL) {
|
| 590 |
gguf_set_val_bool(ctx_out.get(), o.key, o.val_bool);
|
| 591 |
} else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_STR) {
|
|
|
|
| 585 |
if (o.tag == LLAMA_KV_OVERRIDE_TYPE_FLOAT) {
|
| 586 |
gguf_set_val_f32(ctx_out.get(), o.key, o.val_f64);
|
| 587 |
} else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_INT) {
|
| 588 |
+
// Setting type to UINT32. See https://github.com/ggml-org/llama.cpp/pull/14182 for context
|
| 589 |
+
gguf_set_val_u32(ctx_out.get(), o.key, (uint32_t)abs(o.val_i64));
|
| 590 |
} else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_BOOL) {
|
| 591 |
gguf_set_val_bool(ctx_out.get(), o.key, o.val_bool);
|
| 592 |
} else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_STR) {
|
examples/talk-llama/llama-vocab.cpp
CHANGED
|
@@ -9,16 +9,16 @@
|
|
| 9 |
|
| 10 |
#include <algorithm>
|
| 11 |
#include <cassert>
|
|
|
|
| 12 |
#include <cfloat>
|
| 13 |
-
#include <climits>
|
| 14 |
#include <cstdarg>
|
| 15 |
#include <cstring>
|
| 16 |
#include <forward_list>
|
|
|
|
| 17 |
#include <map>
|
| 18 |
#include <queue>
|
| 19 |
#include <set>
|
| 20 |
#include <unordered_map>
|
| 21 |
-
#include <cctype>
|
| 22 |
|
| 23 |
//
|
| 24 |
// helpers
|
|
@@ -1987,6 +1987,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
|
| 1987 |
|| t.first == "<|eom_id|>"
|
| 1988 |
|| t.first == "<EOT>"
|
| 1989 |
|| t.first == "_<EOT>"
|
|
|
|
| 1990 |
) {
|
| 1991 |
special_eog_ids.insert(t.second);
|
| 1992 |
if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
|
|
@@ -2572,6 +2573,10 @@ int32_t llama_vocab::impl::token_to_piece(llama_token token, char * buf, int32_t
|
|
| 2572 |
// copy piece chars to output text buffer
|
| 2573 |
// skip up to 'lstrip' leading spaces before copying
|
| 2574 |
auto _try_copy = [=] (const char * token, size_t size) -> int32_t {
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2575 |
for (int32_t i = 0; i < lstrip && size && *token == ' '; ++i) {
|
| 2576 |
token++;
|
| 2577 |
size--;
|
|
@@ -2768,26 +2773,26 @@ void llama_vocab::impl::print_info() const {
|
|
| 2768 |
LLAMA_LOG_INFO("%s: n_merges = %u\n", __func__, (uint32_t) bpe_ranks.size());
|
| 2769 |
|
| 2770 |
// special tokens
|
| 2771 |
-
if (special_bos_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: BOS token = %d '%s'\n", __func__, special_bos_id, id_to_token
|
| 2772 |
-
if (special_eos_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: EOS token = %d '%s'\n", __func__, special_eos_id, id_to_token
|
| 2773 |
-
if (special_eot_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: EOT token = %d '%s'\n", __func__, special_eot_id, id_to_token
|
| 2774 |
-
if (special_eom_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: EOM token = %d '%s'\n", __func__, special_eom_id, id_to_token
|
| 2775 |
-
if (special_unk_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: UNK token = %d '%s'\n", __func__, special_unk_id, id_to_token
|
| 2776 |
-
if (special_sep_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: SEP token = %d '%s'\n", __func__, special_sep_id, id_to_token
|
| 2777 |
-
if (special_pad_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: PAD token = %d '%s'\n", __func__, special_pad_id, id_to_token
|
| 2778 |
-
if (special_mask_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: MASK token = %d '%s'\n", __func__, special_mask_id, id_to_token
|
| 2779 |
-
|
| 2780 |
-
if (linefeed_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: LF token = %d '%s'\n", __func__, linefeed_id, id_to_token
|
| 2781 |
-
|
| 2782 |
-
if (special_fim_pre_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM PRE token = %d '%s'\n", __func__, special_fim_pre_id, id_to_token
|
| 2783 |
-
if (special_fim_suf_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM SUF token = %d '%s'\n", __func__, special_fim_suf_id, id_to_token
|
| 2784 |
-
if (special_fim_mid_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM MID token = %d '%s'\n", __func__, special_fim_mid_id, id_to_token
|
| 2785 |
-
if (special_fim_pad_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM PAD token = %d '%s'\n", __func__, special_fim_pad_id, id_to_token
|
| 2786 |
-
if (special_fim_rep_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM REP token = %d '%s'\n", __func__, special_fim_rep_id, id_to_token
|
| 2787 |
-
if (special_fim_sep_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM SEP token = %d '%s'\n", __func__, special_fim_sep_id, id_to_token
|
| 2788 |
|
| 2789 |
for (const auto & id : special_eog_ids) {
|
| 2790 |
-
LLAMA_LOG_INFO( "%s: EOG token = %d '%s'\n", __func__, id, id_to_token
|
| 2791 |
}
|
| 2792 |
|
| 2793 |
LLAMA_LOG_INFO("%s: max token length = %d\n", __func__, max_token_len);
|
|
|
|
| 9 |
|
| 10 |
#include <algorithm>
|
| 11 |
#include <cassert>
|
| 12 |
+
#include <cctype>
|
| 13 |
#include <cfloat>
|
|
|
|
| 14 |
#include <cstdarg>
|
| 15 |
#include <cstring>
|
| 16 |
#include <forward_list>
|
| 17 |
+
#include <limits>
|
| 18 |
#include <map>
|
| 19 |
#include <queue>
|
| 20 |
#include <set>
|
| 21 |
#include <unordered_map>
|
|
|
|
| 22 |
|
| 23 |
//
|
| 24 |
// helpers
|
|
|
|
| 1987 |
|| t.first == "<|eom_id|>"
|
| 1988 |
|| t.first == "<EOT>"
|
| 1989 |
|| t.first == "_<EOT>"
|
| 1990 |
+
|| t.first == "<|end_of_text|>"
|
| 1991 |
) {
|
| 1992 |
special_eog_ids.insert(t.second);
|
| 1993 |
if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
|
|
|
|
| 2573 |
// copy piece chars to output text buffer
|
| 2574 |
// skip up to 'lstrip' leading spaces before copying
|
| 2575 |
auto _try_copy = [=] (const char * token, size_t size) -> int32_t {
|
| 2576 |
+
if (size >= static_cast<size_t>(std::numeric_limits<int32_t>::max())) {
|
| 2577 |
+
GGML_ABORT("invalid token size: %zu exceeds int32_t limit", size);
|
| 2578 |
+
}
|
| 2579 |
+
|
| 2580 |
for (int32_t i = 0; i < lstrip && size && *token == ' '; ++i) {
|
| 2581 |
token++;
|
| 2582 |
size--;
|
|
|
|
| 2773 |
LLAMA_LOG_INFO("%s: n_merges = %u\n", __func__, (uint32_t) bpe_ranks.size());
|
| 2774 |
|
| 2775 |
// special tokens
|
| 2776 |
+
if (special_bos_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: BOS token = %d '%s'\n", __func__, special_bos_id, id_to_token.at(special_bos_id).text.c_str() ); }
|
| 2777 |
+
if (special_eos_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: EOS token = %d '%s'\n", __func__, special_eos_id, id_to_token.at(special_eos_id).text.c_str() ); }
|
| 2778 |
+
if (special_eot_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: EOT token = %d '%s'\n", __func__, special_eot_id, id_to_token.at(special_eot_id).text.c_str() ); }
|
| 2779 |
+
if (special_eom_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: EOM token = %d '%s'\n", __func__, special_eom_id, id_to_token.at(special_eom_id).text.c_str() ); }
|
| 2780 |
+
if (special_unk_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: UNK token = %d '%s'\n", __func__, special_unk_id, id_to_token.at(special_unk_id).text.c_str() ); }
|
| 2781 |
+
if (special_sep_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: SEP token = %d '%s'\n", __func__, special_sep_id, id_to_token.at(special_sep_id).text.c_str() ); }
|
| 2782 |
+
if (special_pad_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: PAD token = %d '%s'\n", __func__, special_pad_id, id_to_token.at(special_pad_id).text.c_str() ); }
|
| 2783 |
+
if (special_mask_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: MASK token = %d '%s'\n", __func__, special_mask_id, id_to_token.at(special_mask_id).text.c_str() ); }
|
| 2784 |
+
|
| 2785 |
+
if (linefeed_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: LF token = %d '%s'\n", __func__, linefeed_id, id_to_token.at(linefeed_id).text.c_str() ); }
|
| 2786 |
+
|
| 2787 |
+
if (special_fim_pre_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM PRE token = %d '%s'\n", __func__, special_fim_pre_id, id_to_token.at(special_fim_pre_id).text.c_str() ); }
|
| 2788 |
+
if (special_fim_suf_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM SUF token = %d '%s'\n", __func__, special_fim_suf_id, id_to_token.at(special_fim_suf_id).text.c_str() ); }
|
| 2789 |
+
if (special_fim_mid_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM MID token = %d '%s'\n", __func__, special_fim_mid_id, id_to_token.at(special_fim_mid_id).text.c_str() ); }
|
| 2790 |
+
if (special_fim_pad_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM PAD token = %d '%s'\n", __func__, special_fim_pad_id, id_to_token.at(special_fim_pad_id).text.c_str() ); }
|
| 2791 |
+
if (special_fim_rep_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM REP token = %d '%s'\n", __func__, special_fim_rep_id, id_to_token.at(special_fim_rep_id).text.c_str() ); }
|
| 2792 |
+
if (special_fim_sep_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM SEP token = %d '%s'\n", __func__, special_fim_sep_id, id_to_token.at(special_fim_sep_id).text.c_str() ); }
|
| 2793 |
|
| 2794 |
for (const auto & id : special_eog_ids) {
|
| 2795 |
+
LLAMA_LOG_INFO( "%s: EOG token = %d '%s'\n", __func__, id, id_to_token.at(id).text.c_str() );
|
| 2796 |
}
|
| 2797 |
|
| 2798 |
LLAMA_LOG_INFO("%s: max token length = %d\n", __func__, max_token_len);
|
examples/talk-llama/llama.cpp
CHANGED
|
@@ -198,14 +198,18 @@ static struct llama_model * llama_model_load_from_file_impl(
|
|
| 198 |
|
| 199 |
// if using single GPU mode, remove all except the main GPU
|
| 200 |
if (params.split_mode == LLAMA_SPLIT_MODE_NONE) {
|
| 201 |
-
if (params.main_gpu < 0
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
}
|
| 206 |
-
ggml_backend_dev_t main_gpu = model->devices[params.main_gpu];
|
| 207 |
-
model->devices.clear();
|
| 208 |
-
model->devices.push_back(main_gpu);
|
| 209 |
}
|
| 210 |
|
| 211 |
for (auto * dev : model->devices) {
|
|
|
|
| 198 |
|
| 199 |
// if using single GPU mode, remove all except the main GPU
|
| 200 |
if (params.split_mode == LLAMA_SPLIT_MODE_NONE) {
|
| 201 |
+
if (params.main_gpu < 0) {
|
| 202 |
+
model->devices.clear();
|
| 203 |
+
} else {
|
| 204 |
+
if (params.main_gpu >= (int)model->devices.size()) {
|
| 205 |
+
LLAMA_LOG_ERROR("%s: invalid value for main_gpu: %d (available devices: %zu)\n", __func__, params.main_gpu, model->devices.size());
|
| 206 |
+
llama_model_free(model);
|
| 207 |
+
return nullptr;
|
| 208 |
+
}
|
| 209 |
+
ggml_backend_dev_t main_gpu = model->devices[params.main_gpu];
|
| 210 |
+
model->devices.clear();
|
| 211 |
+
model->devices.push_back(main_gpu);
|
| 212 |
}
|
|
|
|
|
|
|
|
|
|
| 213 |
}
|
| 214 |
|
| 215 |
for (auto * dev : model->devices) {
|
examples/talk-llama/llama.h
CHANGED
|
@@ -243,18 +243,21 @@ extern "C" {
|
|
| 243 |
|
| 244 |
typedef bool (*llama_progress_callback)(float progress, void * user_data);
|
| 245 |
|
| 246 |
-
// Input data for llama_decode
|
| 247 |
// A llama_batch object can contain input about one or many sequences
|
| 248 |
// The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens
|
| 249 |
//
|
| 250 |
// - token : the token ids of the input (used when embd is NULL)
|
| 251 |
// - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL)
|
| 252 |
// - pos : the positions of the respective token in the sequence
|
| 253 |
-
// (if set to NULL, the token position will be tracked automatically by llama_decode)
|
| 254 |
// - seq_id : the sequence to which the respective token belongs
|
| 255 |
// (if set to NULL, the sequence ID will be assumed to be 0)
|
| 256 |
// - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output
|
| 257 |
-
// (if set to NULL
|
|
|
|
|
|
|
|
|
|
| 258 |
//
|
| 259 |
typedef struct llama_batch {
|
| 260 |
int32_t n_tokens;
|
|
@@ -262,8 +265,8 @@ extern "C" {
|
|
| 262 |
llama_token * token;
|
| 263 |
float * embd;
|
| 264 |
llama_pos * pos;
|
| 265 |
-
int32_t * n_seq_id;
|
| 266 |
-
llama_seq_id ** seq_id;
|
| 267 |
int8_t * logits; // TODO: rename this to "output"
|
| 268 |
} llama_batch;
|
| 269 |
|
|
@@ -961,8 +964,8 @@ extern "C" {
|
|
| 961 |
// Get the number of threads used for prompt and batch processing (multiple token).
|
| 962 |
LLAMA_API int32_t llama_n_threads_batch(struct llama_context * ctx);
|
| 963 |
|
| 964 |
-
// Set whether the
|
| 965 |
-
//
|
| 966 |
LLAMA_API void llama_set_embeddings(struct llama_context * ctx, bool embeddings);
|
| 967 |
|
| 968 |
// Set whether to use causal attention or not
|
|
|
|
| 243 |
|
| 244 |
typedef bool (*llama_progress_callback)(float progress, void * user_data);
|
| 245 |
|
| 246 |
+
// Input data for llama_encode/llama_decode
|
| 247 |
// A llama_batch object can contain input about one or many sequences
|
| 248 |
// The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens
|
| 249 |
//
|
| 250 |
// - token : the token ids of the input (used when embd is NULL)
|
| 251 |
// - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL)
|
| 252 |
// - pos : the positions of the respective token in the sequence
|
| 253 |
+
// (if set to NULL, the token position will be tracked automatically by llama_encode/llama_decode)
|
| 254 |
// - seq_id : the sequence to which the respective token belongs
|
| 255 |
// (if set to NULL, the sequence ID will be assumed to be 0)
|
| 256 |
// - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output
|
| 257 |
+
// (if set to NULL:
|
| 258 |
+
// - if embeddings: all tokens are output
|
| 259 |
+
// - if not: only the last token is output
|
| 260 |
+
// )
|
| 261 |
//
|
| 262 |
typedef struct llama_batch {
|
| 263 |
int32_t n_tokens;
|
|
|
|
| 265 |
llama_token * token;
|
| 266 |
float * embd;
|
| 267 |
llama_pos * pos;
|
| 268 |
+
int32_t * n_seq_id;
|
| 269 |
+
llama_seq_id ** seq_id;
|
| 270 |
int8_t * logits; // TODO: rename this to "output"
|
| 271 |
} llama_batch;
|
| 272 |
|
|
|
|
| 964 |
// Get the number of threads used for prompt and batch processing (multiple token).
|
| 965 |
LLAMA_API int32_t llama_n_threads_batch(struct llama_context * ctx);
|
| 966 |
|
| 967 |
+
// Set whether the context outputs embeddings or not
|
| 968 |
+
// TODO: rename to avoid confusion with llama_get_embeddings()
|
| 969 |
LLAMA_API void llama_set_embeddings(struct llama_context * ctx, bool embeddings);
|
| 970 |
|
| 971 |
// Set whether to use causal attention or not
|