ggerganov commited on
Commit
fc04dc0
·
1 Parent(s): 23e1986

talk-llama : sync llama.cpp

Browse files
examples/talk-llama/llama-arch.cpp CHANGED
@@ -20,6 +20,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
20
  { LLM_ARCH_BERT, "bert" },
21
  { LLM_ARCH_NOMIC_BERT, "nomic-bert" },
22
  { LLM_ARCH_NOMIC_BERT_MOE, "nomic-bert-moe" },
 
23
  { LLM_ARCH_JINA_BERT_V2, "jina-bert-v2" },
24
  { LLM_ARCH_BLOOM, "bloom" },
25
  { LLM_ARCH_STABLELM, "stablelm" },
@@ -72,6 +73,8 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
72
  { LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" },
73
  { LLM_ARCH_PLM, "plm" },
74
  { LLM_ARCH_BAILINGMOE, "bailingmoe" },
 
 
75
  { LLM_ARCH_UNKNOWN, "(unknown)" },
76
  };
77
 
@@ -243,6 +246,24 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
243
  { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
244
  },
245
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
  {
247
  LLM_ARCH_LLAMA4,
248
  {
@@ -494,6 +515,21 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
494
  { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
495
  },
496
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
497
  {
498
  LLM_ARCH_JINA_BERT_V2,
499
  {
@@ -1555,6 +1591,34 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
1555
  { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
1556
  },
1557
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1558
  {
1559
  LLM_ARCH_UNKNOWN,
1560
  {
 
20
  { LLM_ARCH_BERT, "bert" },
21
  { LLM_ARCH_NOMIC_BERT, "nomic-bert" },
22
  { LLM_ARCH_NOMIC_BERT_MOE, "nomic-bert-moe" },
23
+ { LLM_ARCH_NEO_BERT, "neo-bert" },
24
  { LLM_ARCH_JINA_BERT_V2, "jina-bert-v2" },
25
  { LLM_ARCH_BLOOM, "bloom" },
26
  { LLM_ARCH_STABLELM, "stablelm" },
 
73
  { LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" },
74
  { LLM_ARCH_PLM, "plm" },
75
  { LLM_ARCH_BAILINGMOE, "bailingmoe" },
76
+ { LLM_ARCH_DOTS1, "dots1" },
77
+ { LLM_ARCH_ARCEE, "arcee" },
78
  { LLM_ARCH_UNKNOWN, "(unknown)" },
79
  };
80
 
 
246
  { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
247
  },
248
  },
249
+ {
250
+ LLM_ARCH_ARCEE,
251
+ {
252
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
253
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
254
+ { LLM_TENSOR_OUTPUT, "output" },
255
+ { LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
256
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
257
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
258
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
259
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
260
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
261
+ { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
262
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
263
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
264
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
265
+ },
266
+ },
267
  {
268
  LLM_ARCH_LLAMA4,
269
  {
 
515
  { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
516
  },
517
  },
518
+ {
519
+ LLM_ARCH_NEO_BERT,
520
+ {
521
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
522
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
523
+ { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
524
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
525
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
526
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
527
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
528
+ { LLM_TENSOR_ENC_OUTPUT_NORM, "enc.output_norm" },
529
+ { LLM_TENSOR_CLS, "cls" },
530
+ { LLM_TENSOR_CLS_OUT, "cls.output" },
531
+ },
532
+ },
533
  {
534
  LLM_ARCH_JINA_BERT_V2,
535
  {
 
1591
  { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
1592
  },
1593
  },
1594
+ {
1595
+ LLM_ARCH_DOTS1,
1596
+ {
1597
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
1598
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
1599
+ { LLM_TENSOR_OUTPUT, "output" },
1600
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
1601
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
1602
+ { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
1603
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
1604
+ { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
1605
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
1606
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
1607
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
1608
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
1609
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
1610
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
1611
+ { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
1612
+ { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
1613
+ { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
1614
+ { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
1615
+ { LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" },
1616
+ { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
1617
+ { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
1618
+ { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
1619
+ { LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
1620
+ }
1621
+ },
1622
  {
1623
  LLM_ARCH_UNKNOWN,
1624
  {
examples/talk-llama/llama-arch.h CHANGED
@@ -24,6 +24,7 @@ enum llm_arch {
24
  LLM_ARCH_BERT,
25
  LLM_ARCH_NOMIC_BERT,
26
  LLM_ARCH_NOMIC_BERT_MOE,
 
27
  LLM_ARCH_JINA_BERT_V2,
28
  LLM_ARCH_BLOOM,
29
  LLM_ARCH_STABLELM,
@@ -76,6 +77,8 @@ enum llm_arch {
76
  LLM_ARCH_WAVTOKENIZER_DEC,
77
  LLM_ARCH_PLM,
78
  LLM_ARCH_BAILINGMOE,
 
 
79
  LLM_ARCH_UNKNOWN,
80
  };
81
 
 
24
  LLM_ARCH_BERT,
25
  LLM_ARCH_NOMIC_BERT,
26
  LLM_ARCH_NOMIC_BERT_MOE,
27
+ LLM_ARCH_NEO_BERT,
28
  LLM_ARCH_JINA_BERT_V2,
29
  LLM_ARCH_BLOOM,
30
  LLM_ARCH_STABLELM,
 
77
  LLM_ARCH_WAVTOKENIZER_DEC,
78
  LLM_ARCH_PLM,
79
  LLM_ARCH_BAILINGMOE,
80
+ LLM_ARCH_DOTS1,
81
+ LLM_ARCH_ARCEE,
82
  LLM_ARCH_UNKNOWN,
83
  };
84
 
examples/talk-llama/llama-batch.cpp CHANGED
@@ -1,8 +1,14 @@
1
  #include "llama-batch.h"
2
 
 
 
 
 
 
3
  #include <cassert>
4
  #include <cstring>
5
  #include <algorithm>
 
6
 
7
  llama_ubatch llama_sbatch::reserve_ubatch(size_t n_ubatch, bool has_embd) {
8
  // clear empty sequences
@@ -105,12 +111,7 @@ void llama_sbatch::add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & s
105
  ubatch.seq_id = batch->seq_id + seq.offset;
106
  }
107
  }
108
- if (logits_all) {
109
- for (size_t i = 0; i < length; ++i) {
110
- ubatch.output[ubatch.n_tokens + i] = 1;
111
- out_ids.push_back(ids[seq.offset + i]);
112
- }
113
- } else if (batch->logits) {
114
  if (ubatch.equal_seqs) {
115
  for (size_t i = 0; i < length; ++i) {
116
  size_t id = ids[seq.offset + i];
@@ -197,11 +198,10 @@ llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) {
197
  return ubatch;
198
  }
199
 
200
- llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split, bool logits_all) {
201
  GGML_ASSERT(batch.n_tokens >= 0);
202
  this->batch = &batch;
203
  this->n_embd = n_embd;
204
- this->logits_all = logits_all;
205
 
206
  n_tokens = batch.n_tokens;
207
  ids.resize(n_tokens);
@@ -285,17 +285,56 @@ llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple
285
  );
286
  }
287
 
288
- llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0) {
289
- batch = in_batch;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
290
  GGML_ASSERT(batch.n_tokens > 0);
291
- if (!batch.pos) {
292
- assert(p0 >= 0);
293
- pos.resize(batch.n_tokens);
294
- for (int32_t i = 0; i < batch.n_tokens; i++) {
295
- pos[i] = p0 + i;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
296
  }
297
- batch.pos = pos.data();
298
  }
 
 
 
 
 
299
  if (!batch.n_seq_id) {
300
  n_seq_id.resize(batch.n_tokens);
301
  for (int32_t i = 0; i < batch.n_tokens; i++) {
@@ -303,6 +342,7 @@ llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0
303
  }
304
  batch.n_seq_id = n_seq_id.data();
305
  }
 
306
  if (!batch.seq_id) {
307
  seq_id.resize(batch.n_tokens + 1);
308
  seq_id[batch.n_tokens] = NULL;
@@ -311,10 +351,221 @@ llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0
311
  }
312
  batch.seq_id = seq_id.data();
313
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
314
  if (!batch.logits) {
315
- logits.resize(batch.n_tokens);
316
- logits[logits.size() - 1] = true;
317
- batch.logits = logits.data();
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
318
  }
319
  }
320
 
 
1
  #include "llama-batch.h"
2
 
3
+ #include "llama-impl.h"
4
+ #include "llama-cparams.h"
5
+ #include "llama-vocab.h"
6
+ #include "llama-memory.h"
7
+
8
  #include <cassert>
9
  #include <cstring>
10
  #include <algorithm>
11
+ #include <sstream>
12
 
13
  llama_ubatch llama_sbatch::reserve_ubatch(size_t n_ubatch, bool has_embd) {
14
  // clear empty sequences
 
111
  ubatch.seq_id = batch->seq_id + seq.offset;
112
  }
113
  }
114
+ if (batch->logits) {
 
 
 
 
 
115
  if (ubatch.equal_seqs) {
116
  for (size_t i = 0; i < length; ++i) {
117
  size_t id = ids[seq.offset + i];
 
198
  return ubatch;
199
  }
200
 
201
+ llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split) {
202
  GGML_ASSERT(batch.n_tokens >= 0);
203
  this->batch = &batch;
204
  this->n_embd = n_embd;
 
205
 
206
  n_tokens = batch.n_tokens;
207
  ids.resize(n_tokens);
 
285
  );
286
  }
287
 
288
+ llama_batch_allocr::llama_batch_allocr() {
289
+ const char * LLAMA_BATCH_DEBUG = getenv("LLAMA_BATCH_DEBUG");
290
+ debug = LLAMA_BATCH_DEBUG ? atoi(LLAMA_BATCH_DEBUG) : 0;
291
+
292
+ seq_pos.resize(LLAMA_MAX_SEQ);
293
+ seq_cpl.resize(LLAMA_MAX_SEQ);
294
+ for (auto & cur : seq_cpl) {
295
+ cur.resize(LLAMA_MAX_SEQ);
296
+ }
297
+ }
298
+
299
+ bool llama_batch_allocr::init(
300
+ const llama_batch & batch_inp,
301
+ const llama_vocab & vocab,
302
+ const llama_memory_i * memory,
303
+ bool embd_all) {
304
+ clear();
305
+
306
+ batch = batch_inp;
307
+
308
  GGML_ASSERT(batch.n_tokens > 0);
309
+
310
+ //
311
+ // validate input batch
312
+ //
313
+
314
+ if (batch.token) {
315
+ for (int32_t i = 0; i < batch.n_tokens; ++i) {
316
+ if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= vocab.n_tokens()) {
317
+ LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
318
+ return false;
319
+ }
320
+ }
321
+ }
322
+
323
+ if (batch.seq_id) {
324
+ for (int32_t i = 0; i < batch.n_tokens; ++i) {
325
+ for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
326
+ if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= LLAMA_MAX_SEQ)) {
327
+ LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s], LLAMA_MAX_SEQ);
328
+ return false;
329
+ }
330
+ }
331
  }
 
332
  }
333
+
334
+ //
335
+ // auto-generate missing fields
336
+ //
337
+
338
  if (!batch.n_seq_id) {
339
  n_seq_id.resize(batch.n_tokens);
340
  for (int32_t i = 0; i < batch.n_tokens; i++) {
 
342
  }
343
  batch.n_seq_id = n_seq_id.data();
344
  }
345
+
346
  if (!batch.seq_id) {
347
  seq_id.resize(batch.n_tokens + 1);
348
  seq_id[batch.n_tokens] = NULL;
 
351
  }
352
  batch.seq_id = seq_id.data();
353
  }
354
+
355
+ if (!batch.pos) {
356
+ pos.resize(batch.n_tokens);
357
+
358
+ // initialize the starting position for each sequence based on the positions in the memory
359
+ llama_pos p0[LLAMA_MAX_SEQ];
360
+ for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
361
+ if (!memory) {
362
+ p0[s] = 0;
363
+ } else {
364
+ p0[s] = memory->seq_pos_max(s) + 1;
365
+ }
366
+ }
367
+
368
+ for (int32_t i = 0; i < batch.n_tokens; i++) {
369
+ const llama_seq_id seq_id = batch.seq_id[i][0];
370
+
371
+ pos[i] = p0[seq_id];
372
+
373
+ for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
374
+ p0[batch.seq_id[i][s]] = pos[i] + 1;
375
+ }
376
+ }
377
+
378
+ batch.pos = pos.data();
379
+ }
380
+
381
  if (!batch.logits) {
382
+ if (embd_all) {
383
+ // return the output for all tokens
384
+ output.resize(batch.n_tokens, true);
385
+ } else {
386
+ // return the output only for the last token
387
+ output.resize(batch.n_tokens, false);
388
+ output[output.size() - 1] = true;
389
+ }
390
+
391
+ batch.logits = output.data();
392
+ } else if (embd_all) {
393
+ bool warn = false;
394
+
395
+ for (int32_t i = 0; i < batch.n_tokens; ++i) {
396
+ if (batch.logits[i] == 0) {
397
+ warn = true;
398
+ }
399
+ }
400
+
401
+ if (warn) {
402
+ LLAMA_LOG_WARN("%s: embeddings required but some input tokens were not marked as outputs -> overriding\n", __func__);
403
+
404
+ output.resize(batch.n_tokens, true);
405
+ batch.logits = output.data();
406
+ }
407
+ }
408
+
409
+ //
410
+ // compute stats
411
+ //
412
+
413
+ for (int32_t i = 0; i < batch.n_tokens; ++i) {
414
+ n_outputs += batch.logits[i] != 0;
415
+ }
416
+
417
+ // determine coupled sequences
418
+ // these are pairs of sequences that have at least one token in the input batch that is assigned to both of them
419
+ for (int32_t i = 0; i < batch.n_tokens; ++i) {
420
+ for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
421
+ seq_pos[batch.seq_id[i][s]].insert(batch.pos[i]);
422
+
423
+ if (s > 0) {
424
+ const llama_seq_id s0 = batch.seq_id[i][0];
425
+ const llama_seq_id s1 = batch.seq_id[i][s];
426
+
427
+ // mark that sequence s1 is coupled to s0
428
+ seq_cpl[s1][s0] = true;
429
+
430
+ // note: the other way around is not necessary for now
431
+ //seq_cpl[s0][s1] = true;
432
+ }
433
+ }
434
+ }
435
+
436
+ if (debug > 0) {
437
+ LLAMA_LOG_DEBUG("%s: input batch info:\n", __func__);
438
+ LLAMA_LOG_DEBUG("%s: n_tokens = %d\n", __func__, batch.n_tokens);
439
+ LLAMA_LOG_DEBUG("%s: token = %p\n", __func__, (void *) batch.token);
440
+ LLAMA_LOG_DEBUG("%s: embd = %p\n", __func__, (void *) batch.embd);
441
+ LLAMA_LOG_DEBUG("%s: pos = %p\n", __func__, (void *) batch.pos);
442
+ LLAMA_LOG_DEBUG("%s: n_seq_id = %p\n", __func__, (void *) batch.n_seq_id);
443
+ LLAMA_LOG_DEBUG("%s: seq_id = %p\n", __func__, (void *) batch.seq_id);
444
+ LLAMA_LOG_DEBUG("%s: logits = %p\n", __func__, (void *) batch.logits);
445
+ LLAMA_LOG_DEBUG("%s: n_outputs = %d\n", __func__, n_outputs);
446
+
447
+ if (debug > 1) {
448
+ int seq_id_max = 0;
449
+ for (int32_t i = 0; i < batch.n_tokens; ++i) {
450
+ for (int s = 0; s < batch.n_seq_id[i]; ++s) {
451
+ for (int s = 0; s < batch.n_seq_id[i]; ++s) {
452
+ seq_id_max = std::max(seq_id_max, batch.seq_id[i][s]);
453
+ }
454
+ }
455
+ }
456
+ ++seq_id_max;
457
+
458
+ LLAMA_LOG_DEBUG("%s: token = [\n", __func__);
459
+ for (int32_t i = 0; i < batch.n_tokens; ++i) {
460
+ std::vector<int8_t> seq_id(seq_id_max);
461
+
462
+ for (int s = 0; s < batch.n_seq_id[i]; ++s) {
463
+ seq_id[batch.seq_id[i][s]] = 1;
464
+ }
465
+
466
+ std::stringstream ss;
467
+ for (int s = 0; s < seq_id_max; ++s) {
468
+ if (seq_id[s]) {
469
+ ss << s%10;
470
+ } else {
471
+ ss << ".";
472
+ }
473
+ }
474
+
475
+ LLAMA_LOG_DEBUG("%s: %4d: id = %6d (%16s), pos = %4d, n_seq_id = %2d, seq_id = [%s], output = %d\n",
476
+ __func__, i, batch.token[i], vocab.token_to_piece(batch.token[i]).c_str(),
477
+ batch.pos[i], batch.n_seq_id[i], ss.str().c_str(), batch.logits[i]);
478
+ }
479
+ LLAMA_LOG_DEBUG("%s: ]\n", __func__);
480
+
481
+ LLAMA_LOG_DEBUG("%s: seq = [\n", __func__);
482
+ for (int s0 = 0; s0 < (int) seq_pos.size(); ++s0) {
483
+ if (seq_pos[s0].empty()) {
484
+ continue;
485
+ }
486
+
487
+ std::stringstream ss;
488
+ for (int s1 = 0; s1 < (int) seq_cpl[s0].size(); ++s1) {
489
+ if (seq_cpl[s0][s1]) {
490
+ ss << s1 << " ";
491
+ }
492
+ }
493
+
494
+ LLAMA_LOG_DEBUG("%s: %4d: pos = [%4d, %4d], cpl = %s\n",
495
+ __func__, s0, seq_pos_min(s0), seq_pos_max(s0), ss.str().empty() ? "-" : ss.str().c_str());
496
+ }
497
+ LLAMA_LOG_DEBUG("%s: ]\n", __func__);
498
+ }
499
+ }
500
+
501
+ //
502
+ // consistency checks
503
+ //
504
+
505
+ for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
506
+ if (seq_pos[s].empty()) {
507
+ continue;
508
+ }
509
+
510
+ if (memory && seq_pos_min(s) != memory->seq_pos_max(s) + 1) {
511
+ LLAMA_LOG_ERROR("%s: sequence %d does not start from the last position stored in the memory\n", __func__, s);
512
+ return false;
513
+ }
514
+
515
+ if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) {
516
+ LLAMA_LOG_ERROR("%s: sequence %d positions are not continuous\n", __func__, s);
517
+ return false;
518
+ }
519
+ }
520
+
521
+ if (memory) {
522
+ for (int32_t s0 = 0; s0 < LLAMA_MAX_SEQ; ++s0) {
523
+ for (int32_t s1 = 0; s1 < LLAMA_MAX_SEQ; ++s1) {
524
+ if (seq_cpl[s0][s1]) {
525
+ if (memory->seq_pos_min(s0) != memory->seq_pos_min(s1) ||
526
+ memory->seq_pos_max(s0) != memory->seq_pos_max(s1)) {
527
+ LLAMA_LOG_ERROR("%s: sequence %d is coupled to %d in the input batch, but have divereged\n", __func__, s0, s1);
528
+ return false;
529
+ }
530
+ }
531
+ }
532
+ }
533
+ }
534
+
535
+ return true;
536
+ }
537
+
538
+ const llama_batch & llama_batch_allocr::get_batch() const {
539
+ return batch;
540
+ }
541
+
542
+ uint32_t llama_batch_allocr::get_n_outputs() const {
543
+ return n_outputs;
544
+ }
545
+
546
+ llama_pos llama_batch_allocr::seq_pos_min(llama_seq_id seq_id) const {
547
+ return seq_pos[seq_id].empty() ? -1 : *seq_pos[seq_id].begin();
548
+ }
549
+
550
+ llama_pos llama_batch_allocr::seq_pos_max(llama_seq_id seq_id) const {
551
+ return seq_pos[seq_id].empty() ? -1 : *seq_pos[seq_id].rbegin();
552
+ }
553
+
554
+ void llama_batch_allocr::clear() {
555
+ n_outputs = 0;
556
+
557
+ batch = {};
558
+ pos.clear();
559
+ n_seq_id.clear();
560
+ seq_id.clear();
561
+ output.clear();
562
+
563
+ for (auto & cur : seq_pos) {
564
+ cur.clear();
565
+ }
566
+
567
+ for (auto & cur : seq_cpl) {
568
+ std::fill(cur.begin(), cur.end(), false);
569
  }
570
  }
571
 
examples/talk-llama/llama-batch.h CHANGED
@@ -4,6 +4,7 @@
4
 
5
  #include <array>
6
  #include <vector>
 
7
 
8
  // very similar to llama_batch,
9
  // but has more metadata about sequences
@@ -18,8 +19,8 @@ struct llama_ubatch {
18
  llama_token * token; // [n_tokens]
19
  float * embd; // [n_embd, n_tokens]
20
  llama_pos * pos; // [n_tokens]
21
- int32_t * n_seq_id; // [n_seqs] // TODO: remove, should belong to only 1 sequence
22
- llama_seq_id ** seq_id; // [n_seqs] // TODO: become llama_seq_id * seq_id;
23
  int8_t * output; // [n_tokens]
24
  };
25
 
@@ -39,8 +40,6 @@ struct llama_sbatch {
39
 
40
  size_t n_embd;
41
 
42
- bool logits_all; // TODO: remove once lctx.logits_all is removed too
43
-
44
  // sorted indices into the batch
45
  std::vector<int64_t> ids;
46
  // batch indices of the output
@@ -76,19 +75,45 @@ struct llama_sbatch {
76
  llama_ubatch split_seq(size_t n_ubatch);
77
 
78
  llama_sbatch() = default;
79
- llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split = false, bool logits_all = false);
80
  };
81
 
82
- // temporary allocate memory for the input batch if needed
83
- struct llama_batch_allocr {
84
- struct llama_batch batch;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
  std::array<llama_seq_id, 1> seq_id_0 = { 0 }; // default sequence id
 
87
  std::vector<llama_pos> pos;
88
  std::vector<int32_t> n_seq_id;
89
  std::vector<llama_seq_id *> seq_id;
90
- std::vector<int8_t> logits;
 
 
 
91
 
92
- // optionally fulfill the batch returned by llama_batch_get_one
93
- llama_batch_allocr(struct llama_batch in_batch, llama_pos p0);
94
  };
 
4
 
5
  #include <array>
6
  #include <vector>
7
+ #include <set>
8
 
9
  // very similar to llama_batch,
10
  // but has more metadata about sequences
 
19
  llama_token * token; // [n_tokens]
20
  float * embd; // [n_embd, n_tokens]
21
  llama_pos * pos; // [n_tokens]
22
+ int32_t * n_seq_id; // [n_seqs]
23
+ llama_seq_id ** seq_id; // [n_seqs]
24
  int8_t * output; // [n_tokens]
25
  };
26
 
 
40
 
41
  size_t n_embd;
42
 
 
 
43
  // sorted indices into the batch
44
  std::vector<int64_t> ids;
45
  // batch indices of the output
 
75
  llama_ubatch split_seq(size_t n_ubatch);
76
 
77
  llama_sbatch() = default;
78
+ llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split = false);
79
  };
80
 
81
+ // a helper for sanitizing and fulfilling a batch
82
+ class llama_batch_allocr {
83
+ public:
84
+ llama_batch_allocr();
85
+
86
+ // sanitize and auto-gen missing data in the input batch
87
+ // memory is optional. if provided will be used to check for sequence continuity and to determine the positions
88
+ bool init(
89
+ const llama_batch & batch_inp,
90
+ const llama_vocab & vocab,
91
+ const llama_memory_i * memory,
92
+ bool embd_all);
93
+
94
+ const llama_batch & get_batch() const;
95
+
96
+ uint32_t get_n_outputs() const;
97
+
98
+ llama_pos seq_pos_min(llama_seq_id seq_id) const;
99
+ llama_pos seq_pos_max(llama_seq_id seq_id) const;
100
+
101
+ private:
102
+ void clear();
103
+
104
+ llama_batch batch;
105
+
106
+ uint32_t n_outputs;
107
 
108
  std::array<llama_seq_id, 1> seq_id_0 = { 0 }; // default sequence id
109
+
110
  std::vector<llama_pos> pos;
111
  std::vector<int32_t> n_seq_id;
112
  std::vector<llama_seq_id *> seq_id;
113
+ std::vector<int8_t> output;
114
+
115
+ std::vector<std::set<llama_pos>> seq_pos; // seq_pos[s]: the set of positions in sequence s
116
+ std::vector<std::vector<bool>> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1
117
 
118
+ int debug;
 
119
  };
examples/talk-llama/llama-chat.cpp CHANGED
@@ -183,6 +183,8 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
183
  return LLM_CHAT_TEMPLATE_BAILING;
184
  } else if (tmpl_contains("<|header_start|>") && tmpl_contains("<|header_end|>")) {
185
  return LLM_CHAT_TEMPLATE_LLAMA4;
 
 
186
  }
187
  return LLM_CHAT_TEMPLATE_UNKNOWN;
188
  }
@@ -643,6 +645,21 @@ int32_t llm_chat_apply_template(
643
  if (add_ass) {
644
  ss << "Assistant:";
645
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
646
  } else {
647
  // template not supported
648
  return -1;
 
183
  return LLM_CHAT_TEMPLATE_BAILING;
184
  } else if (tmpl_contains("<|header_start|>") && tmpl_contains("<|header_end|>")) {
185
  return LLM_CHAT_TEMPLATE_LLAMA4;
186
+ } else if (tmpl_contains("<|endofuserprompt|>")) {
187
+ return LLM_CHAT_TEMPLATE_DOTS1;
188
  }
189
  return LLM_CHAT_TEMPLATE_UNKNOWN;
190
  }
 
645
  if (add_ass) {
646
  ss << "Assistant:";
647
  }
648
+ } else if (tmpl == LLM_CHAT_TEMPLATE_DOTS1) {
649
+ // dots.llm1.inst (DOTS1)
650
+ for (auto message : chat) {
651
+ std::string role(message->role);
652
+ if (role == "system") {
653
+ ss << "<|system|>" << message->content << "<|endofsystem|>";
654
+ } else if (role == "user") {
655
+ ss << "<|userprompt|>" << message->content << "<|endofuserprompt|>";
656
+ } else {
657
+ ss << "<|response|>" << message->content << "<|endofresponse|>";
658
+ }
659
+ }
660
+ if (add_ass) {
661
+ ss << "<|response|>";
662
+ }
663
  } else {
664
  // template not supported
665
  return -1;
examples/talk-llama/llama-chat.h CHANGED
@@ -43,6 +43,7 @@ enum llm_chat_template {
43
  LLM_CHAT_TEMPLATE_BAILING,
44
  LLM_CHAT_TEMPLATE_LLAMA4,
45
  LLM_CHAT_TEMPLATE_SMOLVLM,
 
46
  LLM_CHAT_TEMPLATE_UNKNOWN,
47
  };
48
 
 
43
  LLM_CHAT_TEMPLATE_BAILING,
44
  LLM_CHAT_TEMPLATE_LLAMA4,
45
  LLM_CHAT_TEMPLATE_SMOLVLM,
46
+ LLM_CHAT_TEMPLATE_DOTS1,
47
  LLM_CHAT_TEMPLATE_UNKNOWN,
48
  };
49
 
examples/talk-llama/llama-context.cpp CHANGED
@@ -1,6 +1,7 @@
1
  #include "llama-context.h"
2
 
3
  #include "llama-impl.h"
 
4
  #include "llama-io.h"
5
  #include "llama-memory.h"
6
  #include "llama-mmap.h"
@@ -18,7 +19,8 @@
18
  llama_context::llama_context(
19
  const llama_model & model,
20
  llama_context_params params) :
21
- model(model) {
 
22
  LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__);
23
 
24
  t_start_us = model.t_start_us;
@@ -27,8 +29,8 @@ llama_context::llama_context(
27
  const auto & hparams = model.hparams;
28
 
29
  cparams.n_seq_max = std::max(1u, params.n_seq_max);
30
- if (cparams.n_seq_max > LLAMA_MAX_PARALLEL_SEQUENCES) {
31
- throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_PARALLEL_SEQUENCES));
32
  }
33
 
34
  cparams.n_threads = params.n_threads;
@@ -494,7 +496,7 @@ float * llama_context::get_logits() {
494
  }
495
 
496
  float * llama_context::get_logits_ith(int32_t i) {
497
- int32_t j = -1;
498
 
499
  try {
500
  if (logits == nullptr) {
@@ -517,7 +519,7 @@ float * llama_context::get_logits_ith(int32_t i) {
517
  }
518
  if (j >= n_outputs) {
519
  // This should not happen
520
- throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, n_outputs));
521
  }
522
 
523
  return logits + j*model.vocab.n_tokens();
@@ -536,7 +538,7 @@ float * llama_context::get_embeddings() {
536
  }
537
 
538
  float * llama_context::get_embeddings_ith(int32_t i) {
539
- int32_t j = -1;
540
 
541
  try {
542
  if (embd == nullptr) {
@@ -559,7 +561,7 @@ float * llama_context::get_embeddings_ith(int32_t i) {
559
  }
560
  if (j >= n_outputs) {
561
  // This should not happen
562
- throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, n_outputs));
563
  }
564
 
565
  return embd + j*model.hparams.n_embd;
@@ -719,52 +721,41 @@ llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch,
719
  return res;
720
  }
721
 
722
- int llama_context::encode(llama_batch & inp_batch) {
723
- if (inp_batch.n_tokens == 0) {
724
  LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
725
  return -1;
726
  }
727
 
728
- // temporary allocate memory for the input batch if needed
729
  // note: during encode, we always pass the full sequence starting from pos = 0
730
- llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : 0);
 
 
 
731
 
732
- const llama_batch & batch = batch_allocr.batch;
733
- const int32_t n_tokens = batch.n_tokens;
734
 
735
- const auto & hparams = model.hparams;
736
 
737
  GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
738
 
739
- // TODO: move the validation to the llama_batch_allocr
740
- if (batch.token) {
741
- for (int32_t i = 0; i < n_tokens; ++i) {
742
- if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
743
- LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
744
- return -1;
745
- }
746
-
747
- if (batch.seq_id && (batch.seq_id[i][0] < 0 || batch.seq_id[i][0] >= LLAMA_MAX_PARALLEL_SEQUENCES)) {
748
- LLAMA_LOG_ERROR("%s: invalid seq_id[%d] = %d > %d\n", __func__, i, batch.seq_id[i][0], LLAMA_MAX_PARALLEL_SEQUENCES);
749
- throw -1;
750
- }
751
- }
752
- }
753
-
754
  // micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
755
- GGML_ASSERT(cparams.n_ubatch >= (uint32_t) n_tokens && "encoder requires n_ubatch >= n_tokens");
756
 
757
  if (t_compute_start_us == 0) {
758
  t_compute_start_us = ggml_time_us();
759
  }
760
 
 
761
  embd_seq.clear();
762
 
763
  n_queued_tokens += n_tokens;
764
 
 
 
765
  const int64_t n_embd = hparams.n_embd;
766
 
767
- llama_sbatch sbatch = llama_sbatch(batch, n_embd, /* simple_split */ true, /* logits_all */ true);
768
 
769
  const llama_ubatch ubatch = sbatch.split_simple(n_tokens);
770
 
@@ -774,7 +765,7 @@ int llama_context::encode(llama_batch & inp_batch) {
774
  return -2;
775
  };
776
 
777
- for (int32_t i = 0; i < n_tokens; ++i) {
778
  output_ids[i] = i;
779
  }
780
 
@@ -830,7 +821,8 @@ int llama_context::encode(llama_batch & inp_batch) {
830
 
831
  GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits
832
 
833
- for (int32_t i = 0; i < n_tokens; i++) {
 
834
  const llama_seq_id seq_id = ubatch.seq_id[i][0];
835
  if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
836
  continue;
@@ -845,6 +837,7 @@ int llama_context::encode(llama_batch & inp_batch) {
845
  auto & embd_seq_out = embd_seq;
846
  const uint32_t n_cls_out = hparams.n_cls_out;
847
 
 
848
  for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
849
  const llama_seq_id seq_id = ubatch.seq_id[s][0];
850
  if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
@@ -878,10 +871,10 @@ int llama_context::encode(llama_batch & inp_batch) {
878
 
879
  // remember the sequence ids used during the encoding - needed for cross attention later
880
  cross.seq_ids_enc.resize(n_tokens);
881
- for (int32_t i = 0; i < n_tokens; i++) {
882
  cross.seq_ids_enc[i].clear();
883
- for (int s = 0; s < ubatch.n_seq_id[i]; s++) {
884
- llama_seq_id seq_id = ubatch.seq_id[i][s];
885
  cross.seq_ids_enc[i].insert(seq_id);
886
  }
887
  }
@@ -890,51 +883,45 @@ int llama_context::encode(llama_batch & inp_batch) {
890
  return 0;
891
  }
892
 
893
- int llama_context::decode(llama_batch & inp_batch) {
894
  if (!memory) {
895
  LLAMA_LOG_DEBUG("%s: cannot decode batches with this context (calling encode() instead)\n", __func__);
896
- return encode(inp_batch);
897
  }
898
 
899
- if (inp_batch.n_tokens == 0) {
900
  LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
901
  return -1;
902
  }
903
 
904
- if (!inp_batch.pos) {
905
- if (inp_batch.seq_id) {
906
- LLAMA_LOG_ERROR("%s: pos == NULL, but seq_id != NULL\n", __func__);
907
- return -1;
908
- }
909
- }
910
 
911
- // temporary allocate memory for the input batch if needed
912
- llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : memory->seq_pos_max(0) + 1);
 
 
913
 
914
- const llama_batch & batch = batch_allocr.batch;
915
 
916
  const auto & vocab = model.vocab;
917
  const auto & hparams = model.hparams;
918
 
919
  const int32_t n_vocab = vocab.n_tokens();
 
920
 
921
- const int64_t n_tokens_all = batch.n_tokens;
922
- const int64_t n_embd = hparams.n_embd;
923
 
924
  GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
925
 
926
- // TODO: move the validation to the llama_batch_allocr
927
- if (batch.token) {
928
- for (int64_t i = 0; i < n_tokens_all; ++i) {
929
- if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
930
- LLAMA_LOG_ERROR("%s: invalid token[%" PRId64 "] = %d\n", __func__, i, batch.token[i]);
931
- return -1;
932
- }
933
 
934
- if (batch.seq_id && (batch.seq_id[i][0] < 0 || batch.seq_id[i][0] >= LLAMA_MAX_PARALLEL_SEQUENCES)) {
935
- LLAMA_LOG_ERROR("%s: invalid seq_id[%" PRId64 "] = %d >= %d\n", __func__, i, batch.seq_id[i][0], LLAMA_MAX_PARALLEL_SEQUENCES);
936
- return -1;
937
- }
 
 
938
  }
939
  }
940
 
@@ -947,25 +934,9 @@ int llama_context::decode(llama_batch & inp_batch) {
947
  }
948
  n_queued_tokens += n_tokens_all;
949
 
950
- // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
951
- const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
952
-
953
  embd_seq.clear();
954
 
955
- int64_t n_outputs_all = 0;
956
-
957
- // count outputs
958
- if (batch.logits && !embd_pooled) {
959
- for (uint32_t i = 0; i < n_tokens_all; ++i) {
960
- n_outputs_all += batch.logits[i] != 0;
961
- }
962
- } else if (embd_pooled) {
963
- n_outputs_all = n_tokens_all;
964
- } else {
965
- // keep last output only
966
- n_outputs_all = 1;
967
- }
968
-
969
  bool did_optimize = false;
970
 
971
  // handle any pending defrags/shifts
@@ -974,7 +945,7 @@ int llama_context::decode(llama_batch & inp_batch) {
974
  llama_memory_state_ptr mstate;
975
 
976
  while (true) {
977
- mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all);
978
  if (!mstate) {
979
  return -2;
980
  }
@@ -1018,7 +989,7 @@ int llama_context::decode(llama_batch & inp_batch) {
1018
 
1019
  // reserve output buffer
1020
  if (output_reserve(n_outputs_all) < n_outputs_all) {
1021
- LLAMA_LOG_ERROR("%s: could not reserve space for batch with %" PRId64 " outputs\n", __func__, n_outputs_all);
1022
  return -2;
1023
  };
1024
 
@@ -1027,7 +998,7 @@ int llama_context::decode(llama_batch & inp_batch) {
1027
  do {
1028
  const auto & ubatch = mstate->get_ubatch();
1029
 
1030
- // count the outputs in this u_batch
1031
  {
1032
  int32_t n_outputs_new = 0;
1033
 
@@ -1052,18 +1023,19 @@ int llama_context::decode(llama_batch & inp_batch) {
1052
 
1053
  if (!res) {
1054
  // the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
1055
- llama_pos pos_min[LLAMA_MAX_PARALLEL_SEQUENCES];
1056
- for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
1057
  pos_min[s] = std::numeric_limits<llama_pos>::max();
1058
  }
1059
 
 
1060
  for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
1061
  const auto & seq_id = ubatch.seq_id[i][0];
1062
 
1063
  pos_min[seq_id] = std::min(pos_min[seq_id], ubatch.pos[i]);
1064
  }
1065
 
1066
- for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
1067
  if (pos_min[s] == std::numeric_limits<llama_pos>::max()) {
1068
  continue;
1069
  }
@@ -1086,7 +1058,7 @@ int llama_context::decode(llama_batch & inp_batch) {
1086
  // ggml_graph_dump_dot(gf, NULL, "llama.dot");
1087
  //}
1088
 
1089
- auto * t_logits = cparams.embeddings ? nullptr : res->get_logits();
1090
  auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr;
1091
 
1092
  if (t_embd && res->get_embd_pooled()) {
@@ -1170,14 +1142,14 @@ int llama_context::decode(llama_batch & inp_batch) {
1170
  n_outputs = n_outputs_all;
1171
 
1172
  // set output mappings
1173
- {
1174
  bool sorted_output = true;
1175
 
1176
  auto & out_ids = mstate->out_ids();
1177
 
1178
- GGML_ASSERT(out_ids.size() == (size_t) n_outputs_all);
1179
 
1180
- for (int64_t i = 0; i < n_outputs_all; ++i) {
1181
  int64_t out_id = out_ids[i];
1182
  output_ids[out_id] = i;
1183
  if (out_id != i) {
@@ -1189,20 +1161,22 @@ int llama_context::decode(llama_batch & inp_batch) {
1189
  // note: this is mostly relevant for recurrent models atm
1190
  if (!sorted_output) {
1191
  const uint32_t n_vocab = model.vocab.n_tokens();
1192
- const uint32_t n_embd = model.hparams.n_embd;
1193
 
1194
  GGML_ASSERT((size_t) n_outputs == out_ids.size());
1195
 
1196
  // TODO: is there something more efficient which also minimizes swaps?
1197
  // selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
1198
- for (int32_t i = 0; i < n_outputs - 1; ++i) {
1199
- int32_t j_min = i;
1200
- for (int32_t j = i + 1; j < n_outputs; ++j) {
1201
  if (out_ids[j] < out_ids[j_min]) {
1202
  j_min = j;
1203
  }
1204
  }
1205
- if (j_min == i) { continue; }
 
 
1206
  std::swap(out_ids[i], out_ids[j_min]);
1207
  if (logits_size > 0) {
1208
  for (uint32_t k = 0; k < n_vocab; k++) {
@@ -1215,8 +1189,10 @@ int llama_context::decode(llama_batch & inp_batch) {
1215
  }
1216
  }
1217
  }
 
1218
  std::fill(output_ids.begin(), output_ids.end(), -1);
1219
- for (int32_t i = 0; i < n_outputs; ++i) {
 
1220
  output_ids[out_ids[i]] = i;
1221
  }
1222
  }
@@ -1236,7 +1212,7 @@ int llama_context::decode(llama_batch & inp_batch) {
1236
  // output
1237
  //
1238
 
1239
- int32_t llama_context::output_reserve(int32_t n_outputs) {
1240
  const auto & hparams = model.hparams;
1241
  const auto & vocab = model.vocab;
1242
 
@@ -1246,9 +1222,8 @@ int32_t llama_context::output_reserve(int32_t n_outputs) {
1246
  const auto n_vocab = vocab.n_tokens();
1247
  const auto n_embd = hparams.n_embd;
1248
 
1249
- // TODO: use a per-batch flag for logits presence instead
1250
- bool has_logits = !cparams.embeddings;
1251
- bool has_embd = cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
1252
 
1253
  // TODO: hacky enc-dec support
1254
  if (model.arch == LLM_ARCH_T5) {
@@ -1302,8 +1277,7 @@ int32_t llama_context::output_reserve(int32_t n_outputs) {
1302
  // set all ids as invalid (negative)
1303
  std::fill(output_ids.begin(), output_ids.end(), -1);
1304
 
1305
- this->n_outputs = 0;
1306
- this->n_outputs_max = n_outputs_max;
1307
 
1308
  return n_outputs_max;
1309
  }
@@ -1332,7 +1306,7 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
1332
  LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs);
1333
 
1334
  if (n_tokens % n_seqs != 0) {
1335
- n_tokens = (n_tokens / n_seqs) * n_seqs;
1336
  n_outputs = std::min(n_outputs, n_tokens);
1337
 
1338
  LLAMA_LOG_DEBUG("%s: making n_tokens a multiple of n_seqs - n_tokens = %u, n_seqs = %u, n_outputs = %u\n", __func__, n_tokens, n_seqs, n_outputs);
@@ -1794,14 +1768,12 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
1794
 
1795
  std::vector<int32_t> w_output_pos;
1796
 
1797
- GGML_ASSERT(n_outputs <= n_outputs_max);
1798
-
1799
  w_output_pos.resize(n_outputs);
1800
 
1801
  // build a more compact representation of the output ids
1802
  for (size_t i = 0; i < n_batch(); ++i) {
1803
  // map an output id to a position in the batch
1804
- int32_t pos = output_ids[i];
1805
  if (pos >= 0) {
1806
  GGML_ASSERT(pos < n_outputs);
1807
  w_output_pos[pos] = i;
@@ -2071,14 +2043,11 @@ void llama_context::opt_epoch_iter(
2071
 
2072
  n_queued_tokens += n_tokens_all;
2073
 
2074
- // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
2075
- const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
2076
-
2077
  embd_seq.clear();
2078
 
2079
- int64_t n_outputs_all = n_tokens_all;
2080
 
2081
- auto mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ true);
2082
  if (!mstate || mstate->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
2083
  LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
2084
  break;
@@ -2086,7 +2055,7 @@ void llama_context::opt_epoch_iter(
2086
 
2087
  // reserve output buffer
2088
  if (output_reserve(n_outputs_all) < n_outputs_all) {
2089
- LLAMA_LOG_ERROR("%s: could not reserve space for batch with %" PRId64 " outputs\n", __func__, n_outputs_all);
2090
  GGML_ABORT("TODO: handle this error");
2091
  };
2092
 
 
1
  #include "llama-context.h"
2
 
3
  #include "llama-impl.h"
4
+ #include "llama-batch.h"
5
  #include "llama-io.h"
6
  #include "llama-memory.h"
7
  #include "llama-mmap.h"
 
19
  llama_context::llama_context(
20
  const llama_model & model,
21
  llama_context_params params) :
22
+ model(model),
23
+ batch_allocr(std::make_unique<llama_batch_allocr>()) {
24
  LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__);
25
 
26
  t_start_us = model.t_start_us;
 
29
  const auto & hparams = model.hparams;
30
 
31
  cparams.n_seq_max = std::max(1u, params.n_seq_max);
32
+ if (cparams.n_seq_max > LLAMA_MAX_SEQ) {
33
+ throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_SEQ));
34
  }
35
 
36
  cparams.n_threads = params.n_threads;
 
496
  }
497
 
498
  float * llama_context::get_logits_ith(int32_t i) {
499
+ int64_t j = -1;
500
 
501
  try {
502
  if (logits == nullptr) {
 
519
  }
520
  if (j >= n_outputs) {
521
  // This should not happen
522
+ throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
523
  }
524
 
525
  return logits + j*model.vocab.n_tokens();
 
538
  }
539
 
540
  float * llama_context::get_embeddings_ith(int32_t i) {
541
+ int64_t j = -1;
542
 
543
  try {
544
  if (embd == nullptr) {
 
561
  }
562
  if (j >= n_outputs) {
563
  // This should not happen
564
+ throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
565
  }
566
 
567
  return embd + j*model.hparams.n_embd;
 
721
  return res;
722
  }
723
 
724
+ int llama_context::encode(const llama_batch & batch_inp) {
725
+ if (batch_inp.n_tokens == 0) {
726
  LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
727
  return -1;
728
  }
729
 
 
730
  // note: during encode, we always pass the full sequence starting from pos = 0
731
+ if (!batch_allocr->init(batch_inp, model.vocab, nullptr, true)) {
732
+ LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
733
+ return -1;
734
+ }
735
 
736
+ const llama_batch & batch = batch_allocr->get_batch();
 
737
 
738
+ const uint32_t n_tokens = batch.n_tokens;
739
 
740
  GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
741
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
742
  // micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
743
+ GGML_ASSERT(cparams.n_ubatch >= n_tokens && "encoder requires n_ubatch >= n_tokens");
744
 
745
  if (t_compute_start_us == 0) {
746
  t_compute_start_us = ggml_time_us();
747
  }
748
 
749
+ // TODO: this clear of the buffer can easily be forgotten - need something better
750
  embd_seq.clear();
751
 
752
  n_queued_tokens += n_tokens;
753
 
754
+ const auto & hparams = model.hparams;
755
+
756
  const int64_t n_embd = hparams.n_embd;
757
 
758
+ llama_sbatch sbatch = llama_sbatch(batch, n_embd, /* simple_split */ true);
759
 
760
  const llama_ubatch ubatch = sbatch.split_simple(n_tokens);
761
 
 
765
  return -2;
766
  };
767
 
768
+ for (uint32_t i = 0; i < n_tokens; ++i) {
769
  output_ids[i] = i;
770
  }
771
 
 
821
 
822
  GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits
823
 
824
+ // TODO: fix indexing [UBATCH_IDX]
825
+ for (uint32_t i = 0; i < n_tokens; i++) {
826
  const llama_seq_id seq_id = ubatch.seq_id[i][0];
827
  if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
828
  continue;
 
837
  auto & embd_seq_out = embd_seq;
838
  const uint32_t n_cls_out = hparams.n_cls_out;
839
 
840
+ // TODO: fix indexing [UBATCH_IDX]
841
  for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
842
  const llama_seq_id seq_id = ubatch.seq_id[s][0];
843
  if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
 
871
 
872
  // remember the sequence ids used during the encoding - needed for cross attention later
873
  cross.seq_ids_enc.resize(n_tokens);
874
+ for (uint32_t i = 0; i < n_tokens; i++) {
875
  cross.seq_ids_enc[i].clear();
876
+ for (int s = 0; s < batch.n_seq_id[i]; s++) {
877
+ llama_seq_id seq_id = batch.seq_id[i][s];
878
  cross.seq_ids_enc[i].insert(seq_id);
879
  }
880
  }
 
883
  return 0;
884
  }
885
 
886
+ int llama_context::decode(const llama_batch & batch_inp) {
887
  if (!memory) {
888
  LLAMA_LOG_DEBUG("%s: cannot decode batches with this context (calling encode() instead)\n", __func__);
889
+ return encode(batch_inp);
890
  }
891
 
892
+ if (batch_inp.n_tokens == 0) {
893
  LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
894
  return -1;
895
  }
896
 
897
+ // when computing embeddings, all tokens are output
898
+ const bool embd_all = cparams.embeddings;
 
 
 
 
899
 
900
+ if (!batch_allocr->init(batch_inp, model.vocab, memory.get(), embd_all)) {
901
+ LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
902
+ return -1;
903
+ }
904
 
905
+ const llama_batch & batch = batch_allocr->get_batch();
906
 
907
  const auto & vocab = model.vocab;
908
  const auto & hparams = model.hparams;
909
 
910
  const int32_t n_vocab = vocab.n_tokens();
911
+ const int64_t n_embd = hparams.n_embd;
912
 
913
+ const uint32_t n_tokens_all = batch.n_tokens;
 
914
 
915
  GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
916
 
917
+ const uint32_t n_outputs_all = batch_allocr->get_n_outputs();
 
 
 
 
 
 
918
 
919
+ if (embd_all) {
920
+ // require that all tokens are output
921
+ if (n_outputs_all != n_tokens_all) {
922
+ LLAMA_LOG_ERROR("%s: pooled embedding requires that all tokens are output (n_outputs_all = %d, n_tokens_all = %d)\n",
923
+ __func__, n_outputs_all, n_tokens_all);
924
+ return -1;
925
  }
926
  }
927
 
 
934
  }
935
  n_queued_tokens += n_tokens_all;
936
 
937
+ // TODO: this clear of the buffer can easily be forgotten - need something better
 
 
938
  embd_seq.clear();
939
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
940
  bool did_optimize = false;
941
 
942
  // handle any pending defrags/shifts
 
945
  llama_memory_state_ptr mstate;
946
 
947
  while (true) {
948
+ mstate = memory->init_batch(batch, cparams.n_ubatch, embd_all);
949
  if (!mstate) {
950
  return -2;
951
  }
 
989
 
990
  // reserve output buffer
991
  if (output_reserve(n_outputs_all) < n_outputs_all) {
992
+ LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all);
993
  return -2;
994
  };
995
 
 
998
  do {
999
  const auto & ubatch = mstate->get_ubatch();
1000
 
1001
+ // count the outputs in this ubatch
1002
  {
1003
  int32_t n_outputs_new = 0;
1004
 
 
1023
 
1024
  if (!res) {
1025
  // the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
1026
+ llama_pos pos_min[LLAMA_MAX_SEQ];
1027
+ for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
1028
  pos_min[s] = std::numeric_limits<llama_pos>::max();
1029
  }
1030
 
1031
+ // TODO: fix sequence indexing
1032
  for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
1033
  const auto & seq_id = ubatch.seq_id[i][0];
1034
 
1035
  pos_min[seq_id] = std::min(pos_min[seq_id], ubatch.pos[i]);
1036
  }
1037
 
1038
+ for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
1039
  if (pos_min[s] == std::numeric_limits<llama_pos>::max()) {
1040
  continue;
1041
  }
 
1058
  // ggml_graph_dump_dot(gf, NULL, "llama.dot");
1059
  //}
1060
 
1061
+ auto * t_logits = res->get_logits();
1062
  auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr;
1063
 
1064
  if (t_embd && res->get_embd_pooled()) {
 
1142
  n_outputs = n_outputs_all;
1143
 
1144
  // set output mappings
1145
+ if (n_outputs > 0) {
1146
  bool sorted_output = true;
1147
 
1148
  auto & out_ids = mstate->out_ids();
1149
 
1150
+ GGML_ASSERT(out_ids.size() == (size_t) n_outputs);
1151
 
1152
+ for (int64_t i = 0; i < n_outputs; ++i) {
1153
  int64_t out_id = out_ids[i];
1154
  output_ids[out_id] = i;
1155
  if (out_id != i) {
 
1161
  // note: this is mostly relevant for recurrent models atm
1162
  if (!sorted_output) {
1163
  const uint32_t n_vocab = model.vocab.n_tokens();
1164
+ const uint64_t n_embd = model.hparams.n_embd;
1165
 
1166
  GGML_ASSERT((size_t) n_outputs == out_ids.size());
1167
 
1168
  // TODO: is there something more efficient which also minimizes swaps?
1169
  // selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
1170
+ for (uint32_t i = 0; i < n_outputs - 1; ++i) {
1171
+ uint32_t j_min = i;
1172
+ for (uint32_t j = i + 1; j < n_outputs; ++j) {
1173
  if (out_ids[j] < out_ids[j_min]) {
1174
  j_min = j;
1175
  }
1176
  }
1177
+ if (j_min == i) {
1178
+ continue;
1179
+ }
1180
  std::swap(out_ids[i], out_ids[j_min]);
1181
  if (logits_size > 0) {
1182
  for (uint32_t k = 0; k < n_vocab; k++) {
 
1189
  }
1190
  }
1191
  }
1192
+
1193
  std::fill(output_ids.begin(), output_ids.end(), -1);
1194
+
1195
+ for (uint32_t i = 0; i < n_outputs; ++i) {
1196
  output_ids[out_ids[i]] = i;
1197
  }
1198
  }
 
1212
  // output
1213
  //
1214
 
1215
+ uint32_t llama_context::output_reserve(int32_t n_outputs) {
1216
  const auto & hparams = model.hparams;
1217
  const auto & vocab = model.vocab;
1218
 
 
1222
  const auto n_vocab = vocab.n_tokens();
1223
  const auto n_embd = hparams.n_embd;
1224
 
1225
+ bool has_logits = true;
1226
+ bool has_embd = cparams.embeddings;
 
1227
 
1228
  // TODO: hacky enc-dec support
1229
  if (model.arch == LLM_ARCH_T5) {
 
1277
  // set all ids as invalid (negative)
1278
  std::fill(output_ids.begin(), output_ids.end(), -1);
1279
 
1280
+ this->n_outputs = 0;
 
1281
 
1282
  return n_outputs_max;
1283
  }
 
1306
  LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs);
1307
 
1308
  if (n_tokens % n_seqs != 0) {
1309
+ n_tokens = ((n_tokens + (n_seqs - 1)) / n_seqs) * n_seqs; // round to next multiple of n_seqs
1310
  n_outputs = std::min(n_outputs, n_tokens);
1311
 
1312
  LLAMA_LOG_DEBUG("%s: making n_tokens a multiple of n_seqs - n_tokens = %u, n_seqs = %u, n_outputs = %u\n", __func__, n_tokens, n_seqs, n_outputs);
 
1768
 
1769
  std::vector<int32_t> w_output_pos;
1770
 
 
 
1771
  w_output_pos.resize(n_outputs);
1772
 
1773
  // build a more compact representation of the output ids
1774
  for (size_t i = 0; i < n_batch(); ++i) {
1775
  // map an output id to a position in the batch
1776
+ int64_t pos = output_ids[i];
1777
  if (pos >= 0) {
1778
  GGML_ASSERT(pos < n_outputs);
1779
  w_output_pos[pos] = i;
 
2043
 
2044
  n_queued_tokens += n_tokens_all;
2045
 
 
 
 
2046
  embd_seq.clear();
2047
 
2048
+ uint32_t n_outputs_all = n_tokens_all;
2049
 
2050
+ auto mstate = memory->init_batch(batch, cparams.n_ubatch, true);
2051
  if (!mstate || mstate->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
2052
  LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
2053
  break;
 
2055
 
2056
  // reserve output buffer
2057
  if (output_reserve(n_outputs_all) < n_outputs_all) {
2058
+ LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all);
2059
  GGML_ABORT("TODO: handle this error");
2060
  };
2061
 
examples/talk-llama/llama-context.h CHANGED
@@ -1,7 +1,6 @@
1
  #pragma once
2
 
3
  #include "llama.h"
4
- #include "llama-batch.h"
5
  #include "llama-cparams.h"
6
  #include "llama-graph.h"
7
  #include "llama-adapter.h"
@@ -13,6 +12,7 @@
13
  #include <vector>
14
 
15
  struct llama_model;
 
16
 
17
  class llama_io_read_i;
18
  class llama_io_write_i;
@@ -102,8 +102,8 @@ struct llama_context {
102
  llama_memory_state_i * mstate,
103
  ggml_status & ret);
104
 
105
- int encode(llama_batch & inp_batch);
106
- int decode(llama_batch & inp_batch);
107
 
108
  //
109
  // state save/load
@@ -181,7 +181,7 @@ private:
181
 
182
  // Make sure enough space is available for outputs.
183
  // Returns max number of outputs for which space was reserved.
184
- int32_t output_reserve(int32_t n_outputs);
185
 
186
  //
187
  // graph
@@ -246,8 +246,10 @@ private:
246
  // populated only when pooling_type != LLAMA_POOLING_TYPE_NONE
247
  std::map<llama_seq_id, std::vector<float>> embd_seq;
248
 
249
- int32_t n_outputs = 0; // number of actually-used outputs in the current ubatch or last logical batch
250
- int32_t n_outputs_max = 0; // capacity (of tokens positions) for the output buffers
 
 
251
 
252
  std::vector<int32_t> output_ids; // map batch token positions to ids of the logits and embd buffers
253
 
 
1
  #pragma once
2
 
3
  #include "llama.h"
 
4
  #include "llama-cparams.h"
5
  #include "llama-graph.h"
6
  #include "llama-adapter.h"
 
12
  #include <vector>
13
 
14
  struct llama_model;
15
+ class llama_batch_allocr;
16
 
17
  class llama_io_read_i;
18
  class llama_io_write_i;
 
102
  llama_memory_state_i * mstate,
103
  ggml_status & ret);
104
 
105
+ int encode(const llama_batch & batch_inp);
106
+ int decode(const llama_batch & batch_inp);
107
 
108
  //
109
  // state save/load
 
181
 
182
  // Make sure enough space is available for outputs.
183
  // Returns max number of outputs for which space was reserved.
184
+ uint32_t output_reserve(int32_t n_outputs);
185
 
186
  //
187
  // graph
 
246
  // populated only when pooling_type != LLAMA_POOLING_TYPE_NONE
247
  std::map<llama_seq_id, std::vector<float>> embd_seq;
248
 
249
+ // reuse the batch_allocr to avoid unnecessary memory allocations
250
+ std::unique_ptr<llama_batch_allocr> batch_allocr;
251
+
252
+ uint32_t n_outputs = 0; // number of actually-used outputs in the current ubatch or last logical batch
253
 
254
  std::vector<int32_t> output_ids; // map batch token positions to ids of the logits and embd buffers
255
 
examples/talk-llama/llama-cparams.cpp CHANGED
@@ -1,5 +1,5 @@
1
  #include "llama-cparams.h"
2
 
3
  size_t llama_max_parallel_sequences(void) {
4
- return LLAMA_MAX_PARALLEL_SEQUENCES;
5
  }
 
1
  #include "llama-cparams.h"
2
 
3
  size_t llama_max_parallel_sequences(void) {
4
+ return LLAMA_MAX_SEQ;
5
  }
examples/talk-llama/llama-cparams.h CHANGED
@@ -4,7 +4,7 @@
4
 
5
  #include <cstdint>
6
 
7
- #define LLAMA_MAX_PARALLEL_SEQUENCES 64
8
 
9
  struct llama_cparams {
10
  uint32_t n_ctx; // context size used during inference
 
4
 
5
  #include <cstdint>
6
 
7
+ #define LLAMA_MAX_SEQ 64
8
 
9
  struct llama_cparams {
10
  uint32_t n_ctx; // context size used during inference
examples/talk-llama/llama-graph.cpp CHANGED
@@ -139,6 +139,7 @@ void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
139
 
140
  std::vector<uint64_t> sum(n_tokens, 0);
141
 
 
142
  for (int s = 0; s < n_seqs; ++s) {
143
  const llama_seq_id seq_id = ubatch->seq_id[s][0];
144
 
@@ -156,6 +157,7 @@ void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
156
  }
157
  }
158
 
 
159
  for (int s = 0; s < n_seqs; ++s) {
160
  const llama_seq_id seq_id = ubatch->seq_id[s][0];
161
 
@@ -180,6 +182,7 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
180
  uint32_t * data = (uint32_t *) cls->data;
181
  memset(cls->data, 0, n_tokens * ggml_element_size(cls));
182
 
 
183
  for (int s = 0; s < n_seqs; ++s) {
184
  const llama_seq_id seq_id = ubatch->seq_id[s][0];
185
 
@@ -210,6 +213,7 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
210
  std::vector<int> last_pos(n_tokens, -1);
211
  std::vector<int> last_row(n_tokens, -1);
212
 
 
213
  for (int s = 0; s < n_seqs; ++s) {
214
  const llama_seq_id seq_id = ubatch->seq_id[s][0];
215
 
@@ -250,22 +254,6 @@ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
250
  }
251
  }
252
 
253
- void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) {
254
- GGML_UNUSED(ubatch);
255
-
256
- const int64_t n_kv = kv_state->get_n_kv();
257
-
258
- if (s_mask) {
259
- GGML_ASSERT(ggml_backend_buffer_is_host(s_mask->buffer));
260
- float * data = (float *) s_mask->data;
261
-
262
- // clear unused states
263
- for (int i = 0; i < n_kv; ++i) {
264
- data[i] = kv_state->s_mask(i);
265
- }
266
- }
267
- }
268
-
269
  void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
270
  GGML_UNUSED(ubatch);
271
 
@@ -299,6 +287,7 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
299
  const int32_t ti = s0*n_seq_tokens + i;
300
  float f = -INFINITY;
301
 
 
302
  for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
303
  if (ubatch->seq_id[s0][s] == seq_id && ubatch->pos[ti] <= ubatch->pos[tj]) {
304
  if (hparams.use_alibi) {
@@ -338,6 +327,7 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
338
  const int32_t ti = s0*n_seq_tokens + i;
339
  float f = -INFINITY;
340
 
 
341
  for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
342
  if (ubatch->seq_id[s0][s] == seq_id) {
343
  if (hparams.use_alibi) {
@@ -393,6 +383,7 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
393
  for (int j = 0; j < n_tokens; ++j) {
394
  for (int i = 0; i < n_enc; ++i) {
395
  float f = -INFINITY;
 
396
  for (int s = 0; s < ubatch->n_seq_id[j]; ++s) {
397
  const llama_seq_id seq_id = ubatch->seq_id[j][s];
398
  if (cross->seq_ids_enc[i].find(seq_id) != cross->seq_ids_enc[i].end()) {
@@ -650,6 +641,7 @@ ggml_tensor * llm_graph_context::build_ffn(
650
  {
651
  // Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
652
  int64_t split_point = cur->ne[0] / 2;
 
653
  ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
654
  ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
655
 
@@ -663,7 +655,7 @@ ggml_tensor * llm_graph_context::build_ffn(
663
  {
664
  // Split into two equal parts
665
  int64_t split_point = cur->ne[0] / 2;
666
- // TODO: these conts should not be needed
667
  ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
668
  ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
669
 
@@ -986,23 +978,6 @@ ggml_tensor * llm_graph_context::build_inp_s_copy() const {
986
  return cur;
987
  }
988
 
989
- ggml_tensor * llm_graph_context::build_inp_s_mask() const {
990
- const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
991
-
992
- auto inp = std::make_unique<llm_graph_input_s_mask>(kv_state);
993
-
994
- const auto n_kv = kv_state->get_n_kv();
995
-
996
- auto & cur = inp->s_mask;
997
-
998
- cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_kv);
999
- ggml_set_input(cur);
1000
-
1001
- res->add_input(std::move(inp));
1002
-
1003
- return cur;
1004
- }
1005
-
1006
  ggml_tensor * llm_graph_context::build_inp_cross_embd() const {
1007
  auto inp = std::make_unique<llm_graph_input_cross_embd>(cross);
1008
 
@@ -1455,43 +1430,53 @@ ggml_tensor * llm_graph_context::build_attn(
1455
  return cur;
1456
  }
1457
 
1458
- ggml_tensor * llm_graph_context::build_copy_mask_state(
1459
  ggml_cgraph * gf,
1460
  ggml_tensor * s,
1461
  ggml_tensor * state_copy,
1462
- ggml_tensor * state_mask,
1463
- int32_t n_state,
1464
- int32_t n_seqs) const {
1465
  const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
1466
 
1467
  const auto n_kv = kv_state->get_n_kv();
1468
  const auto kv_head = kv_state->get_head();
 
 
 
1469
 
1470
- ggml_tensor * states = ggml_reshape_2d(ctx0, s, n_state, kv_state->get_size());
 
 
 
1471
 
1472
- // copy states
1473
- // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
1474
- // this shrinks the tensors's ne[1] to n_kv
1475
- states = ggml_get_rows(ctx0, states, state_copy);
1476
 
1477
- // clear states of sequences which are starting at the beginning of this batch
1478
- // FIXME: zero-out NANs?
1479
- states = ggml_mul(ctx0, states, state_mask);
 
 
 
 
 
 
 
 
1480
 
1481
- // copy states which won't be changed further (between n_seqs and n_kv)
 
1482
  ggml_build_forward_expand(gf,
1483
  ggml_cpy(ctx0,
1484
- ggml_view_1d(ctx0, states, n_state*(n_kv - n_seqs), (n_seqs )*n_state*ggml_element_size(states)),
1485
- ggml_view_1d(ctx0, s, n_state*(n_kv - n_seqs), (kv_head + n_seqs)*n_state*ggml_element_size(s))));
1486
 
1487
- // the part of the states that will be used and modified
1488
- return ggml_view_2d(ctx0, states, n_state, n_seqs, states->nb[1], 0);
1489
  }
1490
 
1491
  ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
1492
  ggml_cgraph * gf,
1493
  ggml_tensor * state_copy,
1494
- ggml_tensor * state_mask,
1495
  const llama_ubatch & ubatch,
1496
  int il) const {
1497
  const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
@@ -1502,8 +1487,8 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
1502
 
1503
  ggml_tensor * token_shift_all = kv_state->get_k_l(il);
1504
 
1505
- ggml_tensor * token_shift = build_copy_mask_state(
1506
- gf, token_shift_all, state_copy, state_mask,
1507
  hparams.n_embd_k_s(), n_seqs);
1508
 
1509
  token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs);
@@ -1578,23 +1563,30 @@ void llm_graph_context::build_pooling(
1578
  ggml_tensor * inp_cls = build_inp_cls();
1579
  inp = ggml_get_rows(ctx0, inp, inp_cls);
1580
 
1581
- if (cls != nullptr && cls_b != nullptr) {
1582
  // classification head
1583
  // https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
1584
- cur = ggml_add(ctx0, ggml_mul_mat(ctx0, cls, inp), cls_b);
 
 
 
1585
  cur = ggml_tanh(ctx0, cur);
1586
 
1587
  // some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
1588
  // https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896
1589
  if (cls_out) {
1590
- GGML_ASSERT(cls_out_b != nullptr);
1591
- cur = ggml_add(ctx0, ggml_mul_mat(ctx0, cls_out, cur), cls_out_b);
 
 
1592
  }
1593
  } else if (cls_out) {
1594
  // Single layer classification head (direct projection)
1595
  // https://github.com/huggingface/transformers/blob/f4fc42216cd56ab6b68270bf80d811614d8d59e4/src/transformers/models/bert/modeling_bert.py#L1476
1596
- GGML_ASSERT(cls_out_b != nullptr);
1597
- cur = ggml_add(ctx0, ggml_mul_mat(ctx0, cls_out, inp), cls_out_b);
 
 
1598
  } else {
1599
  GGML_ABORT("RANK pooling requires either cls+cls_b or cls_out+cls_out_b");
1600
  }
 
139
 
140
  std::vector<uint64_t> sum(n_tokens, 0);
141
 
142
+ // TODO: fix indexing [UBATCH_IDX]
143
  for (int s = 0; s < n_seqs; ++s) {
144
  const llama_seq_id seq_id = ubatch->seq_id[s][0];
145
 
 
157
  }
158
  }
159
 
160
+ // TODO: fix indexing [UBATCH_IDX]
161
  for (int s = 0; s < n_seqs; ++s) {
162
  const llama_seq_id seq_id = ubatch->seq_id[s][0];
163
 
 
182
  uint32_t * data = (uint32_t *) cls->data;
183
  memset(cls->data, 0, n_tokens * ggml_element_size(cls));
184
 
185
+ // TODO: fix indexing [UBATCH_IDX]
186
  for (int s = 0; s < n_seqs; ++s) {
187
  const llama_seq_id seq_id = ubatch->seq_id[s][0];
188
 
 
213
  std::vector<int> last_pos(n_tokens, -1);
214
  std::vector<int> last_row(n_tokens, -1);
215
 
216
+ // TODO: fix indexing [UBATCH_IDX]
217
  for (int s = 0; s < n_seqs; ++s) {
218
  const llama_seq_id seq_id = ubatch->seq_id[s][0];
219
 
 
254
  }
255
  }
256
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
  void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
258
  GGML_UNUSED(ubatch);
259
 
 
287
  const int32_t ti = s0*n_seq_tokens + i;
288
  float f = -INFINITY;
289
 
290
+ // TODO: fix indexing [UBATCH_IDX]
291
  for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
292
  if (ubatch->seq_id[s0][s] == seq_id && ubatch->pos[ti] <= ubatch->pos[tj]) {
293
  if (hparams.use_alibi) {
 
327
  const int32_t ti = s0*n_seq_tokens + i;
328
  float f = -INFINITY;
329
 
330
+ // TODO: fix indexing [UBATCH_IDX]
331
  for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
332
  if (ubatch->seq_id[s0][s] == seq_id) {
333
  if (hparams.use_alibi) {
 
383
  for (int j = 0; j < n_tokens; ++j) {
384
  for (int i = 0; i < n_enc; ++i) {
385
  float f = -INFINITY;
386
+ // TODO: fix indexing [UBATCH_IDX]
387
  for (int s = 0; s < ubatch->n_seq_id[j]; ++s) {
388
  const llama_seq_id seq_id = ubatch->seq_id[j][s];
389
  if (cross->seq_ids_enc[i].find(seq_id) != cross->seq_ids_enc[i].end()) {
 
641
  {
642
  // Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
643
  int64_t split_point = cur->ne[0] / 2;
644
+ // TODO: these conts should not be needed, see https://github.com/ggml-org/llama.cpp/pull/14090#discussion_r2137437217
645
  ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
646
  ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
647
 
 
655
  {
656
  // Split into two equal parts
657
  int64_t split_point = cur->ne[0] / 2;
658
+ // TODO: these conts should not be needed, see https://github.com/ggml-org/llama.cpp/pull/14090#discussion_r2137437217
659
  ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
660
  ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
661
 
 
978
  return cur;
979
  }
980
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
981
  ggml_tensor * llm_graph_context::build_inp_cross_embd() const {
982
  auto inp = std::make_unique<llm_graph_input_cross_embd>(cross);
983
 
 
1430
  return cur;
1431
  }
1432
 
1433
+ ggml_tensor * llm_graph_context::build_recurrent_state(
1434
  ggml_cgraph * gf,
1435
  ggml_tensor * s,
1436
  ggml_tensor * state_copy,
1437
+ int32_t state_size,
1438
+ int32_t n_seqs,
1439
+ bool avoid_copies) const {
1440
  const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
1441
 
1442
  const auto n_kv = kv_state->get_n_kv();
1443
  const auto kv_head = kv_state->get_head();
1444
+ const auto rs_zero = kv_state->get_rs_z();
1445
+
1446
+ ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_state->get_size());
1447
 
1448
+ // Clear a single state which will then be copied to the other cleared states.
1449
+ // Note that this is a no-op when the view is zero-sized.
1450
+ ggml_tensor * state_zero = ggml_view_1d(ctx0, states, state_size*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0));
1451
+ ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0));
1452
 
1453
+ ggml_tensor * output_states;
 
 
 
1454
 
1455
+ if (!avoid_copies) {
1456
+ // copy states
1457
+ // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
1458
+ // {state_size, kv_size} -> {state_size, n_seqs}
1459
+ output_states = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0));
1460
+ ggml_build_forward_expand(gf, output_states);
1461
+ } else {
1462
+ // FIXME: make the gathering operation happen before the copy below
1463
+ // (maybe with an optional lambda function passed as a parameter instead of `avoid_copies`?)
1464
+ output_states = states;
1465
+ }
1466
 
1467
+ // copy extra states which won't be changed further (between n_seqs and n_kv)
1468
+ ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0]));
1469
  ggml_build_forward_expand(gf,
1470
  ggml_cpy(ctx0,
1471
+ states_extra,
1472
+ ggml_view_1d(ctx0, s, state_size*(n_kv - n_seqs), (kv_head + n_seqs)*state_size*ggml_element_size(s))));
1473
 
1474
+ return output_states;
 
1475
  }
1476
 
1477
  ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
1478
  ggml_cgraph * gf,
1479
  ggml_tensor * state_copy,
 
1480
  const llama_ubatch & ubatch,
1481
  int il) const {
1482
  const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
 
1487
 
1488
  ggml_tensor * token_shift_all = kv_state->get_k_l(il);
1489
 
1490
+ ggml_tensor * token_shift = build_recurrent_state(
1491
+ gf, token_shift_all, state_copy,
1492
  hparams.n_embd_k_s(), n_seqs);
1493
 
1494
  token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs);
 
1563
  ggml_tensor * inp_cls = build_inp_cls();
1564
  inp = ggml_get_rows(ctx0, inp, inp_cls);
1565
 
1566
+ if (cls) {
1567
  // classification head
1568
  // https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
1569
+ cur = ggml_mul_mat(ctx0, cls, inp);
1570
+ if (cls_b) {
1571
+ cur = ggml_add(ctx0, cur, cls_b);
1572
+ }
1573
  cur = ggml_tanh(ctx0, cur);
1574
 
1575
  // some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
1576
  // https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896
1577
  if (cls_out) {
1578
+ cur = ggml_mul_mat(ctx0, cls_out, cur);
1579
+ if (cls_out_b) {
1580
+ cur = ggml_add(ctx0, cur, cls_out_b);
1581
+ }
1582
  }
1583
  } else if (cls_out) {
1584
  // Single layer classification head (direct projection)
1585
  // https://github.com/huggingface/transformers/blob/f4fc42216cd56ab6b68270bf80d811614d8d59e4/src/transformers/models/bert/modeling_bert.py#L1476
1586
+ cur = ggml_mul_mat(ctx0, cls_out, inp);
1587
+ if (cls_out_b) {
1588
+ cur = ggml_add(ctx0, cur, cls_out_b);
1589
+ }
1590
  } else {
1591
  GGML_ABORT("RANK pooling requires either cls+cls_b or cls_out+cls_out_b");
1592
  }
examples/talk-llama/llama-graph.h CHANGED
@@ -200,18 +200,6 @@ public:
200
  const llama_kv_cache_recurrent_state * kv_state;
201
  };
202
 
203
- class llm_graph_input_s_mask : public llm_graph_input_i {
204
- public:
205
- llm_graph_input_s_mask(const llama_kv_cache_recurrent_state * kv_state) : kv_state(kv_state) {}
206
- virtual ~llm_graph_input_s_mask() = default;
207
-
208
- void set_input(const llama_ubatch * ubatch) override;
209
-
210
- ggml_tensor * s_mask; // F32 [1, n_kv]
211
-
212
- const llama_kv_cache_recurrent_state * kv_state;
213
- };
214
-
215
  class llm_graph_input_cross_embd : public llm_graph_input_i {
216
  public:
217
  llm_graph_input_cross_embd(
@@ -390,7 +378,7 @@ struct llm_graph_params {
390
  const llama_memory_state_i * mstate;
391
  const llama_cross * cross;
392
 
393
- int32_t n_outputs;
394
 
395
  const llm_graph_cb & cb;
396
  };
@@ -424,8 +412,8 @@ struct llm_graph_context {
424
  const float norm_eps;
425
  const float norm_rms_eps;
426
 
427
- const int32_t n_tokens;
428
- const int32_t n_outputs;
429
  const int32_t n_ctx_orig; // yarn
430
 
431
  const enum llama_pooling_type pooling_type;
@@ -521,7 +509,6 @@ struct llm_graph_context {
521
  ggml_tensor * build_inp_mean() const;
522
  ggml_tensor * build_inp_cls() const;
523
  ggml_tensor * build_inp_s_copy() const;
524
- ggml_tensor * build_inp_s_mask() const;
525
 
526
  ggml_tensor * build_inp_cross_embd() const;
527
  ggml_tensor * build_inp_pos_bucket_enc() const;
@@ -606,18 +593,17 @@ struct llm_graph_context {
606
  // recurrent
607
  //
608
 
609
- ggml_tensor * build_copy_mask_state(
610
  ggml_cgraph * gf,
611
  ggml_tensor * s,
612
  ggml_tensor * state_copy,
613
- ggml_tensor * state_mask,
614
- int32_t n_state,
615
- int32_t n_seqs) const;
616
 
617
  ggml_tensor * build_rwkv_token_shift_load(
618
  ggml_cgraph * gf,
619
  ggml_tensor * state_copy,
620
- ggml_tensor * state_mask,
621
  const llama_ubatch & ubatch,
622
  int il) const;
623
 
 
200
  const llama_kv_cache_recurrent_state * kv_state;
201
  };
202
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  class llm_graph_input_cross_embd : public llm_graph_input_i {
204
  public:
205
  llm_graph_input_cross_embd(
 
378
  const llama_memory_state_i * mstate;
379
  const llama_cross * cross;
380
 
381
+ uint32_t n_outputs;
382
 
383
  const llm_graph_cb & cb;
384
  };
 
412
  const float norm_eps;
413
  const float norm_rms_eps;
414
 
415
+ const int64_t n_tokens;
416
+ const int64_t n_outputs;
417
  const int32_t n_ctx_orig; // yarn
418
 
419
  const enum llama_pooling_type pooling_type;
 
509
  ggml_tensor * build_inp_mean() const;
510
  ggml_tensor * build_inp_cls() const;
511
  ggml_tensor * build_inp_s_copy() const;
 
512
 
513
  ggml_tensor * build_inp_cross_embd() const;
514
  ggml_tensor * build_inp_pos_bucket_enc() const;
 
593
  // recurrent
594
  //
595
 
596
+ ggml_tensor * build_recurrent_state(
597
  ggml_cgraph * gf,
598
  ggml_tensor * s,
599
  ggml_tensor * state_copy,
600
+ int32_t state_size,
601
+ int32_t n_seqs,
602
+ bool avoid_copies = false) const;
603
 
604
  ggml_tensor * build_rwkv_token_shift_load(
605
  ggml_cgraph * gf,
606
  ggml_tensor * state_copy,
 
607
  const llama_ubatch & ubatch,
608
  int il) const;
609
 
examples/talk-llama/llama-kv-cache-recurrent.cpp CHANGED
@@ -359,18 +359,16 @@ llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const {
359
  return result;
360
  }
361
 
362
- llama_memory_state_ptr llama_kv_cache_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) {
363
- GGML_UNUSED(embd_pooled);
364
-
365
- auto sbatch = llama_sbatch(batch, hparams.n_embd, false, logits_all);
366
 
367
  std::vector<llama_ubatch> ubatches;
368
 
369
  while (sbatch.n_tokens > 0) {
370
  llama_ubatch ubatch;
371
 
372
- if (embd_pooled) {
373
- // Pooled embeddings cannot be split across ubatches (yet)
374
  ubatch = sbatch.split_seq(n_ubatch);
375
  } else {
376
  ubatch = sbatch.split_equal(n_ubatch);
@@ -406,21 +404,12 @@ bool llama_kv_cache_recurrent::prepare(const std::vector<llama_ubatch> & ubatche
406
 
407
  bool success = true;
408
 
409
- // TODO: here we have to verify that all ubatches can fit in the cells
410
- // however, the current implementation is broken because it relies on s_copy() and s_mask() to update the cells
411
- // during the compute of each ubatch. to reproduce, uncomment the following loop and run:
412
- //
413
- // $ llama-parallel -m ./mamba-130m/ggml-model-f16.gguf -np 5 -ns 8
414
- //
415
- // recovery from failures when the batch does not fit in the KV cache will not work correctly until this is fixed
416
- //
417
- GGML_UNUSED(ubatches);
418
- //for (const auto & ubatch : ubatches) {
419
- // if (!find_slot(ubatch)) {
420
- // success = false;
421
- // break;
422
- // }
423
- //}
424
 
425
  // restore the original state
426
  cells = std::move(org_cells);
@@ -431,14 +420,13 @@ bool llama_kv_cache_recurrent::prepare(const std::vector<llama_ubatch> & ubatche
431
  }
432
 
433
  bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
434
- const uint32_t n_tokens = ubatch.n_tokens;
435
- const uint32_t n_seqs = ubatch.n_seqs;
436
 
437
  const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
438
 
439
  // if we have enough unused cells before the current head ->
440
  // better to start searching from the beginning of the cache, hoping to fill it
441
- if (head > used + 2*n_tokens) {
442
  head = 0;
443
  }
444
 
@@ -534,16 +522,16 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
534
  empty_cell.src = orig_cell.src;
535
  orig_cell.seq_id.erase(seq_id);
536
  empty_cell.seq_id.insert(seq_id); // will be overwritten
 
537
  }
538
  seq_meta.tail = next_empty_cell;
539
  // find next empty cell
540
  if (s + 1 < n_seqs) {
541
- next_empty_cell += 1;
542
  for (uint32_t i = 0; i < size; ++i) {
 
543
  if (next_empty_cell >= size) { next_empty_cell -= size; }
544
  kv_cell & cell = cells[next_empty_cell];
545
  if (cell.is_empty()) { break; }
546
- next_empty_cell += 1;
547
  }
548
  }
549
  }
@@ -553,8 +541,8 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
553
 
554
  // gather and re-order
555
  for (uint32_t s = 0; s < n_seqs; ++s) {
556
- int32_t dst_id = s + min;
557
- int32_t src_id = cells[ubatch.seq_id[s][0]].tail;
558
  if (dst_id != src_id) {
559
  kv_cell & dst_cell = cells[dst_id];
560
  kv_cell & src_cell = cells[src_id];
@@ -563,12 +551,14 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
563
  std::swap(dst_cell.src, src_cell.src);
564
  std::swap(dst_cell.seq_id, src_cell.seq_id);
565
 
566
- // swap tails (assuming they NEVER overlap)
567
- for (const llama_seq_id seq_id : src_cell.seq_id) {
568
- cells[seq_id].tail = src_id;
569
- }
570
- for (const llama_seq_id seq_id : dst_cell.seq_id) {
571
- cells[seq_id].tail = dst_id;
 
 
572
  }
573
  }
574
  }
@@ -576,7 +566,7 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
576
  // update the pos of the used seqs
577
  for (uint32_t s = 0; s < n_seqs; ++s) {
578
  const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1];
579
- int32_t cell_id = s + min;
580
  kv_cell & cell = cells[cell_id];
581
 
582
  if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
@@ -594,6 +584,38 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
594
  }
595
  }
596
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
597
  // allow getting the range of used cells, from head to head + n
598
  head = min;
599
  n = max - min + 1;
@@ -605,47 +627,8 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
605
  }
606
 
607
  bool llama_kv_cache_recurrent::get_can_shift() const {
608
- return false;
609
- }
610
-
611
- int32_t llama_kv_cache_recurrent::s_copy(int i) const {
612
- const uint32_t cell_id = i + head;
613
-
614
- //////////////////////////////////////////////
615
- // TODO: this should not mutate the KV cache !
616
- kv_cell & cell = const_cast<kv_cell &>(cells[cell_id]);
617
-
618
- // prevent out-of-bound sources
619
- if (cell.src < 0 || (uint32_t) cell.src >= size) {
620
- cell.src = cell_id;
621
- }
622
-
623
- int32_t res = cell.src;
624
-
625
- // TODO: do not mutate the KV cache
626
- // ensure copy only happens once
627
- if (cell.src != (int32_t) cell_id) {
628
- cell.src = cell_id;
629
- }
630
-
631
- return res;
632
- }
633
-
634
- float llama_kv_cache_recurrent::s_mask(int i) const {
635
- const uint32_t cell_id = i + head;
636
-
637
- //////////////////////////////////////////////
638
- // TODO: this should not mutate the KV cache !
639
- kv_cell & cell = const_cast<kv_cell &>(cells[cell_id]);
640
-
641
- float res = (float) (cell.src >= 0);
642
-
643
- // only clear once
644
- if (cell.src < 0) {
645
- cell.src = cell_id;
646
- }
647
-
648
- return res;
649
  }
650
 
651
  size_t llama_kv_cache_recurrent::total_size() const {
@@ -1111,6 +1094,10 @@ uint32_t llama_kv_cache_recurrent_state::get_head() const {
1111
  return is_full ? 0 : kv->head;
1112
  }
1113
 
 
 
 
 
1114
  uint32_t llama_kv_cache_recurrent_state::get_size() const {
1115
  return kv->size;
1116
  }
@@ -1124,9 +1111,5 @@ ggml_tensor * llama_kv_cache_recurrent_state::get_v_l(int32_t il) const {
1124
  }
1125
 
1126
  int32_t llama_kv_cache_recurrent_state::s_copy(int i) const {
1127
- return kv->s_copy(i);
1128
- }
1129
-
1130
- float llama_kv_cache_recurrent_state::s_mask(int i) const {
1131
- return kv->s_mask(i);
1132
  }
 
359
  return result;
360
  }
361
 
362
+ llama_memory_state_ptr llama_kv_cache_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_all) {
363
+ auto sbatch = llama_sbatch(batch, hparams.n_embd, false);
 
 
364
 
365
  std::vector<llama_ubatch> ubatches;
366
 
367
  while (sbatch.n_tokens > 0) {
368
  llama_ubatch ubatch;
369
 
370
+ if (embd_all) {
371
+ // if all tokens are output, split by sequence
372
  ubatch = sbatch.split_seq(n_ubatch);
373
  } else {
374
  ubatch = sbatch.split_equal(n_ubatch);
 
404
 
405
  bool success = true;
406
 
407
+ for (const auto & ubatch : ubatches) {
408
+ if (!find_slot(ubatch)) {
409
+ success = false;
410
+ break;
411
+ }
412
+ }
 
 
 
 
 
 
 
 
 
413
 
414
  // restore the original state
415
  cells = std::move(org_cells);
 
420
  }
421
 
422
  bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
423
+ const uint32_t n_seqs = ubatch.n_seqs;
 
424
 
425
  const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
426
 
427
  // if we have enough unused cells before the current head ->
428
  // better to start searching from the beginning of the cache, hoping to fill it
429
+ if (head > used + 2*n_seqs) {
430
  head = 0;
431
  }
432
 
 
522
  empty_cell.src = orig_cell.src;
523
  orig_cell.seq_id.erase(seq_id);
524
  empty_cell.seq_id.insert(seq_id); // will be overwritten
525
+ GGML_ASSERT(!orig_cell.is_empty()); // has at least one remaining seq_id
526
  }
527
  seq_meta.tail = next_empty_cell;
528
  // find next empty cell
529
  if (s + 1 < n_seqs) {
 
530
  for (uint32_t i = 0; i < size; ++i) {
531
+ next_empty_cell += 1;
532
  if (next_empty_cell >= size) { next_empty_cell -= size; }
533
  kv_cell & cell = cells[next_empty_cell];
534
  if (cell.is_empty()) { break; }
 
535
  }
536
  }
537
  }
 
541
 
542
  // gather and re-order
543
  for (uint32_t s = 0; s < n_seqs; ++s) {
544
+ const int32_t dst_id = s + min;
545
+ const int32_t src_id = cells[ubatch.seq_id[s][0]].tail;
546
  if (dst_id != src_id) {
547
  kv_cell & dst_cell = cells[dst_id];
548
  kv_cell & src_cell = cells[src_id];
 
551
  std::swap(dst_cell.src, src_cell.src);
552
  std::swap(dst_cell.seq_id, src_cell.seq_id);
553
 
554
+ // swap tails
555
+ for (uint32_t i = 0; i < size; ++i) {
556
+ int32_t & tail = cells[i].tail;
557
+ if (tail == src_id) {
558
+ tail = dst_id;
559
+ } else if (tail == dst_id) {
560
+ tail = src_id;
561
+ }
562
  }
563
  }
564
  }
 
566
  // update the pos of the used seqs
567
  for (uint32_t s = 0; s < n_seqs; ++s) {
568
  const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1];
569
+ const int32_t cell_id = s + min;
570
  kv_cell & cell = cells[cell_id];
571
 
572
  if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
 
584
  }
585
  }
586
 
587
+ // Find first cell without src refs, to use as the zero-ed state
588
+ {
589
+ // TODO: bake-in src refcounts in the cell metadata
590
+ std::vector<int32_t> refcounts(size, 0);
591
+ for (size_t i = 0; i < size; ++i) {
592
+ const int32_t src = cells[i].src;
593
+ if (src >= 0) {
594
+ refcounts[src] += 1;
595
+ }
596
+ }
597
+
598
+ rs_z = -1;
599
+ for (int i = min; i <= max; ++i) {
600
+ if (refcounts[i] == 0) {
601
+ rs_z = i;
602
+ break;
603
+ }
604
+ }
605
+
606
+ for (int i = min; i <= max; ++i) {
607
+ if (cells[i].src < 0) {
608
+ GGML_ASSERT(rs_z >= 0);
609
+ cells[i].src0 = rs_z;
610
+ } else {
611
+ // Stage the source ids for all used cells to allow correct seq_* behavior
612
+ // and still make these values available when setting the inputs
613
+ cells[i].src0 = cells[i].src;
614
+ }
615
+ cells[i].src = i; // avoid moving or clearing twice
616
+ }
617
+ }
618
+
619
  // allow getting the range of used cells, from head to head + n
620
  head = min;
621
  n = max - min + 1;
 
627
  }
628
 
629
  bool llama_kv_cache_recurrent::get_can_shift() const {
630
+ // shifting the pos is trivial for recurrent models
631
+ return true;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
632
  }
633
 
634
  size_t llama_kv_cache_recurrent::total_size() const {
 
1094
  return is_full ? 0 : kv->head;
1095
  }
1096
 
1097
+ int32_t llama_kv_cache_recurrent_state::get_rs_z() const {
1098
+ return is_full ? 0 : kv->rs_z;
1099
+ }
1100
+
1101
  uint32_t llama_kv_cache_recurrent_state::get_size() const {
1102
  return kv->size;
1103
  }
 
1111
  }
1112
 
1113
  int32_t llama_kv_cache_recurrent_state::s_copy(int i) const {
1114
+ return kv->cells[i + kv->head].src0;
 
 
 
 
1115
  }
examples/talk-llama/llama-kv-cache-recurrent.h CHANGED
@@ -32,8 +32,7 @@ public:
32
  llama_memory_state_ptr init_batch(
33
  const llama_batch & batch,
34
  uint32_t n_ubatch,
35
- bool embd_pooled,
36
- bool logits_all) override;
37
 
38
  llama_memory_state_ptr init_full() override;
39
 
@@ -57,10 +56,6 @@ public:
57
 
58
  bool get_can_shift() const override;
59
 
60
- // TODO: temporary methods - they are not really const as they do const_cast<>, fix this
61
- int32_t s_copy(int i) const;
62
- float s_mask(int i) const;
63
-
64
  // state write/load
65
 
66
  void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
@@ -73,10 +68,14 @@ public:
73
  // computed before each graph build
74
  uint32_t n = 0;
75
 
 
 
 
76
  // TODO: optimize for recurrent state needs
77
  struct kv_cell {
78
  llama_pos pos = -1;
79
- int32_t src = -1; // used to copy states
 
80
  int32_t tail = -1;
81
 
82
  std::set<llama_seq_id> seq_id;
@@ -157,13 +156,13 @@ public:
157
 
158
  uint32_t get_n_kv() const;
159
  uint32_t get_head() const;
 
160
  uint32_t get_size() const;
161
 
162
  ggml_tensor * get_k_l(int32_t il) const;
163
  ggml_tensor * get_v_l(int32_t il) const;
164
 
165
  int32_t s_copy(int i) const;
166
- float s_mask(int i) const;
167
 
168
  private:
169
  const llama_memory_status status;
 
32
  llama_memory_state_ptr init_batch(
33
  const llama_batch & batch,
34
  uint32_t n_ubatch,
35
+ bool embd_all) override;
 
36
 
37
  llama_memory_state_ptr init_full() override;
38
 
 
56
 
57
  bool get_can_shift() const override;
58
 
 
 
 
 
59
  // state write/load
60
 
61
  void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
 
68
  // computed before each graph build
69
  uint32_t n = 0;
70
 
71
+ // first zero-ed state
72
+ int32_t rs_z = -1;
73
+
74
  // TODO: optimize for recurrent state needs
75
  struct kv_cell {
76
  llama_pos pos = -1;
77
+ int32_t src = -1; // used to know where states should be copied from
78
+ int32_t src0 = -1; // like src, but only used when setting the inputs (allowing to copy once)
79
  int32_t tail = -1;
80
 
81
  std::set<llama_seq_id> seq_id;
 
156
 
157
  uint32_t get_n_kv() const;
158
  uint32_t get_head() const;
159
+ int32_t get_rs_z() const;
160
  uint32_t get_size() const;
161
 
162
  ggml_tensor * get_k_l(int32_t il) const;
163
  ggml_tensor * get_v_l(int32_t il) const;
164
 
165
  int32_t s_copy(int i) const;
 
166
 
167
  private:
168
  const llama_memory_status status;
examples/talk-llama/llama-kv-cache-unified-iswa.cpp CHANGED
@@ -95,36 +95,69 @@ llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
95
  return kv_swa->seq_pos_max(seq_id);
96
  }
97
 
98
- llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) {
99
- GGML_UNUSED(embd_pooled);
100
 
101
- // TODO: if we fail with split_simple, we should attempt different splitting strategies
102
- // but to do that properly, we first have to refactor the batches to be more flexible
 
103
 
104
- auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all);
105
 
106
- std::vector<llama_ubatch> ubatches;
 
107
 
108
- while (sbatch.n_tokens > 0) {
109
- auto ubatch = sbatch.split_simple(n_ubatch);
110
 
111
- ubatches.push_back(ubatch);
112
- }
 
 
113
 
114
- auto heads_base = kv_base->prepare(ubatches);
115
- if (heads_base.empty()) {
116
- return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
117
- }
118
 
119
- auto heads_swa = kv_swa->prepare(ubatches);
120
- if (heads_swa.empty()) {
121
- return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
122
- }
 
 
 
 
 
 
 
123
 
124
- assert(heads_base.size() == heads_swa.size());
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
- return std::make_unique<llama_kv_cache_unified_iswa_state>(
127
- this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches));
128
  }
129
 
130
  llama_memory_state_ptr llama_kv_cache_unified_iswa::init_full() {
 
95
  return kv_swa->seq_pos_max(seq_id);
96
  }
97
 
98
+ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_all) {
99
+ GGML_UNUSED(embd_all);
100
 
101
+ // first try simple split
102
+ do {
103
+ auto sbatch = llama_sbatch(batch, hparams.n_embd, true);
104
 
105
+ std::vector<llama_ubatch> ubatches;
106
 
107
+ while (sbatch.n_tokens > 0) {
108
+ auto ubatch = sbatch.split_simple(n_ubatch);
109
 
110
+ ubatches.push_back(ubatch);
111
+ }
112
 
113
+ auto heads_base = kv_base->prepare(ubatches);
114
+ if (heads_base.empty()) {
115
+ break;
116
+ }
117
 
118
+ auto heads_swa = kv_swa->prepare(ubatches);
119
+ if (heads_swa.empty()) {
120
+ break;
121
+ }
122
 
123
+ assert(heads_base.size() == heads_swa.size());
124
+
125
+ return std::make_unique<llama_kv_cache_unified_iswa_state>(
126
+ this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches));
127
+ } while (false);
128
+
129
+ // if it fails, try equal split
130
+ do {
131
+ auto sbatch = llama_sbatch(batch, hparams.n_embd, false);
132
+
133
+ std::vector<llama_ubatch> ubatches;
134
 
135
+ while (sbatch.n_tokens > 0) {
136
+ auto ubatch = sbatch.split_equal(n_ubatch);
137
+
138
+ ubatches.push_back(ubatch);
139
+ }
140
+
141
+ auto heads_base = kv_base->prepare(ubatches);
142
+ if (heads_base.empty()) {
143
+ break;
144
+ }
145
+
146
+ auto heads_swa = kv_swa->prepare(ubatches);
147
+ if (heads_swa.empty()) {
148
+ break;
149
+ }
150
+
151
+ assert(heads_base.size() == heads_swa.size());
152
+
153
+ return std::make_unique<llama_kv_cache_unified_iswa_state>(
154
+ this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches));
155
+ } while (false);
156
+
157
+ // TODO: if we fail again, we should attempt different splitting strategies
158
+ // but to do that properly, we first have to refactor the batches to be more flexible
159
 
160
+ return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
 
161
  }
162
 
163
  llama_memory_state_ptr llama_kv_cache_unified_iswa::init_full() {
examples/talk-llama/llama-kv-cache-unified-iswa.h CHANGED
@@ -34,8 +34,7 @@ public:
34
  llama_memory_state_ptr init_batch(
35
  const llama_batch & batch,
36
  uint32_t n_ubatch,
37
- bool embd_pooled,
38
- bool logits_all) override;
39
 
40
  llama_memory_state_ptr init_full() override;
41
 
 
34
  llama_memory_state_ptr init_batch(
35
  const llama_batch & batch,
36
  uint32_t n_ubatch,
37
+ bool embd_all) override;
 
38
 
39
  llama_memory_state_ptr init_full() override;
40
 
examples/talk-llama/llama-kv-cache-unified.cpp CHANGED
@@ -127,6 +127,9 @@ llama_kv_cache_unified::llama_kv_cache_unified(
127
  ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
128
  ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
129
  }
 
 
 
130
  }
131
 
132
  void llama_kv_cache_unified::clear(bool data) {
@@ -307,24 +310,27 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
307
  llama_memory_state_ptr llama_kv_cache_unified::init_batch(
308
  const llama_batch & batch,
309
  uint32_t n_ubatch,
310
- bool embd_pooled,
311
- bool logits_all) {
312
- GGML_UNUSED(embd_pooled);
313
 
314
- auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all);
 
315
 
316
- std::vector<llama_ubatch> ubatches;
317
- while (sbatch.n_tokens > 0) {
318
- ubatches.push_back(sbatch.split_simple(n_ubatch));
319
- }
320
 
321
- auto heads = prepare(ubatches);
322
- if (heads.empty()) {
323
- return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
324
- }
 
 
 
 
325
 
326
- return std::make_unique<llama_kv_cache_unified_state>(
327
- this, std::move(sbatch), std::move(heads), std::move(ubatches));
328
  }
329
 
330
  llama_memory_state_ptr llama_kv_cache_unified::init_full() {
@@ -512,43 +518,68 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
512
  head_cur = 0;
513
  }
514
 
515
- // otherwise, one cell per token.
516
-
517
  if (n_tokens > cells.size()) {
518
  LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size());
519
  return -1;
520
  }
521
 
522
- //#define FIND_SLOT_DEBUG 1
523
- #if FIND_SLOT_DEBUG
524
- LLAMA_LOG_WARN("begin: n = %5d, used = %5d, head = %5d, n_swa = %5d\n", cells.used_max_p1(), cells.get_used(), head, n_swa);
525
 
526
- // for debugging
527
- {
528
- std::string ss;
529
- if (n_swa > 0) {
530
  for (uint32_t i = 0; i < cells.size(); ++i) {
531
  if (cells.is_empty(i)) {
532
  ss += '.';
533
  } else {
534
- ss += std::to_string(cells.seq_get(i));
 
 
 
 
 
 
535
  }
536
  if (i%256 == 255) {
 
537
  ss += '\n';
538
  }
539
  }
 
540
  }
541
- LLAMA_LOG_WARN("\n%s\n", ss.c_str());
542
- }
543
 
544
- for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
545
- if (cells.seq_pos_min(s) < 0) {
546
- continue;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
547
  }
548
 
549
- LLAMA_LOG_WARN("kv_cells: n_swa = %4d, min[%d] = %5d, max[%d] = %5d\n", n_swa, s, cells.seq_pos_min(s), s, cells.seq_pos_max(s));
 
 
 
 
 
 
550
  }
551
- #endif
552
 
553
  uint32_t n_tested = 0;
554
 
@@ -559,21 +590,15 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
559
  continue;
560
  }
561
 
562
- // keep track of what the minimum sequence positions would be if we accept the ubatch
563
- llama_seq_id seq_pos_min[LLAMA_MAX_PARALLEL_SEQUENCES];
564
- for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
565
- seq_pos_min[s] = cells.seq_pos_min(s);
566
- }
567
-
568
  bool found = true;
569
  for (uint32_t i = 0; i < n_tokens; i++) {
570
- const llama_pos pos = ubatch.pos[i];
571
- const llama_seq_id seq_id = ubatch.seq_id[i][0];
572
 
573
  // can we use this cell? either:
574
  // - the cell is empty
575
  // - the cell is occupied only by one sequence:
576
- // - mask causally, if the sequence is the same as the one we are inserting
577
  // - mask SWA, using current max pos for that sequence in the cache
578
  // always insert in the cell with minimum pos
579
  bool can_use = cells.is_empty(head_cur + i);
@@ -581,21 +606,17 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
581
  if (!can_use && cells.seq_count(head_cur + i) == 1) {
582
  const llama_pos pos_cell = cells.pos_get(head_cur + i);
583
 
584
- // causal mask
585
- if (cells.seq_has(head_cur + i, seq_id)) {
586
- can_use = pos_cell >= pos;
587
- }
 
588
 
589
  if (!can_use) {
590
  const llama_seq_id seq_id_cell = cells.seq_get(head_cur + i);
591
 
592
  // SWA mask
593
- // note: we insert only in the cell with minimum pos in order to preserve the invariant that
594
- // all positions between [pos_min, pos_max] for each sequence will be present in the cache
595
- // ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092
596
- if (pos_cell == seq_pos_min[seq_id_cell] &&
597
- is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
598
- seq_pos_min[seq_id_cell]++;
599
  can_use = true;
600
  }
601
  }
@@ -623,18 +644,58 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
623
  }
624
 
625
  void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch) {
626
- for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
627
- if (!cells.is_empty(head_cur + i)) {
628
- cells.rm(head_cur + i);
629
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
630
 
631
- cells.pos_set(head_cur + i, ubatch.pos[i]);
632
 
633
- for (int32_t j = 0; j < ubatch.n_seq_id[i]; j++) {
634
- cells.seq_add(head_cur + i, ubatch.seq_id[i][j]);
 
 
 
 
 
 
 
635
  }
636
  }
637
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
638
  // move the head at the end of the slot
639
  head = head_cur + ubatch.n_tokens;
640
  }
@@ -731,14 +792,14 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
731
  }
732
 
733
  void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
734
- const int64_t n_tokens = ubatch->n_tokens;
735
- const int64_t n_seq_tokens = ubatch->n_seq_tokens;
736
- const int64_t n_seqs = ubatch->n_seqs;
737
 
738
  GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
739
  float * data = (float *) dst->data;
740
 
741
- const auto n_kv = dst->ne[0];
742
 
743
  // Use only the previous KV cells of the correct sequence for each token of the ubatch.
744
  // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
@@ -752,12 +813,14 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
752
  // xxxxx-----
753
  // xxxxx-----
754
  // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
755
- for (int h = 0; h < 1; ++h) {
756
- for (int s = 0; s < n_seqs; ++s) {
757
  const llama_seq_id seq_id = ubatch->seq_id[s][0];
758
 
759
- for (int j = 0; j < n_seq_tokens; ++j) {
760
- const llama_pos p1 = ubatch->pos[s*n_seq_tokens + j];
 
 
761
 
762
  for (uint32_t i = 0; i < n_kv; ++i) {
763
  float f = 0.0f;
@@ -787,16 +850,16 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
787
  f = -INFINITY;
788
  }
789
 
790
- data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
791
  }
792
  }
793
  }
794
 
795
  // mask padded tokens
796
  if (data) {
797
- for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
798
- for (uint32_t j = 0; j < n_kv; ++j) {
799
- data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
800
  }
801
  }
802
  }
@@ -1447,9 +1510,11 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
1447
  seq_rm(dest_seq_id, -1, -1);
1448
 
1449
  llama_sbatch sbatch;
1450
- llama_ubatch batch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
1451
 
1452
- batch.n_tokens = cell_count;
 
 
1453
 
1454
  for (uint32_t i = 0; i < cell_count; ++i) {
1455
  llama_pos pos;
@@ -1469,18 +1534,18 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
1469
  io.read_to(&seq_id, sizeof(seq_id));
1470
  }
1471
 
1472
- batch.pos[i] = pos;
1473
- batch.n_seq_id[i] = n_seq_id;
1474
- batch.seq_id[i] = &dest_seq_id;
1475
  }
1476
 
1477
- const auto head_cur = find_slot(batch);
1478
  if (head_cur < 0) {
1479
  LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
1480
  return false;
1481
  }
1482
 
1483
- apply_ubatch(head_cur, batch);
1484
 
1485
  // keep the head at the old position because we will read the KV data into it in state_read_data()
1486
  head = head_cur;
@@ -1488,8 +1553,8 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
1488
  // DEBUG CHECK: head_cur should be our first cell, head_cur + cell_count - 1 should be our last cell (verify seq_id and pos values)
1489
  // Assume that this is one contiguous block of cells
1490
  GGML_ASSERT(head_cur + cell_count <= cells.size());
1491
- GGML_ASSERT(cells.pos_get(head_cur) == batch.pos[0]);
1492
- GGML_ASSERT(cells.pos_get(head_cur + cell_count - 1) == batch.pos[cell_count - 1]);
1493
  GGML_ASSERT(cells.seq_has(head_cur, dest_seq_id));
1494
  GGML_ASSERT(cells.seq_has(head_cur + cell_count - 1, dest_seq_id));
1495
  } else {
@@ -1674,7 +1739,7 @@ llama_kv_cache_unified_state::llama_kv_cache_unified_state(
1674
  llama_context * lctx,
1675
  bool do_shift,
1676
  defrag_info dinfo) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), lctx(lctx), do_shift(do_shift), dinfo(std::move(dinfo)) {
1677
- if (!do_shift && dinfo.empty()) {
1678
  status = LLAMA_MEMORY_STATUS_NO_UPDATE;
1679
  }
1680
  }
 
127
  ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
128
  ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
129
  }
130
+
131
+ const char * LLAMA_KV_CACHE_DEBUG = getenv("LLAMA_KV_CACHE_DEBUG");
132
+ debug = LLAMA_KV_CACHE_DEBUG ? atoi(LLAMA_KV_CACHE_DEBUG) : 0;
133
  }
134
 
135
  void llama_kv_cache_unified::clear(bool data) {
 
310
  llama_memory_state_ptr llama_kv_cache_unified::init_batch(
311
  const llama_batch & batch,
312
  uint32_t n_ubatch,
313
+ bool embd_all) {
314
+ GGML_UNUSED(embd_all);
 
315
 
316
+ do {
317
+ auto sbatch = llama_sbatch(batch, hparams.n_embd, true);
318
 
319
+ std::vector<llama_ubatch> ubatches;
320
+ while (sbatch.n_tokens > 0) {
321
+ ubatches.push_back(sbatch.split_simple(n_ubatch));
322
+ }
323
 
324
+ auto heads = prepare(ubatches);
325
+ if (heads.empty()) {
326
+ break;
327
+ }
328
+
329
+ return std::make_unique<llama_kv_cache_unified_state>(
330
+ this, std::move(sbatch), std::move(heads), std::move(ubatches));
331
+ } while (false);
332
 
333
+ return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
 
334
  }
335
 
336
  llama_memory_state_ptr llama_kv_cache_unified::init_full() {
 
518
  head_cur = 0;
519
  }
520
 
 
 
521
  if (n_tokens > cells.size()) {
522
  LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size());
523
  return -1;
524
  }
525
 
526
+ if (debug > 0) {
527
+ LLAMA_LOG_DEBUG("%s: n = %5d, used = %5d, head = %5d, size = %5d, n_swa = %5d\n", __func__, cells.used_max_p1(), cells.get_used(), head, get_size(), n_swa);
 
528
 
529
+ if ((debug == 2 && n_swa > 0) || debug > 2) {
530
+ std::string ss;
 
 
531
  for (uint32_t i = 0; i < cells.size(); ++i) {
532
  if (cells.is_empty(i)) {
533
  ss += '.';
534
  } else {
535
+ assert(cells.seq_count(i) >= 1);
536
+
537
+ if (cells.seq_count(i) == 1) {
538
+ ss += std::to_string(cells.seq_get(i));
539
+ } else {
540
+ ss += 'M';
541
+ }
542
  }
543
  if (i%256 == 255) {
544
+ ss += " *";
545
  ss += '\n';
546
  }
547
  }
548
+ LLAMA_LOG_DEBUG("\n%s\n", ss.c_str());
549
  }
 
 
550
 
551
+ if ((debug == 2 && n_swa > 0) || debug > 2) {
552
+ std::string ss;
553
+ for (uint32_t i = 0; i < cells.size(); ++i) {
554
+ std::string cur;
555
+ if (cells.is_empty(i)) {
556
+ cur = '.';
557
+ } else {
558
+ cur = std::to_string(cells.pos_get(i));
559
+ }
560
+ const int n = cur.size();
561
+ for (int j = 0; j < 5 - n; ++j) {
562
+ cur += ' ';
563
+ }
564
+ ss += cur;
565
+ if (i%256 == 255) {
566
+ ss += " *";
567
+ }
568
+ if (i%64 == 63) {
569
+ ss += '\n';
570
+ }
571
+ }
572
+ LLAMA_LOG_DEBUG("\n%s\n", ss.c_str());
573
  }
574
 
575
+ for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
576
+ if (cells.seq_pos_min(s) < 0) {
577
+ continue;
578
+ }
579
+
580
+ LLAMA_LOG_DEBUG("%s: min[%d] = %5d, max[%d] = %5d\n", __func__, s, cells.seq_pos_min(s), s, cells.seq_pos_max(s));
581
+ }
582
  }
 
583
 
584
  uint32_t n_tested = 0;
585
 
 
590
  continue;
591
  }
592
 
 
 
 
 
 
 
593
  bool found = true;
594
  for (uint32_t i = 0; i < n_tokens; i++) {
595
+ //const llama_pos pos = ubatch.pos[i];
596
+ //const llama_seq_id seq_id = ubatch.seq_id[i][0];
597
 
598
  // can we use this cell? either:
599
  // - the cell is empty
600
  // - the cell is occupied only by one sequence:
601
+ // - (disabled) mask causally, if the sequence is the same as the one we are inserting
602
  // - mask SWA, using current max pos for that sequence in the cache
603
  // always insert in the cell with minimum pos
604
  bool can_use = cells.is_empty(head_cur + i);
 
606
  if (!can_use && cells.seq_count(head_cur + i) == 1) {
607
  const llama_pos pos_cell = cells.pos_get(head_cur + i);
608
 
609
+ // (disabled) causal mask
610
+ // note: it's better to purge any "future" tokens beforehand
611
+ //if (cells.seq_has(head_cur + i, seq_id)) {
612
+ // can_use = pos_cell >= pos;
613
+ //}
614
 
615
  if (!can_use) {
616
  const llama_seq_id seq_id_cell = cells.seq_get(head_cur + i);
617
 
618
  // SWA mask
619
+ if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
 
 
 
 
 
620
  can_use = true;
621
  }
622
  }
 
644
  }
645
 
646
  void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch) {
647
+ if (debug > 0) {
648
+ LLAMA_LOG_DEBUG("%s: ubatch info:\n", __func__);
649
+ LLAMA_LOG_DEBUG("%s: n_tokens = %d, equal_seqs = %d\n", __func__, ubatch.n_tokens, ubatch.equal_seqs);
650
+ LLAMA_LOG_DEBUG("%s: n_seq_tokens = %d, n_seqs = %d\n", __func__, ubatch.n_seq_tokens, ubatch.n_seqs);
651
+ }
652
+
653
+ // keep track of the max sequence position that we would overwrite with this ubatch
654
+ // for non-SWA cache, this would be always empty
655
+ llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ];
656
+ for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
657
+ seq_pos_max_rm[s] = -1;
658
+ }
659
+
660
+ for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
661
+ for (uint32_t j = 0; j < ubatch.n_seq_tokens; ++j) {
662
+ const uint32_t idx = s*ubatch.n_seq_tokens + j;
663
+
664
+ if (!cells.is_empty(head_cur + idx)) {
665
+ assert(cells.seq_count(head_cur + idx) == 1);
666
+
667
+ const llama_seq_id seq_id = cells.seq_get(head_cur + idx);
668
+ const llama_pos pos = cells.pos_get(head_cur + idx);
669
 
670
+ seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
671
 
672
+ cells.rm(head_cur + idx);
673
+ }
674
+
675
+ cells.pos_set(head_cur + idx, ubatch.pos[idx]);
676
+
677
+ // TODO: fix indexing [UBATCH_IDX]
678
+ for (int32_t i = 0; i < ubatch.n_seq_id[s]; i++) {
679
+ cells.seq_add(head_cur + idx, ubatch.seq_id[s][i]);
680
+ }
681
  }
682
  }
683
 
684
+ // note: we want to preserve the invariant that all positions between [pos_min, pos_max] for each sequence
685
+ // will be present in the cache. so we have to purge any position which is less than those we would overwrite
686
+ // ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092
687
+ for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
688
+ if (seq_pos_max_rm[s] == -1) {
689
+ continue;
690
+ }
691
+
692
+ if (cells.seq_pos_min(s) <= seq_pos_max_rm[s]) {
693
+ LLAMA_LOG_DEBUG("%s: purging positions [%d, %d] of sequence %d from KV cache\n",
694
+ __func__, cells.seq_pos_min(s), seq_pos_max_rm[s], s);
695
+
696
+ seq_rm(s, cells.seq_pos_min(s), seq_pos_max_rm[s] + 1);
697
+ }
698
+ }
699
  // move the head at the end of the slot
700
  head = head_cur + ubatch.n_tokens;
701
  }
 
792
  }
793
 
794
  void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
795
+ const uint32_t n_tokens = ubatch->n_tokens;
796
+ const uint32_t n_seq_tokens = ubatch->n_seq_tokens;
797
+ const uint32_t n_seqs = ubatch->n_seqs;
798
 
799
  GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
800
  float * data = (float *) dst->data;
801
 
802
+ const int64_t n_kv = dst->ne[0];
803
 
804
  // Use only the previous KV cells of the correct sequence for each token of the ubatch.
805
  // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
 
813
  // xxxxx-----
814
  // xxxxx-----
815
  // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
816
+ for (uint32_t h = 0; h < 1; ++h) {
817
+ for (uint32_t s = 0; s < n_seqs; ++s) {
818
  const llama_seq_id seq_id = ubatch->seq_id[s][0];
819
 
820
+ for (uint32_t j = 0; j < n_seq_tokens; ++j) {
821
+ const uint32_t idx = s*n_seq_tokens + j;
822
+
823
+ const llama_pos p1 = ubatch->pos[idx];
824
 
825
  for (uint32_t i = 0; i < n_kv; ++i) {
826
  float f = 0.0f;
 
850
  f = -INFINITY;
851
  }
852
 
853
+ data[h*(n_kv*n_tokens) + idx*n_kv + i] = f;
854
  }
855
  }
856
  }
857
 
858
  // mask padded tokens
859
  if (data) {
860
+ for (uint32_t j = n_tokens; j < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++j) {
861
+ for (uint32_t i = 0; i < n_kv; ++i) {
862
+ data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY;
863
  }
864
  }
865
  }
 
1510
  seq_rm(dest_seq_id, -1, -1);
1511
 
1512
  llama_sbatch sbatch;
1513
+ llama_ubatch ubatch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
1514
 
1515
+ ubatch.n_tokens = cell_count;
1516
+ ubatch.n_seq_tokens = cell_count;
1517
+ ubatch.n_seqs = 1;
1518
 
1519
  for (uint32_t i = 0; i < cell_count; ++i) {
1520
  llama_pos pos;
 
1534
  io.read_to(&seq_id, sizeof(seq_id));
1535
  }
1536
 
1537
+ ubatch.pos[i] = pos;
1538
+ ubatch.n_seq_id[i] = n_seq_id;
1539
+ ubatch.seq_id[i] = &dest_seq_id;
1540
  }
1541
 
1542
+ const auto head_cur = find_slot(ubatch);
1543
  if (head_cur < 0) {
1544
  LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
1545
  return false;
1546
  }
1547
 
1548
+ apply_ubatch(head_cur, ubatch);
1549
 
1550
  // keep the head at the old position because we will read the KV data into it in state_read_data()
1551
  head = head_cur;
 
1553
  // DEBUG CHECK: head_cur should be our first cell, head_cur + cell_count - 1 should be our last cell (verify seq_id and pos values)
1554
  // Assume that this is one contiguous block of cells
1555
  GGML_ASSERT(head_cur + cell_count <= cells.size());
1556
+ GGML_ASSERT(cells.pos_get(head_cur) == ubatch.pos[0]);
1557
+ GGML_ASSERT(cells.pos_get(head_cur + cell_count - 1) == ubatch.pos[cell_count - 1]);
1558
  GGML_ASSERT(cells.seq_has(head_cur, dest_seq_id));
1559
  GGML_ASSERT(cells.seq_has(head_cur + cell_count - 1, dest_seq_id));
1560
  } else {
 
1739
  llama_context * lctx,
1740
  bool do_shift,
1741
  defrag_info dinfo) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), lctx(lctx), do_shift(do_shift), dinfo(std::move(dinfo)) {
1742
+ if (!do_shift && this->dinfo.empty()) {
1743
  status = LLAMA_MEMORY_STATUS_NO_UPDATE;
1744
  }
1745
  }
examples/talk-llama/llama-kv-cache-unified.h CHANGED
@@ -59,8 +59,7 @@ public:
59
  llama_memory_state_ptr init_batch(
60
  const llama_batch & batch,
61
  uint32_t n_ubatch,
62
- bool embd_pooled,
63
- bool logits_all) override;
64
 
65
  llama_memory_state_ptr init_full() override;
66
 
@@ -158,6 +157,8 @@ private:
158
  // SWA
159
  const uint32_t n_swa = 0;
160
 
 
 
161
  const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
162
 
163
  std::vector<ggml_context_ptr> ctxs;
 
59
  llama_memory_state_ptr init_batch(
60
  const llama_batch & batch,
61
  uint32_t n_ubatch,
62
+ bool embd_all) override;
 
63
 
64
  llama_memory_state_ptr init_full() override;
65
 
 
157
  // SWA
158
  const uint32_t n_swa = 0;
159
 
160
+ int debug = 0;
161
+
162
  const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
163
 
164
  std::vector<ggml_context_ptr> ctxs;
examples/talk-llama/llama-kv-cells.h CHANGED
@@ -23,7 +23,7 @@ public:
23
 
24
  used.clear();
25
 
26
- for (uint32_t s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
27
  seq_pos[s].clear();
28
  }
29
  }
@@ -240,7 +240,7 @@ public:
240
  llama_seq_id seq_get(uint32_t i) const {
241
  assert(seq[i].count() == 1);
242
 
243
- for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
244
  if (seq[i].test(s)) {
245
  return s;
246
  }
@@ -253,7 +253,7 @@ public:
253
  // return -1 if the sequence is not present
254
  llama_pos seq_pos_min(llama_seq_id seq_id) const {
255
  assert(seq_id >= 0);
256
- assert(seq_id < LLAMA_MAX_PARALLEL_SEQUENCES);
257
 
258
  if (seq_pos[seq_id].empty()) {
259
  return -1;
@@ -266,7 +266,7 @@ public:
266
  // return -1 if the sequence is not present
267
  llama_pos seq_pos_max(llama_seq_id seq_id) const {
268
  assert(seq_id >= 0);
269
- assert(seq_id < LLAMA_MAX_PARALLEL_SEQUENCES);
270
 
271
  if (seq_pos[seq_id].empty()) {
272
  return -1;
@@ -384,20 +384,20 @@ private:
384
  //
385
  std::vector<llama_pos> shift;
386
 
387
- using bits_t = std::bitset<LLAMA_MAX_PARALLEL_SEQUENCES>;
388
 
389
  // the bitset seq[i] tells us which sequences are currently occupying the i-th cell
390
  std::vector<bits_t> seq;
391
 
392
  // the set seq_pos[s] tells us which positions are currently present for sequence s
393
  // this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache
394
- std::set<llama_pos> seq_pos[LLAMA_MAX_PARALLEL_SEQUENCES];
395
 
396
  // helper functions for updating `seq_pos`, once cell at a time:
397
 
398
  // remove cell i
399
  void seq_pos_rm(uint32_t i) {
400
- for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
401
  if (seq[i].test(s)) {
402
  seq_pos[s].erase(pos[i]);
403
  }
@@ -406,7 +406,7 @@ private:
406
 
407
  // add cell i
408
  void seq_pos_add(uint32_t i) {
409
- for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
410
  if (seq[i].test(s)) {
411
  seq_pos[s].insert(pos[i]);
412
  }
 
23
 
24
  used.clear();
25
 
26
+ for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
27
  seq_pos[s].clear();
28
  }
29
  }
 
240
  llama_seq_id seq_get(uint32_t i) const {
241
  assert(seq[i].count() == 1);
242
 
243
+ for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
244
  if (seq[i].test(s)) {
245
  return s;
246
  }
 
253
  // return -1 if the sequence is not present
254
  llama_pos seq_pos_min(llama_seq_id seq_id) const {
255
  assert(seq_id >= 0);
256
+ assert(seq_id < LLAMA_MAX_SEQ);
257
 
258
  if (seq_pos[seq_id].empty()) {
259
  return -1;
 
266
  // return -1 if the sequence is not present
267
  llama_pos seq_pos_max(llama_seq_id seq_id) const {
268
  assert(seq_id >= 0);
269
+ assert(seq_id < LLAMA_MAX_SEQ);
270
 
271
  if (seq_pos[seq_id].empty()) {
272
  return -1;
 
384
  //
385
  std::vector<llama_pos> shift;
386
 
387
+ using bits_t = std::bitset<LLAMA_MAX_SEQ>;
388
 
389
  // the bitset seq[i] tells us which sequences are currently occupying the i-th cell
390
  std::vector<bits_t> seq;
391
 
392
  // the set seq_pos[s] tells us which positions are currently present for sequence s
393
  // this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache
394
+ std::set<llama_pos> seq_pos[LLAMA_MAX_SEQ];
395
 
396
  // helper functions for updating `seq_pos`, once cell at a time:
397
 
398
  // remove cell i
399
  void seq_pos_rm(uint32_t i) {
400
+ for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
401
  if (seq[i].test(s)) {
402
  seq_pos[s].erase(pos[i]);
403
  }
 
406
 
407
  // add cell i
408
  void seq_pos_add(uint32_t i) {
409
+ for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
410
  if (seq[i].test(s)) {
411
  seq_pos[s].insert(pos[i]);
412
  }
examples/talk-llama/llama-memory.h CHANGED
@@ -73,8 +73,7 @@ struct llama_memory_i {
73
  virtual llama_memory_state_ptr init_batch(
74
  const llama_batch & batch,
75
  uint32_t n_ubatch,
76
- bool embd_pooled,
77
- bool logits_all) = 0;
78
 
79
  // simulate full cache, used for allocating worst-case compute buffers
80
  virtual llama_memory_state_ptr init_full() = 0;
 
73
  virtual llama_memory_state_ptr init_batch(
74
  const llama_batch & batch,
75
  uint32_t n_ubatch,
76
+ bool embd_all) = 0;
 
77
 
78
  // simulate full cache, used for allocating worst-case compute buffers
79
  virtual llama_memory_state_ptr init_full() = 0;
examples/talk-llama/llama-model.cpp CHANGED
@@ -80,6 +80,7 @@ const char * llm_type_name(llm_type type) {
80
  case LLM_TYPE_40B: return "40B";
81
  case LLM_TYPE_65B: return "65B";
82
  case LLM_TYPE_70B: return "70B";
 
83
  case LLM_TYPE_236B: return "236B";
84
  case LLM_TYPE_290B: return "290B";
85
  case LLM_TYPE_314B: return "314B";
@@ -598,6 +599,16 @@ void llama_model::load_hparams(llama_model_loader & ml) {
598
  hparams.use_kq_norm = false;
599
  }
600
  } break;
 
 
 
 
 
 
 
 
 
 
601
  case LLM_ARCH_DECI:
602
  {
603
  ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
@@ -738,6 +749,16 @@ void llama_model::load_hparams(llama_model_loader & ml) {
738
  }
739
  }
740
  } break;
 
 
 
 
 
 
 
 
 
 
741
  case LLM_ARCH_BLOOM:
742
  {
743
  ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
@@ -1444,6 +1465,20 @@ void llama_model::load_hparams(llama_model_loader & ml) {
1444
  default: type = LLM_TYPE_UNKNOWN;
1445
  }
1446
  } break;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1447
  default: throw std::runtime_error("unsupported model architecture");
1448
  }
1449
 
@@ -2187,6 +2222,32 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
2187
  layer.layer_out_norm_b = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i), {n_embd}, 0);
2188
  }
2189
  } break;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2190
  case LLM_ARCH_JINA_BERT_V2:
2191
  {
2192
  tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); // word_embeddings
@@ -2224,8 +2285,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
2224
  layer.attn_norm_2 = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED);
2225
  layer.attn_norm_2_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED);
2226
 
2227
- layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
2228
- layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
2229
 
2230
  layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
2231
  layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0);
@@ -4123,6 +4184,89 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
4123
  layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0);
4124
  }
4125
  } break;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4126
  default:
4127
  throw std::runtime_error("unknown architecture");
4128
  }
@@ -6043,7 +6187,7 @@ struct llm_build_bert : public llm_graph_context {
6043
  model.layers[il].ffn_gate, NULL, NULL,
6044
  model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
6045
  NULL,
6046
- LLM_FFN_GELU, LLM_FFN_PAR, il);
6047
  cb(cur, "ffn_out", il);
6048
  } else {
6049
  cur = build_ffn(cur,
@@ -6074,6 +6218,117 @@ struct llm_build_bert : public llm_graph_context {
6074
  }
6075
  };
6076
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6077
  struct llm_build_bloom : public llm_graph_context {
6078
  llm_build_bloom(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
6079
  const int64_t n_embd_head = hparams.n_embd_head_v;
@@ -8857,7 +9112,6 @@ struct llm_build_mamba : public llm_graph_context {
8857
  inpL = build_inp_embd(model.tok_embd);
8858
 
8859
  ggml_tensor * state_copy = build_inp_s_copy();
8860
- ggml_tensor * state_mask = build_inp_s_mask();
8861
 
8862
  for (int il = 0; il < n_layer; ++il) {
8863
  // norm
@@ -8866,8 +9120,7 @@ struct llm_build_mamba : public llm_graph_context {
8866
  LLM_NORM_RMS, il);
8867
  cb(cur, "attn_norm", il);
8868
 
8869
- //cur = build_mamba_layer(gf, cur, state_copy, state_mask, il);
8870
- cur = build_mamba_layer(gf, cur, state_copy, state_mask, ubatch, il);
8871
 
8872
  if (il == n_layer - 1) {
8873
  // skip computing output for unused tokens
@@ -8908,7 +9161,6 @@ struct llm_build_mamba : public llm_graph_context {
8908
  ggml_cgraph * gf,
8909
  ggml_tensor * cur,
8910
  ggml_tensor * state_copy,
8911
- ggml_tensor * state_mask,
8912
  const llama_ubatch & ubatch,
8913
  int il) const {
8914
  const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
@@ -8935,12 +9187,12 @@ struct llm_build_mamba : public llm_graph_context {
8935
  ggml_tensor * ssm_states_all = kv_state->get_v_l(il);
8936
 
8937
  // (ab)using the KV cache to store the states
8938
- ggml_tensor * conv = build_copy_mask_state(
8939
- gf, conv_states_all, state_copy, state_mask,
8940
  hparams.n_embd_k_s(), n_seqs);
8941
  conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner, n_seqs);
8942
- ggml_tensor * ssm = build_copy_mask_state(
8943
- gf, ssm_states_all, state_copy, state_mask,
8944
  hparams.n_embd_v_s(), n_seqs);
8945
  ssm = ggml_reshape_3d(ctx0, ssm, d_state, d_inner, n_seqs);
8946
 
@@ -11656,7 +11908,6 @@ struct llm_build_rwkv6_base : public llm_graph_context {
11656
  ggml_tensor * cur,
11657
  ggml_tensor * x_prev,
11658
  ggml_tensor * state_copy,
11659
- ggml_tensor * state_mask,
11660
  const llama_ubatch & ubatch,
11661
  int il) const {
11662
  const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
@@ -11780,8 +12031,8 @@ struct llm_build_rwkv6_base : public llm_graph_context {
11780
  k = ggml_sub(ctx0, k, ggml_mul(ctx0, k, w));
11781
  }
11782
 
11783
- ggml_tensor * wkv_state = build_copy_mask_state(
11784
- gf, kv_state->get_v_l(il), state_copy, state_mask,
11785
  hparams.n_embd_v_s(), n_seqs);
11786
 
11787
  ggml_tensor * wkv_output;
@@ -11837,7 +12088,6 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base {
11837
  inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1);
11838
 
11839
  ggml_tensor * state_copy = build_inp_s_copy();
11840
- ggml_tensor * state_mask = build_inp_s_mask();
11841
 
11842
  const auto n_embd = hparams.n_embd;
11843
  const auto n_seq_tokens = ubatch.n_seq_tokens;
@@ -11848,7 +12098,7 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base {
11848
  inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
11849
 
11850
  ggml_tensor * token_shift = build_rwkv_token_shift_load(
11851
- gf, state_copy, state_mask, ubatch, il
11852
  );
11853
 
11854
  ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0);
@@ -11864,7 +12114,7 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base {
11864
  1
11865
  );
11866
 
11867
- cur = build_rwkv6_time_mix(gf, att_norm, x_prev, state_copy, state_mask, ubatch, il);
11868
 
11869
  ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
11870
  cb(ffn_inp, "ffn_inp", il);
@@ -11935,7 +12185,6 @@ struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base {
11935
  inpL = build_inp_embd(model.tok_embd);
11936
 
11937
  ggml_tensor * state_copy = build_inp_s_copy();
11938
- ggml_tensor * state_mask = build_inp_s_mask();
11939
 
11940
  const auto n_embd = hparams.n_embd;
11941
  const auto n_seq_tokens = ubatch.n_seq_tokens;
@@ -11946,7 +12195,7 @@ struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base {
11946
  inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
11947
 
11948
  ggml_tensor * token_shift = build_rwkv_token_shift_load(
11949
- gf, state_copy, state_mask, ubatch, il
11950
  );
11951
 
11952
  ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, il);
@@ -11959,7 +12208,7 @@ struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base {
11959
  1
11960
  );
11961
 
11962
- cur = build_rwkv6_time_mix(gf, att_norm, x_prev, state_copy, state_mask, ubatch, il);
11963
 
11964
  token_shift = ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm));
11965
  ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il));
@@ -12051,7 +12300,6 @@ struct llm_build_rwkv7_base : public llm_graph_context {
12051
  ggml_tensor * cur,
12052
  ggml_tensor * x_prev,
12053
  ggml_tensor * state_copy,
12054
- ggml_tensor * state_mask,
12055
  ggml_tensor *& first_layer_value,
12056
  const llama_ubatch & ubatch,
12057
  int il) const {
@@ -12134,8 +12382,8 @@ struct llm_build_rwkv7_base : public llm_graph_context {
12134
  v = ggml_reshape_3d(ctx0, v, head_size, head_count, n_tokens);
12135
  a = ggml_reshape_3d(ctx0, a, head_size, head_count, n_tokens);
12136
 
12137
- ggml_tensor * wkv_state = build_copy_mask_state(
12138
- gf, kv_state->get_v_l(il), state_copy, state_mask,
12139
  hparams.n_embd_v_s(), n_seqs);
12140
 
12141
  ggml_tensor * wkv_output = ggml_rwkv_wkv7(ctx0, r, w, k, v, ggml_neg(ctx0, kk), ggml_mul(ctx0, kk, a), wkv_state);
@@ -12193,7 +12441,6 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base {
12193
  inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1);
12194
 
12195
  ggml_tensor * state_copy = build_inp_s_copy();
12196
- ggml_tensor * state_mask = build_inp_s_mask();
12197
 
12198
  const auto n_embd = hparams.n_embd;
12199
  const auto n_seq_tokens = ubatch.n_seq_tokens;
@@ -12204,7 +12451,7 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base {
12204
  inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
12205
 
12206
  ggml_tensor * token_shift = build_rwkv_token_shift_load(
12207
- gf, state_copy, state_mask, ubatch, il
12208
  );
12209
 
12210
  ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0);
@@ -12220,7 +12467,7 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base {
12220
  1
12221
  );
12222
 
12223
- cur = build_rwkv7_time_mix(gf, att_norm, x_prev, state_copy, state_mask, v_first, ubatch, il);
12224
 
12225
  ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
12226
  cb(ffn_inp, "ffn_inp", il);
@@ -12287,7 +12534,6 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base {
12287
  inpL = build_inp_embd(model.tok_embd);
12288
 
12289
  ggml_tensor * state_copy = build_inp_s_copy();
12290
- ggml_tensor * state_mask = build_inp_s_mask();
12291
 
12292
  const auto n_embd = hparams.n_embd;
12293
  const auto n_seq_tokens = ubatch.n_seq_tokens;
@@ -12298,7 +12544,7 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base {
12298
  inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
12299
 
12300
  ggml_tensor * token_shift = build_rwkv_token_shift_load(
12301
- gf, state_copy, state_mask, ubatch, il
12302
  );
12303
 
12304
  ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, il);
@@ -12311,7 +12557,7 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base {
12311
  1
12312
  );
12313
 
12314
- cur = build_rwkv7_time_mix(gf, att_norm, x_prev, state_copy, state_mask, v_first, ubatch, il);
12315
 
12316
  token_shift = ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm));
12317
  ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il));
@@ -13203,6 +13449,291 @@ struct llm_build_bailingmoe : public llm_graph_context {
13203
  }
13204
  };
13205
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13206
  llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_cparams & cparams) const {
13207
  llama_memory_i * res;
13208
 
@@ -13211,6 +13742,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
13211
  case LLM_ARCH_JINA_BERT_V2:
13212
  case LLM_ARCH_NOMIC_BERT:
13213
  case LLM_ARCH_NOMIC_BERT_MOE:
 
13214
  case LLM_ARCH_WAVTOKENIZER_DEC:
13215
  {
13216
  res = nullptr;
@@ -13319,6 +13851,10 @@ llm_graph_result_ptr llama_model::build_graph(
13319
  {
13320
  llm = std::make_unique<llm_build_bert>(*this, params, gf);
13321
  } break;
 
 
 
 
13322
  case LLM_ARCH_BLOOM:
13323
  {
13324
  llm = std::make_unique<llm_build_bloom>(*this, params, gf);
@@ -13541,6 +14077,14 @@ llm_graph_result_ptr llama_model::build_graph(
13541
  {
13542
  llm = std::make_unique<llm_build_bailingmoe>(*this, params, gf);
13543
  } break;
 
 
 
 
 
 
 
 
13544
  default:
13545
  GGML_ABORT("fatal error");
13546
  }
@@ -13690,6 +14234,8 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
13690
  case LLM_ARCH_GRANITE_MOE:
13691
  case LLM_ARCH_CHAMELEON:
13692
  case LLM_ARCH_BAILINGMOE:
 
 
13693
  return LLAMA_ROPE_TYPE_NORM;
13694
 
13695
  // the pairs of head values are offset by n_rot/2
@@ -13723,6 +14269,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
13723
  case LLM_ARCH_NEMOTRON:
13724
  case LLM_ARCH_EXAONE:
13725
  case LLM_ARCH_MINICPM3:
 
13726
  return LLAMA_ROPE_TYPE_NEOX;
13727
 
13728
  case LLM_ARCH_QWEN2VL:
 
80
  case LLM_TYPE_40B: return "40B";
81
  case LLM_TYPE_65B: return "65B";
82
  case LLM_TYPE_70B: return "70B";
83
+ case LLM_TYPE_142B: return "142B";
84
  case LLM_TYPE_236B: return "236B";
85
  case LLM_TYPE_290B: return "290B";
86
  case LLM_TYPE_314B: return "314B";
 
599
  hparams.use_kq_norm = false;
600
  }
601
  } break;
602
+ case LLM_ARCH_ARCEE:
603
+ {
604
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
605
+
606
+ // Arcee uses the same structure as Llama
607
+ switch (hparams.n_layer) {
608
+ case 36: type = LLM_TYPE_4B; break;
609
+ default: type = LLM_TYPE_UNKNOWN;
610
+ }
611
+ } break;
612
  case LLM_ARCH_DECI:
613
  {
614
  ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
 
749
  }
750
  }
751
  } break;
752
+ case LLM_ARCH_NEO_BERT:
753
+ {
754
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
755
+ ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
756
+ ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type);
757
+
758
+ if (hparams.n_layer == 28) {
759
+ type = LLM_TYPE_250M;
760
+ }
761
+ } break;
762
  case LLM_ARCH_BLOOM:
763
  {
764
  ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
 
1465
  default: type = LLM_TYPE_UNKNOWN;
1466
  }
1467
  } break;
1468
+ case LLM_ARCH_DOTS1:
1469
+ {
1470
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
1471
+ ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead);
1472
+ ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
1473
+ ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared);
1474
+ ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale);
1475
+ ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false);
1476
+ ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false);
1477
+ switch (hparams.n_layer) {
1478
+ case 62: type = LLM_TYPE_142B; break;
1479
+ default: type = LLM_TYPE_UNKNOWN;
1480
+ }
1481
+ } break;
1482
  default: throw std::runtime_error("unsupported model architecture");
1483
  }
1484
 
 
2222
  layer.layer_out_norm_b = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i), {n_embd}, 0);
2223
  }
2224
  } break;
2225
+ case LLM_ARCH_NEO_BERT:
2226
+ {
2227
+ tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
2228
+
2229
+ cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, TENSOR_NOT_REQUIRED);
2230
+ cls_b = create_tensor(tn(LLM_TENSOR_CLS, "bias"), {n_embd}, TENSOR_NOT_REQUIRED);
2231
+
2232
+ cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED);
2233
+ cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"), {hparams.n_cls_out}, TENSOR_NOT_REQUIRED);
2234
+
2235
+ output_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_OUTPUT_NORM, "weight"), {n_embd}, 0);
2236
+
2237
+ for (int i = 0; i < n_layer; ++i) {
2238
+ auto & layer = layers[i];
2239
+
2240
+ layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
2241
+
2242
+ layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
2243
+ layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
2244
+
2245
+ layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
2246
+
2247
+ layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff*2}, 0);
2248
+ layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
2249
+ }
2250
+ } break;
2251
  case LLM_ARCH_JINA_BERT_V2:
2252
  {
2253
  tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); // word_embeddings
 
2285
  layer.attn_norm_2 = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED);
2286
  layer.attn_norm_2_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED);
2287
 
2288
+ layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED);
2289
+ layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, layer.ffn_gate ? n_ff : n_ff * 2}, 0);
2290
 
2291
  layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
2292
  layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0);
 
4184
  layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0);
4185
  }
4186
  } break;
4187
+ case LLM_ARCH_DOTS1:
4188
+ {
4189
+ const int64_t n_ff_exp = hparams.n_ff_exp;
4190
+ const int64_t n_expert_shared = hparams.n_expert_shared;
4191
+
4192
+ tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
4193
+
4194
+ output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
4195
+ output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
4196
+
4197
+ for (int i = 0; i < n_layer; ++i) {
4198
+ auto & layer = layers[i];
4199
+
4200
+ layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
4201
+
4202
+ layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
4203
+ layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
4204
+ layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
4205
+ layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
4206
+
4207
+ layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0);
4208
+ layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0);
4209
+
4210
+ layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
4211
+
4212
+ if (i < (int) hparams.n_layer_dense_lead) {
4213
+ layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
4214
+ layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
4215
+ layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
4216
+ } else {
4217
+ layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
4218
+ layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED);
4219
+
4220
+ if (n_expert == 0) {
4221
+ throw std::runtime_error("n_expert must be > 0");
4222
+ }
4223
+ if (n_expert_used == 0) {
4224
+ throw std::runtime_error("n_expert_used must be > 0");
4225
+ }
4226
+
4227
+ // MoE branch
4228
+ layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0);
4229
+ layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0);
4230
+ layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0);
4231
+
4232
+ // Shared expert branch
4233
+ layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0);
4234
+ layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, 0);
4235
+ layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0);
4236
+ }
4237
+ }
4238
+ } break;
4239
+ case LLM_ARCH_ARCEE:
4240
+ {
4241
+ tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
4242
+
4243
+ // output
4244
+ output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
4245
+ output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
4246
+
4247
+ // if output is NULL, init from the input tok embed
4248
+ if (output == NULL) {
4249
+ output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
4250
+ }
4251
+
4252
+ for (int i = 0; i < n_layer; ++i) {
4253
+ auto & layer = layers[i];
4254
+
4255
+ layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
4256
+
4257
+ layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
4258
+ layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0);
4259
+ layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0);
4260
+ layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
4261
+
4262
+ layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
4263
+
4264
+ layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
4265
+
4266
+ layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
4267
+ layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
4268
+ }
4269
+ } break;
4270
  default:
4271
  throw std::runtime_error("unknown architecture");
4272
  }
 
6187
  model.layers[il].ffn_gate, NULL, NULL,
6188
  model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
6189
  NULL,
6190
+ model.layers[il].ffn_gate ? LLM_FFN_GELU : LLM_FFN_GEGLU, LLM_FFN_PAR, il);
6191
  cb(cur, "ffn_out", il);
6192
  } else {
6193
  cur = build_ffn(cur,
 
6218
  }
6219
  };
6220
 
6221
+ struct llm_build_neo_bert : public llm_graph_context {
6222
+ llm_build_neo_bert(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
6223
+ const int64_t n_embd_head = hparams.n_embd_head_v;
6224
+ const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
6225
+
6226
+ GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
6227
+
6228
+ ggml_tensor * cur;
6229
+ ggml_tensor * inpL;
6230
+ ggml_tensor * inp_pos = build_inp_pos();
6231
+
6232
+ // construct input embeddings (token, type, position)
6233
+ inpL = build_inp_embd(model.tok_embd);
6234
+ cb(inpL, "inp_embd", -1);
6235
+
6236
+ auto * inp_attn = build_attn_inp_no_cache();
6237
+
6238
+ // iterate layers
6239
+ for (int il = 0; il < n_layer; ++il) {
6240
+ ggml_tensor * cur = inpL;
6241
+
6242
+ ggml_tensor * Qcur;
6243
+ ggml_tensor * Kcur;
6244
+ ggml_tensor * Vcur;
6245
+
6246
+ // pre-norm
6247
+ cur = build_norm(inpL,
6248
+ model.layers[il].attn_norm, NULL,
6249
+ LLM_NORM_RMS, il);
6250
+
6251
+ // self-attention
6252
+ cur = build_lora_mm(model.layers[il].wqkv, cur);
6253
+ cb(cur, "wqkv", il);
6254
+
6255
+ Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
6256
+ Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
6257
+ Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
6258
+
6259
+ Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
6260
+ Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
6261
+ Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
6262
+
6263
+ // RoPE
6264
+ Qcur = ggml_rope_ext(
6265
+ ctx0, Qcur, inp_pos, nullptr,
6266
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
6267
+ ext_factor, attn_factor, beta_fast, beta_slow
6268
+ );
6269
+
6270
+ Kcur = ggml_rope_ext(
6271
+ ctx0, Kcur, inp_pos, nullptr,
6272
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
6273
+ ext_factor, attn_factor, beta_fast, beta_slow
6274
+ );
6275
+
6276
+ cb(Qcur, "Qcur", il);
6277
+ cb(Kcur, "Kcur", il);
6278
+ cb(Vcur, "Vcur", il);
6279
+
6280
+ cur = build_attn(inp_attn, gf,
6281
+ model.layers[il].wo, nullptr,
6282
+ Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
6283
+ cb(cur, "kqv_out", il);
6284
+
6285
+ if (il == n_layer - 1 && pooling_type == LLAMA_POOLING_TYPE_NONE) {
6286
+ // skip computing output for unused tokens
6287
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
6288
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
6289
+ inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
6290
+ }
6291
+
6292
+ // re-add the layer input
6293
+ cur = ggml_add(ctx0, cur, inpL);
6294
+
6295
+ ggml_tensor * ffn_inp = cur;
6296
+ cb(ffn_inp, "ffn_inp", il);
6297
+
6298
+ // pre-norm
6299
+ cur = build_norm(ffn_inp,
6300
+ model.layers[il].ffn_norm, NULL,
6301
+ LLM_NORM_RMS, il);
6302
+ cb(cur, "ffn_norm", il);
6303
+
6304
+ // feed-forward network
6305
+ cur = build_ffn(cur,
6306
+ model.layers[il].ffn_up,
6307
+ NULL, NULL, NULL, NULL, NULL,
6308
+ model.layers[il].ffn_down,
6309
+ NULL, NULL, NULL,
6310
+ LLM_FFN_SWIGLU, LLM_FFN_SEQ, il);
6311
+
6312
+ // attentions bypass the intermediate layer
6313
+ cur = ggml_add(ctx0, cur, ffn_inp);
6314
+
6315
+ // input for next layer
6316
+ inpL = cur;
6317
+ }
6318
+
6319
+ cur = inpL;
6320
+
6321
+ cur = build_norm(cur,
6322
+ model.output_norm_enc, NULL,
6323
+ LLM_NORM_RMS, -1);
6324
+
6325
+ cb(cur, "result_embd", -1);
6326
+ res->t_embd = cur;
6327
+
6328
+ ggml_build_forward_expand(gf, cur);
6329
+ }
6330
+ };
6331
+
6332
  struct llm_build_bloom : public llm_graph_context {
6333
  llm_build_bloom(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
6334
  const int64_t n_embd_head = hparams.n_embd_head_v;
 
9112
  inpL = build_inp_embd(model.tok_embd);
9113
 
9114
  ggml_tensor * state_copy = build_inp_s_copy();
 
9115
 
9116
  for (int il = 0; il < n_layer; ++il) {
9117
  // norm
 
9120
  LLM_NORM_RMS, il);
9121
  cb(cur, "attn_norm", il);
9122
 
9123
+ cur = build_mamba_layer(gf, cur, state_copy, ubatch, il);
 
9124
 
9125
  if (il == n_layer - 1) {
9126
  // skip computing output for unused tokens
 
9161
  ggml_cgraph * gf,
9162
  ggml_tensor * cur,
9163
  ggml_tensor * state_copy,
 
9164
  const llama_ubatch & ubatch,
9165
  int il) const {
9166
  const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
 
9187
  ggml_tensor * ssm_states_all = kv_state->get_v_l(il);
9188
 
9189
  // (ab)using the KV cache to store the states
9190
+ ggml_tensor * conv = build_recurrent_state(
9191
+ gf, conv_states_all, state_copy,
9192
  hparams.n_embd_k_s(), n_seqs);
9193
  conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner, n_seqs);
9194
+ ggml_tensor * ssm = build_recurrent_state(
9195
+ gf, ssm_states_all, state_copy,
9196
  hparams.n_embd_v_s(), n_seqs);
9197
  ssm = ggml_reshape_3d(ctx0, ssm, d_state, d_inner, n_seqs);
9198
 
 
11908
  ggml_tensor * cur,
11909
  ggml_tensor * x_prev,
11910
  ggml_tensor * state_copy,
 
11911
  const llama_ubatch & ubatch,
11912
  int il) const {
11913
  const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
 
12031
  k = ggml_sub(ctx0, k, ggml_mul(ctx0, k, w));
12032
  }
12033
 
12034
+ ggml_tensor * wkv_state = build_recurrent_state(
12035
+ gf, kv_state->get_v_l(il), state_copy,
12036
  hparams.n_embd_v_s(), n_seqs);
12037
 
12038
  ggml_tensor * wkv_output;
 
12088
  inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1);
12089
 
12090
  ggml_tensor * state_copy = build_inp_s_copy();
 
12091
 
12092
  const auto n_embd = hparams.n_embd;
12093
  const auto n_seq_tokens = ubatch.n_seq_tokens;
 
12098
  inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
12099
 
12100
  ggml_tensor * token_shift = build_rwkv_token_shift_load(
12101
+ gf, state_copy, ubatch, il
12102
  );
12103
 
12104
  ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0);
 
12114
  1
12115
  );
12116
 
12117
+ cur = build_rwkv6_time_mix(gf, att_norm, x_prev, state_copy, ubatch, il);
12118
 
12119
  ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
12120
  cb(ffn_inp, "ffn_inp", il);
 
12185
  inpL = build_inp_embd(model.tok_embd);
12186
 
12187
  ggml_tensor * state_copy = build_inp_s_copy();
 
12188
 
12189
  const auto n_embd = hparams.n_embd;
12190
  const auto n_seq_tokens = ubatch.n_seq_tokens;
 
12195
  inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
12196
 
12197
  ggml_tensor * token_shift = build_rwkv_token_shift_load(
12198
+ gf, state_copy, ubatch, il
12199
  );
12200
 
12201
  ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, il);
 
12208
  1
12209
  );
12210
 
12211
+ cur = build_rwkv6_time_mix(gf, att_norm, x_prev, state_copy, ubatch, il);
12212
 
12213
  token_shift = ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm));
12214
  ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il));
 
12300
  ggml_tensor * cur,
12301
  ggml_tensor * x_prev,
12302
  ggml_tensor * state_copy,
 
12303
  ggml_tensor *& first_layer_value,
12304
  const llama_ubatch & ubatch,
12305
  int il) const {
 
12382
  v = ggml_reshape_3d(ctx0, v, head_size, head_count, n_tokens);
12383
  a = ggml_reshape_3d(ctx0, a, head_size, head_count, n_tokens);
12384
 
12385
+ ggml_tensor * wkv_state = build_recurrent_state(
12386
+ gf, kv_state->get_v_l(il), state_copy,
12387
  hparams.n_embd_v_s(), n_seqs);
12388
 
12389
  ggml_tensor * wkv_output = ggml_rwkv_wkv7(ctx0, r, w, k, v, ggml_neg(ctx0, kk), ggml_mul(ctx0, kk, a), wkv_state);
 
12441
  inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1);
12442
 
12443
  ggml_tensor * state_copy = build_inp_s_copy();
 
12444
 
12445
  const auto n_embd = hparams.n_embd;
12446
  const auto n_seq_tokens = ubatch.n_seq_tokens;
 
12451
  inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
12452
 
12453
  ggml_tensor * token_shift = build_rwkv_token_shift_load(
12454
+ gf, state_copy, ubatch, il
12455
  );
12456
 
12457
  ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0);
 
12467
  1
12468
  );
12469
 
12470
+ cur = build_rwkv7_time_mix(gf, att_norm, x_prev, state_copy, v_first, ubatch, il);
12471
 
12472
  ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
12473
  cb(ffn_inp, "ffn_inp", il);
 
12534
  inpL = build_inp_embd(model.tok_embd);
12535
 
12536
  ggml_tensor * state_copy = build_inp_s_copy();
 
12537
 
12538
  const auto n_embd = hparams.n_embd;
12539
  const auto n_seq_tokens = ubatch.n_seq_tokens;
 
12544
  inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
12545
 
12546
  ggml_tensor * token_shift = build_rwkv_token_shift_load(
12547
+ gf, state_copy, ubatch, il
12548
  );
12549
 
12550
  ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, il);
 
12557
  1
12558
  );
12559
 
12560
+ cur = build_rwkv7_time_mix(gf, att_norm, x_prev, state_copy, v_first, ubatch, il);
12561
 
12562
  token_shift = ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm));
12563
  ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il));
 
13449
  }
13450
  };
13451
 
13452
+ struct llm_build_dots1 : public llm_graph_context {
13453
+ llm_build_dots1(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
13454
+ const int64_t n_embd_head = hparams.n_embd_head_v;
13455
+
13456
+ GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
13457
+ GGML_ASSERT(n_embd_head == hparams.n_rot);
13458
+
13459
+ ggml_tensor * cur;
13460
+ ggml_tensor * inpL;
13461
+
13462
+ inpL = build_inp_embd(model.tok_embd);
13463
+
13464
+ // inp_pos - contains the positions
13465
+ ggml_tensor * inp_pos = build_inp_pos();
13466
+
13467
+ auto * inp_attn = build_attn_inp_kv_unified();
13468
+
13469
+ for (int il = 0; il < n_layer; ++il) {
13470
+ ggml_tensor * inpSA = inpL;
13471
+
13472
+ // norm
13473
+ cur = build_norm(inpL,
13474
+ model.layers[il].attn_norm, NULL,
13475
+ LLM_NORM_RMS, il);
13476
+ cb(cur, "attn_norm", il);
13477
+
13478
+ // self_attention
13479
+ {
13480
+ // compute Q and K and RoPE them
13481
+ ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
13482
+ cb(Qcur, "Qcur", il);
13483
+
13484
+ ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
13485
+ cb(Kcur, "Kcur", il);
13486
+
13487
+ ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
13488
+ cb(Vcur, "Vcur", il);
13489
+
13490
+ Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
13491
+ Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
13492
+ Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
13493
+
13494
+ Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
13495
+ cb(Qcur, "Qcur_normed", il);
13496
+
13497
+ Qcur = ggml_rope_ext(
13498
+ ctx0, Qcur, inp_pos, nullptr,
13499
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
13500
+ ext_factor, attn_factor, beta_fast, beta_slow
13501
+ );
13502
+
13503
+ Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
13504
+ cb(Kcur, "Kcur_normed", il);
13505
+
13506
+ Kcur = ggml_rope_ext(
13507
+ ctx0, Kcur, inp_pos, nullptr,
13508
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
13509
+ ext_factor, attn_factor, beta_fast, beta_slow
13510
+ );
13511
+
13512
+ cb(Qcur, "Qcur", il);
13513
+ cb(Kcur, "Kcur", il);
13514
+ cb(Vcur, "Vcur", il);
13515
+
13516
+ cur = build_attn(inp_attn, gf,
13517
+ model.layers[il].wo, model.layers[il].bo,
13518
+ Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
13519
+ }
13520
+
13521
+ if (il == n_layer - 1) {
13522
+ // skip computing output for unused tokens
13523
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
13524
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
13525
+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
13526
+ }
13527
+
13528
+ ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
13529
+ cb(ffn_inp, "ffn_inp", il);
13530
+
13531
+ // MoE branch
13532
+ cur = build_norm(ffn_inp,
13533
+ model.layers[il].ffn_norm, NULL,
13534
+ LLM_NORM_RMS, il);
13535
+ cb(cur, "ffn_norm", il);
13536
+
13537
+ if ((uint32_t) il < hparams.n_layer_dense_lead) {
13538
+ cur = build_ffn(cur,
13539
+ model.layers[il].ffn_up, NULL, NULL,
13540
+ model.layers[il].ffn_gate, NULL, NULL,
13541
+ model.layers[il].ffn_down, NULL, NULL,
13542
+ NULL,
13543
+ LLM_FFN_SILU, LLM_FFN_PAR, il);
13544
+ cb(cur, "ffn_out", il);
13545
+ } else {
13546
+ ggml_tensor * moe_out =
13547
+ build_moe_ffn(cur,
13548
+ model.layers[il].ffn_gate_inp,
13549
+ model.layers[il].ffn_up_exps,
13550
+ model.layers[il].ffn_gate_exps,
13551
+ model.layers[il].ffn_down_exps,
13552
+ model.layers[il].ffn_exp_probs_b,
13553
+ n_expert, n_expert_used,
13554
+ LLM_FFN_SILU, hparams.expert_weights_norm,
13555
+ true, hparams.expert_weights_scale,
13556
+ (llama_expert_gating_func_type) hparams.expert_gating_func,
13557
+ il);
13558
+ cb(moe_out, "ffn_moe_out", il);
13559
+
13560
+ {
13561
+ ggml_tensor * ffn_shexp = build_ffn(cur,
13562
+ model.layers[il].ffn_up_shexp, NULL, NULL,
13563
+ model.layers[il].ffn_gate_shexp, NULL, NULL,
13564
+ model.layers[il].ffn_down_shexp, NULL, NULL,
13565
+ NULL,
13566
+ LLM_FFN_SILU, LLM_FFN_PAR, il);
13567
+ cb(ffn_shexp, "ffn_shexp", il);
13568
+
13569
+ cur = ggml_add(ctx0, moe_out, ffn_shexp);
13570
+ cb(cur, "ffn_out", il);
13571
+ }
13572
+ }
13573
+
13574
+ cur = ggml_add(ctx0, cur, ffn_inp);
13575
+
13576
+ cur = build_cvec(cur, il);
13577
+ cb(cur, "l_out", il);
13578
+
13579
+ // input for next layer
13580
+ inpL = cur;
13581
+ }
13582
+
13583
+ cur = inpL;
13584
+
13585
+ cur = build_norm(cur,
13586
+ model.output_norm, NULL,
13587
+ LLM_NORM_RMS, -1);
13588
+
13589
+ cb(cur, "result_norm", -1);
13590
+ res->t_embd = cur;
13591
+
13592
+ // lm_head
13593
+ cur = build_lora_mm(model.output, cur);
13594
+
13595
+ cb(cur, "result_output", -1);
13596
+ res->t_logits = cur;
13597
+
13598
+ ggml_build_forward_expand(gf, cur);
13599
+ }
13600
+ };
13601
+
13602
+ struct llm_build_arcee : public llm_graph_context {
13603
+ llm_build_arcee(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
13604
+ const int64_t n_embd_head = hparams.n_embd_head_v;
13605
+
13606
+ GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
13607
+ GGML_ASSERT(n_embd_head == hparams.n_rot);
13608
+
13609
+ ggml_tensor * cur;
13610
+ ggml_tensor * inpL;
13611
+
13612
+ inpL = build_inp_embd(model.tok_embd);
13613
+
13614
+ // inp_pos - contains the positions
13615
+ ggml_tensor * inp_pos = build_inp_pos();
13616
+
13617
+ auto * inp_attn = build_attn_inp_kv_unified();
13618
+
13619
+ const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
13620
+
13621
+ for (int il = 0; il < n_layer; ++il) {
13622
+ ggml_tensor * inpSA = inpL;
13623
+
13624
+ // norm
13625
+ cur = build_norm(inpL,
13626
+ model.layers[il].attn_norm, NULL,
13627
+ LLM_NORM_RMS, il);
13628
+ cb(cur, "attn_norm", il);
13629
+
13630
+ // self-attention
13631
+ {
13632
+ // rope freq factors for llama3; may return nullptr for llama2 and other models
13633
+ ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
13634
+
13635
+ // compute Q and K and RoPE them
13636
+ ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
13637
+ cb(Qcur, "Qcur", il);
13638
+ if (model.layers[il].bq) {
13639
+ Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
13640
+ cb(Qcur, "Qcur", il);
13641
+ }
13642
+
13643
+ ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
13644
+ cb(Kcur, "Kcur", il);
13645
+ if (model.layers[il].bk) {
13646
+ Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
13647
+ cb(Kcur, "Kcur", il);
13648
+ }
13649
+
13650
+ ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
13651
+ cb(Vcur, "Vcur", il);
13652
+ if (model.layers[il].bv) {
13653
+ Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
13654
+ cb(Vcur, "Vcur", il);
13655
+ }
13656
+
13657
+ Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
13658
+ Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
13659
+ Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
13660
+
13661
+ Qcur = ggml_rope_ext(
13662
+ ctx0, Qcur, inp_pos, rope_factors,
13663
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
13664
+ ext_factor, attn_factor, beta_fast, beta_slow
13665
+ );
13666
+
13667
+ Kcur = ggml_rope_ext(
13668
+ ctx0, Kcur, inp_pos, rope_factors,
13669
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
13670
+ ext_factor, attn_factor, beta_fast, beta_slow
13671
+ );
13672
+
13673
+ cb(Qcur, "Qcur", il);
13674
+ cb(Kcur, "Kcur", il);
13675
+ cb(Vcur, "Vcur", il);
13676
+
13677
+ cur = build_attn(inp_attn, gf,
13678
+ model.layers[il].wo, model.layers[il].bo,
13679
+ Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
13680
+ cb(cur, "attn_out", il);
13681
+ }
13682
+
13683
+ if (il == n_layer - 1) {
13684
+ // skip computing output for unused tokens
13685
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
13686
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
13687
+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
13688
+ }
13689
+
13690
+ ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
13691
+ cb(ffn_inp, "ffn_inp", il);
13692
+
13693
+ // feed-forward network
13694
+ // ARCEE uses relu^2 instead of silu
13695
+ cur = build_norm(ffn_inp,
13696
+ model.layers[il].ffn_norm, NULL,
13697
+ LLM_NORM_RMS, il);
13698
+ cb(cur, "ffn_norm", il);
13699
+
13700
+ cur = build_ffn(cur,
13701
+ model.layers[il].ffn_up, NULL, NULL,
13702
+ NULL, NULL, NULL,
13703
+ model.layers[il].ffn_down, NULL, NULL,
13704
+ NULL,
13705
+ LLM_FFN_RELU_SQR, LLM_FFN_SEQ, il);
13706
+ cb(cur, "ffn_out", il);
13707
+
13708
+ cur = ggml_add(ctx0, cur, ffn_inp);
13709
+ cb(cur, "ffn_out", il);
13710
+
13711
+ cur = build_cvec(cur, il);
13712
+ cb(cur, "l_out", il);
13713
+
13714
+ // input for next layer
13715
+ inpL = cur;
13716
+ }
13717
+
13718
+ cur = inpL;
13719
+
13720
+ cur = build_norm(cur,
13721
+ model.output_norm, NULL,
13722
+ LLM_NORM_RMS, -1);
13723
+
13724
+ cb(cur, "result_norm", -1);
13725
+ res->t_embd = cur;
13726
+
13727
+ // lm_head
13728
+ cur = build_lora_mm(model.output, cur);
13729
+
13730
+ cb(cur, "result_output", -1);
13731
+ res->t_logits = cur;
13732
+
13733
+ ggml_build_forward_expand(gf, cur);
13734
+ }
13735
+ };
13736
+
13737
  llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_cparams & cparams) const {
13738
  llama_memory_i * res;
13739
 
 
13742
  case LLM_ARCH_JINA_BERT_V2:
13743
  case LLM_ARCH_NOMIC_BERT:
13744
  case LLM_ARCH_NOMIC_BERT_MOE:
13745
+ case LLM_ARCH_NEO_BERT:
13746
  case LLM_ARCH_WAVTOKENIZER_DEC:
13747
  {
13748
  res = nullptr;
 
13851
  {
13852
  llm = std::make_unique<llm_build_bert>(*this, params, gf);
13853
  } break;
13854
+ case LLM_ARCH_NEO_BERT:
13855
+ {
13856
+ llm = std::make_unique<llm_build_neo_bert>(*this, params, gf);
13857
+ } break;
13858
  case LLM_ARCH_BLOOM:
13859
  {
13860
  llm = std::make_unique<llm_build_bloom>(*this, params, gf);
 
14077
  {
14078
  llm = std::make_unique<llm_build_bailingmoe>(*this, params, gf);
14079
  } break;
14080
+ case LLM_ARCH_DOTS1:
14081
+ {
14082
+ llm = std::make_unique<llm_build_dots1>(*this, params, gf);
14083
+ } break;
14084
+ case LLM_ARCH_ARCEE:
14085
+ {
14086
+ llm = std::make_unique<llm_build_arcee>(*this, params, gf);
14087
+ } break;
14088
  default:
14089
  GGML_ABORT("fatal error");
14090
  }
 
14234
  case LLM_ARCH_GRANITE_MOE:
14235
  case LLM_ARCH_CHAMELEON:
14236
  case LLM_ARCH_BAILINGMOE:
14237
+ case LLM_ARCH_NEO_BERT:
14238
+ case LLM_ARCH_ARCEE:
14239
  return LLAMA_ROPE_TYPE_NORM;
14240
 
14241
  // the pairs of head values are offset by n_rot/2
 
14269
  case LLM_ARCH_NEMOTRON:
14270
  case LLM_ARCH_EXAONE:
14271
  case LLM_ARCH_MINICPM3:
14272
+ case LLM_ARCH_DOTS1:
14273
  return LLAMA_ROPE_TYPE_NEOX;
14274
 
14275
  case LLM_ARCH_QWEN2VL:
examples/talk-llama/llama-model.h CHANGED
@@ -73,6 +73,7 @@ enum llm_type {
73
  LLM_TYPE_40B,
74
  LLM_TYPE_65B,
75
  LLM_TYPE_70B,
 
76
  LLM_TYPE_236B,
77
  LLM_TYPE_290B,
78
  LLM_TYPE_314B,
 
73
  LLM_TYPE_40B,
74
  LLM_TYPE_65B,
75
  LLM_TYPE_70B,
76
+ LLM_TYPE_142B,
77
  LLM_TYPE_236B,
78
  LLM_TYPE_290B,
79
  LLM_TYPE_314B,
examples/talk-llama/llama-quant.cpp CHANGED
@@ -585,7 +585,8 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
585
  if (o.tag == LLAMA_KV_OVERRIDE_TYPE_FLOAT) {
586
  gguf_set_val_f32(ctx_out.get(), o.key, o.val_f64);
587
  } else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_INT) {
588
- gguf_set_val_i32(ctx_out.get(), o.key, o.val_i64);
 
589
  } else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_BOOL) {
590
  gguf_set_val_bool(ctx_out.get(), o.key, o.val_bool);
591
  } else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_STR) {
 
585
  if (o.tag == LLAMA_KV_OVERRIDE_TYPE_FLOAT) {
586
  gguf_set_val_f32(ctx_out.get(), o.key, o.val_f64);
587
  } else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_INT) {
588
+ // Setting type to UINT32. See https://github.com/ggml-org/llama.cpp/pull/14182 for context
589
+ gguf_set_val_u32(ctx_out.get(), o.key, (uint32_t)abs(o.val_i64));
590
  } else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_BOOL) {
591
  gguf_set_val_bool(ctx_out.get(), o.key, o.val_bool);
592
  } else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_STR) {
examples/talk-llama/llama-vocab.cpp CHANGED
@@ -9,16 +9,16 @@
9
 
10
  #include <algorithm>
11
  #include <cassert>
 
12
  #include <cfloat>
13
- #include <climits>
14
  #include <cstdarg>
15
  #include <cstring>
16
  #include <forward_list>
 
17
  #include <map>
18
  #include <queue>
19
  #include <set>
20
  #include <unordered_map>
21
- #include <cctype>
22
 
23
  //
24
  // helpers
@@ -1987,6 +1987,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
1987
  || t.first == "<|eom_id|>"
1988
  || t.first == "<EOT>"
1989
  || t.first == "_<EOT>"
 
1990
  ) {
1991
  special_eog_ids.insert(t.second);
1992
  if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
@@ -2572,6 +2573,10 @@ int32_t llama_vocab::impl::token_to_piece(llama_token token, char * buf, int32_t
2572
  // copy piece chars to output text buffer
2573
  // skip up to 'lstrip' leading spaces before copying
2574
  auto _try_copy = [=] (const char * token, size_t size) -> int32_t {
 
 
 
 
2575
  for (int32_t i = 0; i < lstrip && size && *token == ' '; ++i) {
2576
  token++;
2577
  size--;
@@ -2768,26 +2773,26 @@ void llama_vocab::impl::print_info() const {
2768
  LLAMA_LOG_INFO("%s: n_merges = %u\n", __func__, (uint32_t) bpe_ranks.size());
2769
 
2770
  // special tokens
2771
- if (special_bos_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: BOS token = %d '%s'\n", __func__, special_bos_id, id_to_token[special_bos_id].text.c_str() ); }
2772
- if (special_eos_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: EOS token = %d '%s'\n", __func__, special_eos_id, id_to_token[special_eos_id].text.c_str() ); }
2773
- if (special_eot_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: EOT token = %d '%s'\n", __func__, special_eot_id, id_to_token[special_eot_id].text.c_str() ); }
2774
- if (special_eom_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: EOM token = %d '%s'\n", __func__, special_eom_id, id_to_token[special_eom_id].text.c_str() ); }
2775
- if (special_unk_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: UNK token = %d '%s'\n", __func__, special_unk_id, id_to_token[special_unk_id].text.c_str() ); }
2776
- if (special_sep_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: SEP token = %d '%s'\n", __func__, special_sep_id, id_to_token[special_sep_id].text.c_str() ); }
2777
- if (special_pad_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: PAD token = %d '%s'\n", __func__, special_pad_id, id_to_token[special_pad_id].text.c_str() ); }
2778
- if (special_mask_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: MASK token = %d '%s'\n", __func__, special_mask_id, id_to_token[special_mask_id].text.c_str() ); }
2779
-
2780
- if (linefeed_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: LF token = %d '%s'\n", __func__, linefeed_id, id_to_token[linefeed_id].text.c_str() ); }
2781
-
2782
- if (special_fim_pre_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM PRE token = %d '%s'\n", __func__, special_fim_pre_id, id_to_token[special_fim_pre_id].text.c_str() ); }
2783
- if (special_fim_suf_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM SUF token = %d '%s'\n", __func__, special_fim_suf_id, id_to_token[special_fim_suf_id].text.c_str() ); }
2784
- if (special_fim_mid_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM MID token = %d '%s'\n", __func__, special_fim_mid_id, id_to_token[special_fim_mid_id].text.c_str() ); }
2785
- if (special_fim_pad_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM PAD token = %d '%s'\n", __func__, special_fim_pad_id, id_to_token[special_fim_pad_id].text.c_str() ); }
2786
- if (special_fim_rep_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM REP token = %d '%s'\n", __func__, special_fim_rep_id, id_to_token[special_fim_rep_id].text.c_str() ); }
2787
- if (special_fim_sep_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM SEP token = %d '%s'\n", __func__, special_fim_sep_id, id_to_token[special_fim_sep_id].text.c_str() ); }
2788
 
2789
  for (const auto & id : special_eog_ids) {
2790
- LLAMA_LOG_INFO( "%s: EOG token = %d '%s'\n", __func__, id, id_to_token[id].text.c_str() );
2791
  }
2792
 
2793
  LLAMA_LOG_INFO("%s: max token length = %d\n", __func__, max_token_len);
 
9
 
10
  #include <algorithm>
11
  #include <cassert>
12
+ #include <cctype>
13
  #include <cfloat>
 
14
  #include <cstdarg>
15
  #include <cstring>
16
  #include <forward_list>
17
+ #include <limits>
18
  #include <map>
19
  #include <queue>
20
  #include <set>
21
  #include <unordered_map>
 
22
 
23
  //
24
  // helpers
 
1987
  || t.first == "<|eom_id|>"
1988
  || t.first == "<EOT>"
1989
  || t.first == "_<EOT>"
1990
+ || t.first == "<|end_of_text|>"
1991
  ) {
1992
  special_eog_ids.insert(t.second);
1993
  if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
 
2573
  // copy piece chars to output text buffer
2574
  // skip up to 'lstrip' leading spaces before copying
2575
  auto _try_copy = [=] (const char * token, size_t size) -> int32_t {
2576
+ if (size >= static_cast<size_t>(std::numeric_limits<int32_t>::max())) {
2577
+ GGML_ABORT("invalid token size: %zu exceeds int32_t limit", size);
2578
+ }
2579
+
2580
  for (int32_t i = 0; i < lstrip && size && *token == ' '; ++i) {
2581
  token++;
2582
  size--;
 
2773
  LLAMA_LOG_INFO("%s: n_merges = %u\n", __func__, (uint32_t) bpe_ranks.size());
2774
 
2775
  // special tokens
2776
+ if (special_bos_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: BOS token = %d '%s'\n", __func__, special_bos_id, id_to_token.at(special_bos_id).text.c_str() ); }
2777
+ if (special_eos_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: EOS token = %d '%s'\n", __func__, special_eos_id, id_to_token.at(special_eos_id).text.c_str() ); }
2778
+ if (special_eot_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: EOT token = %d '%s'\n", __func__, special_eot_id, id_to_token.at(special_eot_id).text.c_str() ); }
2779
+ if (special_eom_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: EOM token = %d '%s'\n", __func__, special_eom_id, id_to_token.at(special_eom_id).text.c_str() ); }
2780
+ if (special_unk_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: UNK token = %d '%s'\n", __func__, special_unk_id, id_to_token.at(special_unk_id).text.c_str() ); }
2781
+ if (special_sep_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: SEP token = %d '%s'\n", __func__, special_sep_id, id_to_token.at(special_sep_id).text.c_str() ); }
2782
+ if (special_pad_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: PAD token = %d '%s'\n", __func__, special_pad_id, id_to_token.at(special_pad_id).text.c_str() ); }
2783
+ if (special_mask_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: MASK token = %d '%s'\n", __func__, special_mask_id, id_to_token.at(special_mask_id).text.c_str() ); }
2784
+
2785
+ if (linefeed_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: LF token = %d '%s'\n", __func__, linefeed_id, id_to_token.at(linefeed_id).text.c_str() ); }
2786
+
2787
+ if (special_fim_pre_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM PRE token = %d '%s'\n", __func__, special_fim_pre_id, id_to_token.at(special_fim_pre_id).text.c_str() ); }
2788
+ if (special_fim_suf_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM SUF token = %d '%s'\n", __func__, special_fim_suf_id, id_to_token.at(special_fim_suf_id).text.c_str() ); }
2789
+ if (special_fim_mid_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM MID token = %d '%s'\n", __func__, special_fim_mid_id, id_to_token.at(special_fim_mid_id).text.c_str() ); }
2790
+ if (special_fim_pad_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM PAD token = %d '%s'\n", __func__, special_fim_pad_id, id_to_token.at(special_fim_pad_id).text.c_str() ); }
2791
+ if (special_fim_rep_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM REP token = %d '%s'\n", __func__, special_fim_rep_id, id_to_token.at(special_fim_rep_id).text.c_str() ); }
2792
+ if (special_fim_sep_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM SEP token = %d '%s'\n", __func__, special_fim_sep_id, id_to_token.at(special_fim_sep_id).text.c_str() ); }
2793
 
2794
  for (const auto & id : special_eog_ids) {
2795
+ LLAMA_LOG_INFO( "%s: EOG token = %d '%s'\n", __func__, id, id_to_token.at(id).text.c_str() );
2796
  }
2797
 
2798
  LLAMA_LOG_INFO("%s: max token length = %d\n", __func__, max_token_len);
examples/talk-llama/llama.cpp CHANGED
@@ -198,14 +198,18 @@ static struct llama_model * llama_model_load_from_file_impl(
198
 
199
  // if using single GPU mode, remove all except the main GPU
200
  if (params.split_mode == LLAMA_SPLIT_MODE_NONE) {
201
- if (params.main_gpu < 0 || params.main_gpu >= (int)model->devices.size()) {
202
- LLAMA_LOG_ERROR("%s: invalid value for main_gpu: %d (available devices: %d)\n", __func__, params.main_gpu, (int)model->devices.size());
203
- llama_model_free(model);
204
- return nullptr;
 
 
 
 
 
 
 
205
  }
206
- ggml_backend_dev_t main_gpu = model->devices[params.main_gpu];
207
- model->devices.clear();
208
- model->devices.push_back(main_gpu);
209
  }
210
 
211
  for (auto * dev : model->devices) {
 
198
 
199
  // if using single GPU mode, remove all except the main GPU
200
  if (params.split_mode == LLAMA_SPLIT_MODE_NONE) {
201
+ if (params.main_gpu < 0) {
202
+ model->devices.clear();
203
+ } else {
204
+ if (params.main_gpu >= (int)model->devices.size()) {
205
+ LLAMA_LOG_ERROR("%s: invalid value for main_gpu: %d (available devices: %zu)\n", __func__, params.main_gpu, model->devices.size());
206
+ llama_model_free(model);
207
+ return nullptr;
208
+ }
209
+ ggml_backend_dev_t main_gpu = model->devices[params.main_gpu];
210
+ model->devices.clear();
211
+ model->devices.push_back(main_gpu);
212
  }
 
 
 
213
  }
214
 
215
  for (auto * dev : model->devices) {
examples/talk-llama/llama.h CHANGED
@@ -243,18 +243,21 @@ extern "C" {
243
 
244
  typedef bool (*llama_progress_callback)(float progress, void * user_data);
245
 
246
- // Input data for llama_decode
247
  // A llama_batch object can contain input about one or many sequences
248
  // The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens
249
  //
250
  // - token : the token ids of the input (used when embd is NULL)
251
  // - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL)
252
  // - pos : the positions of the respective token in the sequence
253
- // (if set to NULL, the token position will be tracked automatically by llama_decode)
254
  // - seq_id : the sequence to which the respective token belongs
255
  // (if set to NULL, the sequence ID will be assumed to be 0)
256
  // - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output
257
- // (if set to NULL, only the logits for last token will be returned)
 
 
 
258
  //
259
  typedef struct llama_batch {
260
  int32_t n_tokens;
@@ -262,8 +265,8 @@ extern "C" {
262
  llama_token * token;
263
  float * embd;
264
  llama_pos * pos;
265
- int32_t * n_seq_id; // TODO: remove, should belong to only 1 sequence
266
- llama_seq_id ** seq_id; // TODO: become llama_seq_id * seq_id;
267
  int8_t * logits; // TODO: rename this to "output"
268
  } llama_batch;
269
 
@@ -961,8 +964,8 @@ extern "C" {
961
  // Get the number of threads used for prompt and batch processing (multiple token).
962
  LLAMA_API int32_t llama_n_threads_batch(struct llama_context * ctx);
963
 
964
- // Set whether the model is in embeddings mode or not
965
- // If true, embeddings will be returned but logits will not
966
  LLAMA_API void llama_set_embeddings(struct llama_context * ctx, bool embeddings);
967
 
968
  // Set whether to use causal attention or not
 
243
 
244
  typedef bool (*llama_progress_callback)(float progress, void * user_data);
245
 
246
+ // Input data for llama_encode/llama_decode
247
  // A llama_batch object can contain input about one or many sequences
248
  // The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens
249
  //
250
  // - token : the token ids of the input (used when embd is NULL)
251
  // - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL)
252
  // - pos : the positions of the respective token in the sequence
253
+ // (if set to NULL, the token position will be tracked automatically by llama_encode/llama_decode)
254
  // - seq_id : the sequence to which the respective token belongs
255
  // (if set to NULL, the sequence ID will be assumed to be 0)
256
  // - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output
257
+ // (if set to NULL:
258
+ // - if embeddings: all tokens are output
259
+ // - if not: only the last token is output
260
+ // )
261
  //
262
  typedef struct llama_batch {
263
  int32_t n_tokens;
 
265
  llama_token * token;
266
  float * embd;
267
  llama_pos * pos;
268
+ int32_t * n_seq_id;
269
+ llama_seq_id ** seq_id;
270
  int8_t * logits; // TODO: rename this to "output"
271
  } llama_batch;
272
 
 
964
  // Get the number of threads used for prompt and batch processing (multiple token).
965
  LLAMA_API int32_t llama_n_threads_batch(struct llama_context * ctx);
966
 
967
+ // Set whether the context outputs embeddings or not
968
+ // TODO: rename to avoid confusion with llama_get_embeddings()
969
  LLAMA_API void llama_set_embeddings(struct llama_context * ctx, bool embeddings);
970
 
971
  // Set whether to use causal attention or not