ggerganov commited on
Commit
844e617
·
1 Parent(s): 7d38d31

talk-llama : sync llama.cpp

Browse files
examples/talk-llama/llama-arch.cpp CHANGED
@@ -34,6 +34,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
34
  { LLM_ARCH_PHI3, "phi3" },
35
  { LLM_ARCH_PHIMOE, "phimoe" },
36
  { LLM_ARCH_PLAMO, "plamo" },
 
37
  { LLM_ARCH_CODESHELL, "codeshell" },
38
  { LLM_ARCH_ORION, "orion" },
39
  { LLM_ARCH_INTERNLM2, "internlm2" },
@@ -67,6 +68,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
67
  { LLM_ARCH_JAIS, "jais" },
68
  { LLM_ARCH_NEMOTRON, "nemotron" },
69
  { LLM_ARCH_EXAONE, "exaone" },
 
70
  { LLM_ARCH_RWKV6, "rwkv6" },
71
  { LLM_ARCH_RWKV6QWEN2, "rwkv6qwen2" },
72
  { LLM_ARCH_RWKV7, "rwkv7" },
@@ -81,9 +83,11 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
81
  { LLM_ARCH_DOTS1, "dots1" },
82
  { LLM_ARCH_ARCEE, "arcee" },
83
  { LLM_ARCH_ERNIE4_5, "ernie4_5" },
 
84
  { LLM_ARCH_HUNYUAN_MOE, "hunyuan-moe" },
85
  { LLM_ARCH_SMOLLM3, "smollm3" },
86
  { LLM_ARCH_LFM2, "lfm2" },
 
87
  { LLM_ARCH_UNKNOWN, "(unknown)" },
88
  };
89
 
@@ -784,6 +788,36 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
784
  { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
785
  },
786
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
787
  {
788
  LLM_ARCH_CODESHELL,
789
  {
@@ -1477,6 +1511,26 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
1477
  { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
1478
  },
1479
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1480
  {
1481
  LLM_ARCH_RWKV6,
1482
  {
@@ -1793,6 +1847,31 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
1793
  { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
1794
  },
1795
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1796
  {
1797
  LLM_ARCH_HUNYUAN_MOE,
1798
  {
@@ -1854,6 +1933,23 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
1854
  { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
1855
  }
1856
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1857
  {
1858
  LLM_ARCH_UNKNOWN,
1859
  {
@@ -2094,6 +2190,7 @@ bool llm_arch_is_hybrid(const llm_arch & arch) {
2094
  switch (arch) {
2095
  case LLM_ARCH_JAMBA:
2096
  case LLM_ARCH_FALCON_H1:
 
2097
  case LLM_ARCH_GRANITE_HYBRID:
2098
  case LLM_ARCH_LFM2:
2099
  return true;
@@ -2101,3 +2198,12 @@ bool llm_arch_is_hybrid(const llm_arch & arch) {
2101
  return false;
2102
  }
2103
  }
 
 
 
 
 
 
 
 
 
 
34
  { LLM_ARCH_PHI3, "phi3" },
35
  { LLM_ARCH_PHIMOE, "phimoe" },
36
  { LLM_ARCH_PLAMO, "plamo" },
37
+ { LLM_ARCH_PLAMO2, "plamo2" },
38
  { LLM_ARCH_CODESHELL, "codeshell" },
39
  { LLM_ARCH_ORION, "orion" },
40
  { LLM_ARCH_INTERNLM2, "internlm2" },
 
68
  { LLM_ARCH_JAIS, "jais" },
69
  { LLM_ARCH_NEMOTRON, "nemotron" },
70
  { LLM_ARCH_EXAONE, "exaone" },
71
+ { LLM_ARCH_EXAONE4, "exaone4" },
72
  { LLM_ARCH_RWKV6, "rwkv6" },
73
  { LLM_ARCH_RWKV6QWEN2, "rwkv6qwen2" },
74
  { LLM_ARCH_RWKV7, "rwkv7" },
 
83
  { LLM_ARCH_DOTS1, "dots1" },
84
  { LLM_ARCH_ARCEE, "arcee" },
85
  { LLM_ARCH_ERNIE4_5, "ernie4_5" },
86
+ { LLM_ARCH_ERNIE4_5_MOE, "ernie4_5-moe" },
87
  { LLM_ARCH_HUNYUAN_MOE, "hunyuan-moe" },
88
  { LLM_ARCH_SMOLLM3, "smollm3" },
89
  { LLM_ARCH_LFM2, "lfm2" },
90
+ { LLM_ARCH_DREAM, "dream" },
91
  { LLM_ARCH_UNKNOWN, "(unknown)" },
92
  };
93
 
 
788
  { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
789
  },
790
  },
791
+ {
792
+ LLM_ARCH_PLAMO2,
793
+ {
794
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
795
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
796
+ { LLM_TENSOR_OUTPUT, "output" },
797
+ { LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
798
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
799
+ { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
800
+ { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
801
+ { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
802
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
803
+ { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
804
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
805
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
806
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
807
+ { LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" },
808
+ { LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" },
809
+ { LLM_TENSOR_SSM_X, "blk.%d.ssm_x" },
810
+ { LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" },
811
+ { LLM_TENSOR_SSM_A, "blk.%d.ssm_a" },
812
+ { LLM_TENSOR_SSM_D, "blk.%d.ssm_d" },
813
+ { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
814
+ { LLM_TENSOR_SSM_DT_NORM, "blk.%d.ssm_dt_norm" },
815
+ { LLM_TENSOR_SSM_B_NORM, "blk.%d.ssm_b_norm" },
816
+ { LLM_TENSOR_SSM_C_NORM, "blk.%d.ssm_c_norm" },
817
+ { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
818
+ { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
819
+ },
820
+ },
821
  {
822
  LLM_ARCH_CODESHELL,
823
  {
 
1511
  { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
1512
  },
1513
  },
1514
+ {
1515
+ LLM_ARCH_EXAONE4,
1516
+ {
1517
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
1518
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
1519
+ { LLM_TENSOR_OUTPUT, "output" },
1520
+ { LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
1521
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
1522
+ { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
1523
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
1524
+ { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
1525
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
1526
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
1527
+ { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
1528
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
1529
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
1530
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
1531
+ { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
1532
+ }
1533
+ },
1534
  {
1535
  LLM_ARCH_RWKV6,
1536
  {
 
1847
  { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
1848
  },
1849
  },
1850
+ {
1851
+ LLM_ARCH_ERNIE4_5_MOE,
1852
+ {
1853
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
1854
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
1855
+ { LLM_TENSOR_OUTPUT, "output" },
1856
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
1857
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
1858
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
1859
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
1860
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
1861
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
1862
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
1863
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
1864
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
1865
+ { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
1866
+ { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
1867
+ { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
1868
+ { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
1869
+ { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
1870
+ { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
1871
+ { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
1872
+ { LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
1873
+ },
1874
+ },
1875
  {
1876
  LLM_ARCH_HUNYUAN_MOE,
1877
  {
 
1933
  { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
1934
  }
1935
  },
1936
+ {
1937
+ LLM_ARCH_DREAM,
1938
+ {
1939
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
1940
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
1941
+ { LLM_TENSOR_OUTPUT, "output" },
1942
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
1943
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
1944
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
1945
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
1946
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
1947
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
1948
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
1949
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
1950
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
1951
+ },
1952
+ },
1953
  {
1954
  LLM_ARCH_UNKNOWN,
1955
  {
 
2190
  switch (arch) {
2191
  case LLM_ARCH_JAMBA:
2192
  case LLM_ARCH_FALCON_H1:
2193
+ case LLM_ARCH_PLAMO2:
2194
  case LLM_ARCH_GRANITE_HYBRID:
2195
  case LLM_ARCH_LFM2:
2196
  return true;
 
2198
  return false;
2199
  }
2200
  }
2201
+
2202
+ bool llm_arch_is_diffusion(const llm_arch & arch) {
2203
+ switch (arch) {
2204
+ case LLM_ARCH_DREAM:
2205
+ return true;
2206
+ default:
2207
+ return false;
2208
+ }
2209
+ }
examples/talk-llama/llama-arch.h CHANGED
@@ -38,6 +38,7 @@ enum llm_arch {
38
  LLM_ARCH_PHI3,
39
  LLM_ARCH_PHIMOE,
40
  LLM_ARCH_PLAMO,
 
41
  LLM_ARCH_CODESHELL,
42
  LLM_ARCH_ORION,
43
  LLM_ARCH_INTERNLM2,
@@ -71,6 +72,7 @@ enum llm_arch {
71
  LLM_ARCH_JAIS,
72
  LLM_ARCH_NEMOTRON,
73
  LLM_ARCH_EXAONE,
 
74
  LLM_ARCH_RWKV6,
75
  LLM_ARCH_RWKV6QWEN2,
76
  LLM_ARCH_RWKV7,
@@ -85,9 +87,11 @@ enum llm_arch {
85
  LLM_ARCH_DOTS1,
86
  LLM_ARCH_ARCEE,
87
  LLM_ARCH_ERNIE4_5,
 
88
  LLM_ARCH_HUNYUAN_MOE,
89
  LLM_ARCH_SMOLLM3,
90
  LLM_ARCH_LFM2,
 
91
  LLM_ARCH_UNKNOWN,
92
  };
93
 
@@ -478,3 +482,4 @@ const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor);
478
 
479
  bool llm_arch_is_recurrent(const llm_arch & arch);
480
  bool llm_arch_is_hybrid (const llm_arch & arch);
 
 
38
  LLM_ARCH_PHI3,
39
  LLM_ARCH_PHIMOE,
40
  LLM_ARCH_PLAMO,
41
+ LLM_ARCH_PLAMO2,
42
  LLM_ARCH_CODESHELL,
43
  LLM_ARCH_ORION,
44
  LLM_ARCH_INTERNLM2,
 
72
  LLM_ARCH_JAIS,
73
  LLM_ARCH_NEMOTRON,
74
  LLM_ARCH_EXAONE,
75
+ LLM_ARCH_EXAONE4,
76
  LLM_ARCH_RWKV6,
77
  LLM_ARCH_RWKV6QWEN2,
78
  LLM_ARCH_RWKV7,
 
87
  LLM_ARCH_DOTS1,
88
  LLM_ARCH_ARCEE,
89
  LLM_ARCH_ERNIE4_5,
90
+ LLM_ARCH_ERNIE4_5_MOE,
91
  LLM_ARCH_HUNYUAN_MOE,
92
  LLM_ARCH_SMOLLM3,
93
  LLM_ARCH_LFM2,
94
+ LLM_ARCH_DREAM,
95
  LLM_ARCH_UNKNOWN,
96
  };
97
 
 
482
 
483
  bool llm_arch_is_recurrent(const llm_arch & arch);
484
  bool llm_arch_is_hybrid (const llm_arch & arch);
485
+ bool llm_arch_is_diffusion(const llm_arch & arch);
examples/talk-llama/llama-batch.cpp CHANGED
@@ -27,6 +27,7 @@ bool llama_batch_allocr::init(
27
  const llama_vocab & vocab,
28
  const llama_memory_i * memory,
29
  uint32_t n_embd,
 
30
  bool output_all) {
31
  clear();
32
 
@@ -40,6 +41,11 @@ bool llama_batch_allocr::init(
40
  // validate input batch
41
  //
42
 
 
 
 
 
 
43
  if (batch.token) {
44
  for (int32_t i = 0; i < batch.n_tokens; ++i) {
45
  if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= vocab.n_tokens()) {
@@ -52,8 +58,8 @@ bool llama_batch_allocr::init(
52
  if (batch.seq_id) {
53
  for (int32_t i = 0; i < batch.n_tokens; ++i) {
54
  for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
55
- if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= LLAMA_MAX_SEQ)) {
56
- LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s], LLAMA_MAX_SEQ);
57
  return false;
58
  }
59
  }
@@ -86,7 +92,7 @@ bool llama_batch_allocr::init(
86
 
87
  // initialize the starting position for each sequence based on the positions in the memory
88
  llama_pos p0[LLAMA_MAX_SEQ];
89
- for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
90
  if (!memory) {
91
  // if no memory -> start from 0
92
  p0[s] = 0;
@@ -143,13 +149,16 @@ bool llama_batch_allocr::init(
143
  // compute stats
144
  //
145
 
146
- this->n_embd = n_embd;
 
147
 
148
  // count the outputs in this batch
149
  for (int32_t i = 0; i < batch.n_tokens; ++i) {
150
  n_outputs += batch.logits[i] != 0;
151
  }
152
 
 
 
153
  // determine coupled sequences
154
  // these are pairs of sequences that have at least one token in the input batch that is assigned to both of them
155
  for (int32_t i = 0; i < batch.n_tokens; ++i) {
@@ -189,7 +198,7 @@ bool llama_batch_allocr::init(
189
  seq_set_map[cur].push_back(i);
190
  }
191
 
192
- for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
193
  if (seq_set_unq.test(s)) {
194
  seq_idx[s] = seq_id_unq.size();
195
  seq_id_unq.push_back(s);
@@ -201,7 +210,7 @@ bool llama_batch_allocr::init(
201
  LLAMA_LOG_DEBUG("%s: input batch info:\n", __func__);
202
 
203
  llama_ubatch ubatch {
204
- /*.equal_seqs =*/ false,
205
  /*.n_tokens =*/ (uint32_t) batch.n_tokens,
206
  /*.n_seq_tokens =*/ (uint32_t) 1,
207
  /*.n_seqs =*/ (uint32_t) batch.n_tokens,
@@ -214,6 +223,7 @@ bool llama_batch_allocr::init(
214
  /*.seq_id_unq =*/ this->seq_id_unq.data(),
215
  /*.seq_idx =*/ this->seq_idx.data(),
216
  /*.output =*/ batch.logits,
 
217
  };
218
 
219
  ubatch_print(ubatch, debug);
@@ -241,7 +251,7 @@ bool llama_batch_allocr::init(
241
  // consistency checks
242
  //
243
 
244
- for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
245
  if (seq_pos[s].empty()) {
246
  continue;
247
  }
@@ -284,8 +294,8 @@ bool llama_batch_allocr::init(
284
  }
285
 
286
  if (memory) {
287
- for (int32_t s0 = 0; s0 < LLAMA_MAX_SEQ; ++s0) {
288
- for (int32_t s1 = 0; s1 < LLAMA_MAX_SEQ; ++s1) {
289
  if (seq_cpl[s0][s1]) {
290
  if (memory->seq_pos_min(s0) != memory->seq_pos_min(s1) ||
291
  memory->seq_pos_max(s0) != memory->seq_pos_max(s1)) {
@@ -316,12 +326,12 @@ bool llama_batch_allocr::init(
316
  //
317
  {
318
  seq_set_t cur_seq_set[LLAMA_MAX_SEQ];
319
- for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
320
  cur_seq_set[s].set();
321
  }
322
 
323
  llama_pos cur_seq_pos[LLAMA_MAX_SEQ];
324
- for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
325
  cur_seq_pos[s] = -1;
326
  }
327
 
@@ -357,39 +367,38 @@ llama_ubatch llama_batch_allocr::ubatch_reserve(uint32_t n_seq_tokens, uint32_t
357
  clear();
358
  split_reset();
359
 
360
- ubatches.emplace_back();
361
 
362
- auto & ubatch = ubatches.back();
363
-
364
- ubatch.token .resize(n_tokens);
365
- ubatch.embd .clear();
366
- ubatch.pos .resize(n_tokens);
367
- ubatch.n_seq_id .resize(n_tokens);
368
- ubatch.seq_id .resize(n_tokens);
369
- ubatch.seq_id_unq.resize(0);
370
- ubatch.seq_idx .resize(LLAMA_MAX_SEQ, -1);
371
- ubatch.output .resize(n_tokens);
372
 
373
  for (uint32_t s = 0; s < n_seqs; ++s) {
374
- ubatch.seq_idx[s] = s;
375
- ubatch.seq_id_unq.push_back(s);
376
  }
377
 
378
  llama_ubatch res {
379
- /*.equal_seqs =*/ true,
380
  /*.n_tokens =*/ n_tokens,
381
  /*.n_seq_tokens =*/ n_seq_tokens,
382
  /*.n_seqs =*/ n_seqs,
383
  /*.n_seqs_unq =*/ n_seqs,
384
 
385
- /*.token =*/ ubatch.token.data(),
386
  /*.embd =*/ nullptr,
387
- /*.pos =*/ ubatch.pos.data(),
388
- /*.n_seq_id =*/ ubatch.n_seq_id.data(),
389
- /*.seq_id =*/ ubatch.seq_id.data(),
390
- /*.seq_id_unq =*/ ubatch.seq_id_unq.data(),
391
- /*.seq_idx =*/ ubatch.seq_idx.data(),
392
- /*.output =*/ ubatch.output.data(),
 
393
  };
394
 
395
  return res;
@@ -430,8 +439,6 @@ void llama_batch_allocr::split_reset() {
430
 
431
  used.clear();
432
  used.resize(get_n_tokens(), false);
433
-
434
- ubatches.clear();
435
  }
436
 
437
  llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
@@ -646,78 +653,77 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
646
 
647
  assert(n_tokens%n_seqs == 0);
648
 
649
- ubatches.emplace_back();
650
-
651
- auto & ubatch = ubatches.back();
652
 
653
  const int32_t n_pos_cur = batch.embd ? n_pos_per_embd : 1;
654
 
655
  const int64_t n_embd_all = batch.embd ? (int64_t) n_tokens*n_embd : 0;
656
  const int64_t n_pos_all = (int64_t) n_tokens*n_pos_cur;
657
 
658
- ubatch.token .resize(n_tokens);
659
- ubatch.embd .resize(n_embd_all);
660
- ubatch.pos .resize(n_pos_all);
661
- ubatch.n_seq_id .resize(n_tokens);
662
- ubatch.seq_id .resize(n_tokens);
663
- ubatch.seq_id_unq.resize(0);
664
- ubatch.seq_idx .resize(LLAMA_MAX_SEQ, -1);
665
- ubatch.output .resize(n_tokens);
666
 
667
  seq_set_t seq_set_unq;
668
 
669
  for (size_t i = 0; i < idxs.size(); ++i) {
670
  if (batch.token) {
671
- ubatch.token[i] = batch.token[idxs[i]];
672
  }
673
 
674
  if (batch.embd) {
675
- memcpy(ubatch.embd.data() + i*n_embd, batch.embd + (int64_t) idxs[i]*n_embd, n_embd*sizeof(float));
676
  }
677
 
678
  for (int j = 0; j < n_pos_cur; ++j) {
679
- ubatch.pos[j*n_tokens + i] = batch.pos[j*batch.n_tokens + idxs[i]];
680
  }
681
 
682
- ubatch.n_seq_id[i] = batch.n_seq_id[idxs[i]];
683
- ubatch.seq_id[i] = batch.seq_id[idxs[i]];
684
- ubatch.output[i] = batch.logits[idxs[i]];
685
 
686
- for (int s = 0; s < ubatch.n_seq_id[i]; ++s) {
687
- seq_set_unq.set(ubatch.seq_id[i][s]);
688
  }
689
 
690
- if (ubatch.output[i]) {
691
  out_ids.push_back(idxs[i]);
692
  }
693
  }
694
 
695
- for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
696
  if (seq_set_unq.test(s)) {
697
- ubatch.seq_idx[s] = ubatch.seq_id_unq.size();
698
- ubatch.seq_id_unq.push_back(s);
699
  }
700
  }
701
 
702
  llama_ubatch res {
703
- /*.equal_seqs =*/ equal_seqs,
704
  /*.n_tokens =*/ n_tokens,
705
  /*.n_seq_tokens =*/ n_tokens/n_seqs,
706
  /*.n_seqs =*/ n_seqs,
707
- /*.n_seqs_unq =*/ (uint32_t) ubatch.seq_id_unq.size(),
708
-
709
- /*.token =*/ batch.token ? ubatch.token.data() : nullptr,
710
- /*.embd =*/ batch.embd ? ubatch.embd.data() : nullptr,
711
- /*.pos =*/ ubatch.pos.data(),
712
- /*.n_seq_id =*/ ubatch.n_seq_id.data(),
713
- /*.seq_id =*/ ubatch.seq_id.data(),
714
- /*.seq_id_unq =*/ ubatch.seq_id_unq.data(),
715
- /*.seq_idx =*/ ubatch.seq_idx.data(),
716
- /*.output =*/ ubatch.output.data(),
 
717
  };
718
 
719
  if (debug > 0) {
720
- LLAMA_LOG_DEBUG("%s: added ubatch %d to split:\n", __func__, (int) ubatches.size() - 1);
721
 
722
  ubatch_print(res, debug);
723
  }
@@ -727,7 +733,7 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
727
 
728
  void llama_batch_allocr::ubatch_print(const llama_ubatch & ubatch, int debug) {
729
  if (debug > 0) {
730
- LLAMA_LOG_DEBUG("%s: equal_seqs = %d\n", __func__, ubatch.equal_seqs);
731
  LLAMA_LOG_DEBUG("%s: n_tokens = %d\n", __func__, ubatch.n_tokens);
732
  LLAMA_LOG_DEBUG("%s: n_seq_tokens = %d\n", __func__, ubatch.n_seq_tokens);
733
  LLAMA_LOG_DEBUG("%s: n_seqs = %d\n", __func__, ubatch.n_seqs);
 
27
  const llama_vocab & vocab,
28
  const llama_memory_i * memory,
29
  uint32_t n_embd,
30
+ uint32_t n_seq_max,
31
  bool output_all) {
32
  clear();
33
 
 
41
  // validate input batch
42
  //
43
 
44
+ if (n_seq_max > LLAMA_MAX_SEQ) {
45
+ LLAMA_LOG_ERROR("%s: n_seq_max = %d > %d\n", __func__, n_seq_max, LLAMA_MAX_SEQ);
46
+ return false;
47
+ }
48
+
49
  if (batch.token) {
50
  for (int32_t i = 0; i < batch.n_tokens; ++i) {
51
  if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= vocab.n_tokens()) {
 
58
  if (batch.seq_id) {
59
  for (int32_t i = 0; i < batch.n_tokens; ++i) {
60
  for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
61
+ if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= (llama_seq_id) n_seq_max)) {
62
+ LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s], (llama_seq_id) n_seq_max);
63
  return false;
64
  }
65
  }
 
92
 
93
  // initialize the starting position for each sequence based on the positions in the memory
94
  llama_pos p0[LLAMA_MAX_SEQ];
95
+ for (uint32_t s = 0; s < n_seq_max; ++s) {
96
  if (!memory) {
97
  // if no memory -> start from 0
98
  p0[s] = 0;
 
149
  // compute stats
150
  //
151
 
152
+ this->n_embd = n_embd;
153
+ this->n_seq_max = n_seq_max;
154
 
155
  // count the outputs in this batch
156
  for (int32_t i = 0; i < batch.n_tokens; ++i) {
157
  n_outputs += batch.logits[i] != 0;
158
  }
159
 
160
+ has_cpl = false;
161
+
162
  // determine coupled sequences
163
  // these are pairs of sequences that have at least one token in the input batch that is assigned to both of them
164
  for (int32_t i = 0; i < batch.n_tokens; ++i) {
 
198
  seq_set_map[cur].push_back(i);
199
  }
200
 
201
+ for (uint32_t s = 0; s < n_seq_max; ++s) {
202
  if (seq_set_unq.test(s)) {
203
  seq_idx[s] = seq_id_unq.size();
204
  seq_id_unq.push_back(s);
 
210
  LLAMA_LOG_DEBUG("%s: input batch info:\n", __func__);
211
 
212
  llama_ubatch ubatch {
213
+ /*.b_equal_seqs =*/ false,
214
  /*.n_tokens =*/ (uint32_t) batch.n_tokens,
215
  /*.n_seq_tokens =*/ (uint32_t) 1,
216
  /*.n_seqs =*/ (uint32_t) batch.n_tokens,
 
223
  /*.seq_id_unq =*/ this->seq_id_unq.data(),
224
  /*.seq_idx =*/ this->seq_idx.data(),
225
  /*.output =*/ batch.logits,
226
+ /*.data =*/ {},
227
  };
228
 
229
  ubatch_print(ubatch, debug);
 
251
  // consistency checks
252
  //
253
 
254
+ for (uint32_t s = 0; s < n_seq_max; ++s) {
255
  if (seq_pos[s].empty()) {
256
  continue;
257
  }
 
294
  }
295
 
296
  if (memory) {
297
+ for (uint32_t s0 = 0; s0 < n_seq_max; ++s0) {
298
+ for (uint32_t s1 = 0; s1 < n_seq_max; ++s1) {
299
  if (seq_cpl[s0][s1]) {
300
  if (memory->seq_pos_min(s0) != memory->seq_pos_min(s1) ||
301
  memory->seq_pos_max(s0) != memory->seq_pos_max(s1)) {
 
326
  //
327
  {
328
  seq_set_t cur_seq_set[LLAMA_MAX_SEQ];
329
+ for (uint32_t s = 0; s < n_seq_max; ++s) {
330
  cur_seq_set[s].set();
331
  }
332
 
333
  llama_pos cur_seq_pos[LLAMA_MAX_SEQ];
334
+ for (uint32_t s = 0; s < n_seq_max; ++s) {
335
  cur_seq_pos[s] = -1;
336
  }
337
 
 
367
  clear();
368
  split_reset();
369
 
370
+ auto udata = std::make_shared<llama_ubatch::data_t>();
371
 
372
+ udata->token .resize(n_tokens);
373
+ udata->embd .clear();
374
+ udata->pos .resize(n_tokens);
375
+ udata->n_seq_id .resize(n_tokens);
376
+ udata->seq_id .resize(n_tokens);
377
+ udata->seq_id_unq.resize(0);
378
+ udata->seq_idx .resize(LLAMA_MAX_SEQ, -1);
379
+ udata->output .resize(n_tokens);
 
 
380
 
381
  for (uint32_t s = 0; s < n_seqs; ++s) {
382
+ udata->seq_idx[s] = s;
383
+ udata->seq_id_unq.push_back(s);
384
  }
385
 
386
  llama_ubatch res {
387
+ /*.b_equal_seqs =*/ true,
388
  /*.n_tokens =*/ n_tokens,
389
  /*.n_seq_tokens =*/ n_seq_tokens,
390
  /*.n_seqs =*/ n_seqs,
391
  /*.n_seqs_unq =*/ n_seqs,
392
 
393
+ /*.token =*/ udata->token.data(),
394
  /*.embd =*/ nullptr,
395
+ /*.pos =*/ udata->pos.data(),
396
+ /*.n_seq_id =*/ udata->n_seq_id.data(),
397
+ /*.seq_id =*/ udata->seq_id.data(),
398
+ /*.seq_id_unq =*/ udata->seq_id_unq.data(),
399
+ /*.seq_idx =*/ udata->seq_idx.data(),
400
+ /*.output =*/ udata->output.data(),
401
+ /*.data =*/ std::move(udata),
402
  };
403
 
404
  return res;
 
439
 
440
  used.clear();
441
  used.resize(get_n_tokens(), false);
 
 
442
  }
443
 
444
  llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
 
653
 
654
  assert(n_tokens%n_seqs == 0);
655
 
656
+ auto udata = std::make_shared<llama_ubatch::data_t>();
 
 
657
 
658
  const int32_t n_pos_cur = batch.embd ? n_pos_per_embd : 1;
659
 
660
  const int64_t n_embd_all = batch.embd ? (int64_t) n_tokens*n_embd : 0;
661
  const int64_t n_pos_all = (int64_t) n_tokens*n_pos_cur;
662
 
663
+ udata->token .resize(n_tokens);
664
+ udata->embd .resize(n_embd_all);
665
+ udata->pos .resize(n_pos_all);
666
+ udata->n_seq_id .resize(n_tokens);
667
+ udata->seq_id .resize(n_tokens);
668
+ udata->seq_id_unq.resize(0);
669
+ udata->seq_idx .resize(LLAMA_MAX_SEQ, -1);
670
+ udata->output .resize(n_tokens);
671
 
672
  seq_set_t seq_set_unq;
673
 
674
  for (size_t i = 0; i < idxs.size(); ++i) {
675
  if (batch.token) {
676
+ udata->token[i] = batch.token[idxs[i]];
677
  }
678
 
679
  if (batch.embd) {
680
+ memcpy(udata->embd.data() + i*n_embd, batch.embd + (int64_t) idxs[i]*n_embd, n_embd*sizeof(float));
681
  }
682
 
683
  for (int j = 0; j < n_pos_cur; ++j) {
684
+ udata->pos[j*n_tokens + i] = batch.pos[j*batch.n_tokens + idxs[i]];
685
  }
686
 
687
+ udata->n_seq_id[i] = batch.n_seq_id[idxs[i]];
688
+ udata->seq_id[i] = batch.seq_id[idxs[i]];
689
+ udata->output[i] = batch.logits[idxs[i]];
690
 
691
+ for (int s = 0; s < udata->n_seq_id[i]; ++s) {
692
+ seq_set_unq.set(udata->seq_id[i][s]);
693
  }
694
 
695
+ if (udata->output[i]) {
696
  out_ids.push_back(idxs[i]);
697
  }
698
  }
699
 
700
+ for (uint32_t s = 0; s < n_seq_max; ++s) {
701
  if (seq_set_unq.test(s)) {
702
+ udata->seq_idx[s] = udata->seq_id_unq.size();
703
+ udata->seq_id_unq.push_back(s);
704
  }
705
  }
706
 
707
  llama_ubatch res {
708
+ /*.b_equal_seqs =*/ equal_seqs,
709
  /*.n_tokens =*/ n_tokens,
710
  /*.n_seq_tokens =*/ n_tokens/n_seqs,
711
  /*.n_seqs =*/ n_seqs,
712
+ /*.n_seqs_unq =*/ (uint32_t) udata->seq_id_unq.size(),
713
+
714
+ /*.token =*/ batch.token ? udata->token.data() : nullptr,
715
+ /*.embd =*/ batch.embd ? udata->embd.data() : nullptr,
716
+ /*.pos =*/ udata->pos.data(),
717
+ /*.n_seq_id =*/ udata->n_seq_id.data(),
718
+ /*.seq_id =*/ udata->seq_id.data(),
719
+ /*.seq_id_unq =*/ udata->seq_id_unq.data(),
720
+ /*.seq_idx =*/ udata->seq_idx.data(),
721
+ /*.output =*/ udata->output.data(),
722
+ /*.data =*/ std::move(udata),
723
  };
724
 
725
  if (debug > 0) {
726
+ LLAMA_LOG_DEBUG("%s: added ubatch to split:\n", __func__);
727
 
728
  ubatch_print(res, debug);
729
  }
 
733
 
734
  void llama_batch_allocr::ubatch_print(const llama_ubatch & ubatch, int debug) {
735
  if (debug > 0) {
736
+ LLAMA_LOG_DEBUG("%s: equal_seqs = %d\n", __func__, ubatch.equal_seqs());
737
  LLAMA_LOG_DEBUG("%s: n_tokens = %d\n", __func__, ubatch.n_tokens);
738
  LLAMA_LOG_DEBUG("%s: n_seq_tokens = %d\n", __func__, ubatch.n_seq_tokens);
739
  LLAMA_LOG_DEBUG("%s: n_seqs = %d\n", __func__, ubatch.n_seqs);
examples/talk-llama/llama-batch.h CHANGED
@@ -8,12 +8,17 @@
8
  #include <vector>
9
  #include <set>
10
  #include <bitset>
 
11
  #include <unordered_map>
12
 
13
  // keep this struct lightweight
14
- // it points to data in `llama_batch_allocr`
15
  struct llama_ubatch {
16
- bool equal_seqs;
 
 
 
 
 
17
  // TODO: whole_seqs for embeddings?
18
 
19
  uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs)
@@ -34,6 +39,20 @@ struct llama_ubatch {
34
  llama_seq_id * seq_id_unq; // [n_seqs_unq] | s | seq_id
35
  int32_t * seq_idx; // [LLAMA_MAX_SEQ] | - | seq_idx
36
  int8_t * output; // [n_tokens] | i | -
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  };
38
 
39
  // a helper for sanitizing, fulfilling and splitting a batch
@@ -48,6 +67,7 @@ public:
48
  const llama_vocab & vocab,
49
  const llama_memory_i * memory,
50
  uint32_t n_embd,
 
51
  bool output_all);
52
 
53
  const llama_batch & get_batch() const;
@@ -100,6 +120,7 @@ private:
100
  const uint32_t n_pos_per_embd;
101
 
102
  uint32_t n_embd;
 
103
  uint32_t n_outputs;
104
 
105
  std::array<llama_seq_id, 1> seq_id_0 = { 0 }; // default sequence id
@@ -115,7 +136,7 @@ private:
115
  using seq_cpl_t = std::vector<bool>;
116
 
117
  // helper flag to quickly determine if there are any coupled sequences in the batch
118
- bool has_cpl;
119
 
120
  std::vector<pos_set_t> seq_pos; // seq_pos[s]: the set of positions in sequence s
121
  std::vector<seq_cpl_t> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1
@@ -135,20 +156,5 @@ private:
135
  // used[i] indicates if token i has already been used in a previous ubatch
136
  std::vector<bool> used;
137
 
138
- // llama_ubatch points to this data:
139
- struct ubatch {
140
- std::vector<llama_token> token;
141
- std::vector<float> embd;
142
- std::vector<llama_pos> pos;
143
- std::vector<int32_t> n_seq_id;
144
- std::vector<llama_seq_id *> seq_id;
145
- std::vector<llama_seq_id> seq_id_unq;
146
- std::vector<int32_t> seq_idx;
147
- std::vector<int8_t> output;
148
- };
149
-
150
- // current splitting state:
151
- std::vector<ubatch> ubatches;
152
-
153
  int debug;
154
  };
 
8
  #include <vector>
9
  #include <set>
10
  #include <bitset>
11
+ #include <memory>
12
  #include <unordered_map>
13
 
14
  // keep this struct lightweight
 
15
  struct llama_ubatch {
16
+ bool equal_seqs() const {
17
+ return b_equal_seqs != 0;
18
+ }
19
+
20
+ uint32_t b_equal_seqs; // note: this is a boolean, but we use an int32_t for alignment
21
+ // otherwise address sanitizer complains
22
  // TODO: whole_seqs for embeddings?
23
 
24
  uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs)
 
39
  llama_seq_id * seq_id_unq; // [n_seqs_unq] | s | seq_id
40
  int32_t * seq_idx; // [LLAMA_MAX_SEQ] | - | seq_idx
41
  int8_t * output; // [n_tokens] | i | -
42
+
43
+ struct data_t {
44
+ std::vector<llama_token> token;
45
+ std::vector<float> embd;
46
+ std::vector<llama_pos> pos;
47
+ std::vector<int32_t> n_seq_id;
48
+ std::vector<llama_seq_id *> seq_id;
49
+ std::vector<llama_seq_id> seq_id_unq;
50
+ std::vector<int32_t> seq_idx;
51
+ std::vector<int8_t> output;
52
+ };
53
+
54
+ // the llama_ubatch pointers above point to this data if set. otherwise - points to non-owning data
55
+ std::shared_ptr<data_t> data;
56
  };
57
 
58
  // a helper for sanitizing, fulfilling and splitting a batch
 
67
  const llama_vocab & vocab,
68
  const llama_memory_i * memory,
69
  uint32_t n_embd,
70
+ uint32_t n_seq_max,
71
  bool output_all);
72
 
73
  const llama_batch & get_batch() const;
 
120
  const uint32_t n_pos_per_embd;
121
 
122
  uint32_t n_embd;
123
+ uint32_t n_seq_max;
124
  uint32_t n_outputs;
125
 
126
  std::array<llama_seq_id, 1> seq_id_0 = { 0 }; // default sequence id
 
136
  using seq_cpl_t = std::vector<bool>;
137
 
138
  // helper flag to quickly determine if there are any coupled sequences in the batch
139
+ bool has_cpl = false;
140
 
141
  std::vector<pos_set_t> seq_pos; // seq_pos[s]: the set of positions in sequence s
142
  std::vector<seq_cpl_t> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1
 
156
  // used[i] indicates if token i has already been used in a previous ubatch
157
  std::vector<bool> used;
158
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  int debug;
160
  };
examples/talk-llama/llama-chat.cpp CHANGED
@@ -56,6 +56,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
56
  { "glmedge", LLM_CHAT_TEMPLATE_GLMEDGE },
57
  { "minicpm", LLM_CHAT_TEMPLATE_MINICPM },
58
  { "exaone3", LLM_CHAT_TEMPLATE_EXAONE_3 },
 
59
  { "rwkv-world", LLM_CHAT_TEMPLATE_RWKV_WORLD },
60
  { "granite", LLM_CHAT_TEMPLATE_GRANITE },
61
  { "gigachat", LLM_CHAT_TEMPLATE_GIGACHAT },
@@ -65,6 +66,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
65
  { "llama4", LLM_CHAT_TEMPLATE_LLAMA4 },
66
  { "smolvlm", LLM_CHAT_TEMPLATE_SMOLVLM },
67
  { "hunyuan-moe", LLM_CHAT_TEMPLATE_HUNYUAN_MOE },
 
68
  };
69
 
70
  llm_chat_template llm_chat_template_from_str(const std::string & name) {
@@ -167,10 +169,13 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
167
  } else if (tmpl_contains(LU8("<|Assistant|>")) && tmpl_contains(LU8("<|User|>")) && tmpl_contains(LU8("<|end▁of▁sentence|>"))) {
168
  return LLM_CHAT_TEMPLATE_DEEPSEEK_3;
169
  } else if (tmpl_contains("[|system|]") && tmpl_contains("[|assistant|]") && tmpl_contains("[|endofturn|]")) {
 
 
 
170
  // ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/discussions/8#66bae61b1893d14ee8ed85bb
171
  // EXAONE-3.0-7.8B-Instruct
172
  return LLM_CHAT_TEMPLATE_EXAONE_3;
173
- } else if (tmpl_contains("rwkv-world")) {
174
  return LLM_CHAT_TEMPLATE_RWKV_WORLD;
175
  } else if (tmpl_contains("<|start_of_role|>")) {
176
  return LLM_CHAT_TEMPLATE_GRANITE;
@@ -188,6 +193,8 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
188
  return LLM_CHAT_TEMPLATE_DOTS1;
189
  } else if (tmpl_contains("<|startoftext|>") && tmpl_contains("<|extra_4|>")) {
190
  return LLM_CHAT_TEMPLATE_HUNYUAN_MOE;
 
 
191
  }
192
  return LLM_CHAT_TEMPLATE_UNKNOWN;
193
  }
@@ -529,6 +536,22 @@ int32_t llm_chat_apply_template(
529
  if (add_ass) {
530
  ss << "[|assistant|]";
531
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
532
  } else if (tmpl == LLM_CHAT_TEMPLATE_RWKV_WORLD) {
533
  // this template requires the model to have "\n\n" as EOT token
534
  for (size_t i = 0; i < chat.size(); i++) {
@@ -680,6 +703,25 @@ int32_t llm_chat_apply_template(
680
  ss << "<|startoftext|>" << message->content << "<|extra_0|>";
681
  }
682
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
683
  } else {
684
  // template not supported
685
  return -1;
 
56
  { "glmedge", LLM_CHAT_TEMPLATE_GLMEDGE },
57
  { "minicpm", LLM_CHAT_TEMPLATE_MINICPM },
58
  { "exaone3", LLM_CHAT_TEMPLATE_EXAONE_3 },
59
+ { "exaone4", LLM_CHAT_TEMPLATE_EXAONE_4 },
60
  { "rwkv-world", LLM_CHAT_TEMPLATE_RWKV_WORLD },
61
  { "granite", LLM_CHAT_TEMPLATE_GRANITE },
62
  { "gigachat", LLM_CHAT_TEMPLATE_GIGACHAT },
 
66
  { "llama4", LLM_CHAT_TEMPLATE_LLAMA4 },
67
  { "smolvlm", LLM_CHAT_TEMPLATE_SMOLVLM },
68
  { "hunyuan-moe", LLM_CHAT_TEMPLATE_HUNYUAN_MOE },
69
+ { "kimi-k2", LLM_CHAT_TEMPLATE_KIMI_K2 },
70
  };
71
 
72
  llm_chat_template llm_chat_template_from_str(const std::string & name) {
 
169
  } else if (tmpl_contains(LU8("<|Assistant|>")) && tmpl_contains(LU8("<|User|>")) && tmpl_contains(LU8("<|end▁of▁sentence|>"))) {
170
  return LLM_CHAT_TEMPLATE_DEEPSEEK_3;
171
  } else if (tmpl_contains("[|system|]") && tmpl_contains("[|assistant|]") && tmpl_contains("[|endofturn|]")) {
172
+ if (tmpl_contains("[|tool|]")) {
173
+ return LLM_CHAT_TEMPLATE_EXAONE_4;
174
+ }
175
  // ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/discussions/8#66bae61b1893d14ee8ed85bb
176
  // EXAONE-3.0-7.8B-Instruct
177
  return LLM_CHAT_TEMPLATE_EXAONE_3;
178
+ } else if (tmpl_contains("rwkv-world") || tmpl_contains("{{- 'User: ' + message['content']|trim + '\\n\\n' -}}")) {
179
  return LLM_CHAT_TEMPLATE_RWKV_WORLD;
180
  } else if (tmpl_contains("<|start_of_role|>")) {
181
  return LLM_CHAT_TEMPLATE_GRANITE;
 
193
  return LLM_CHAT_TEMPLATE_DOTS1;
194
  } else if (tmpl_contains("<|startoftext|>") && tmpl_contains("<|extra_4|>")) {
195
  return LLM_CHAT_TEMPLATE_HUNYUAN_MOE;
196
+ } else if (tmpl_contains("<|im_assistant|>assistant<|im_middle|>")) {
197
+ return LLM_CHAT_TEMPLATE_KIMI_K2;
198
  }
199
  return LLM_CHAT_TEMPLATE_UNKNOWN;
200
  }
 
536
  if (add_ass) {
537
  ss << "[|assistant|]";
538
  }
539
+ } else if (tmpl == LLM_CHAT_TEMPLATE_EXAONE_4) {
540
+ for (auto message : chat) {
541
+ std::string role(message->role);
542
+ if (role == "system") {
543
+ ss << "[|system|]" << trim(message->content) << "[|endofturn|]\n";
544
+ } else if (role == "user") {
545
+ ss << "[|user|]" << trim(message->content) << "\n";
546
+ } else if (role == "assistant") {
547
+ ss << "[|assistant|]" << trim(message->content) << "[|endofturn|]\n";
548
+ } else if (role == "tool") {
549
+ ss << "[|tool|]" << trim(message->content) << "[|endofturn|]\n";
550
+ }
551
+ }
552
+ if (add_ass) {
553
+ ss << "[|assistant|]";
554
+ }
555
  } else if (tmpl == LLM_CHAT_TEMPLATE_RWKV_WORLD) {
556
  // this template requires the model to have "\n\n" as EOT token
557
  for (size_t i = 0; i < chat.size(); i++) {
 
703
  ss << "<|startoftext|>" << message->content << "<|extra_0|>";
704
  }
705
  }
706
+ } else if (tmpl == LLM_CHAT_TEMPLATE_KIMI_K2) {
707
+ // moonshotai/Kimi-K2-Instruct
708
+ for (auto message : chat) {
709
+ std::string role(message->role);
710
+ if (role == "system") {
711
+ ss << "<|im_system|>system<|im_middle|>";
712
+ } else if (role == "user") {
713
+ ss << "<|im_user|>user<|im_middle|>";
714
+ } else if (role == "assistant") {
715
+ ss << "<|im_assistant|>assistant<|im_middle|>";
716
+ } else if (role == "tool") {
717
+ ss << "<|im_system|>tool<|im_middle|>";
718
+ }
719
+
720
+ ss << message->content << "<|im_end|>";
721
+ }
722
+ if (add_ass) {
723
+ ss << "<|im_assistant|>assistant<|im_middle|>";
724
+ }
725
  } else {
726
  // template not supported
727
  return -1;
examples/talk-llama/llama-chat.h CHANGED
@@ -35,6 +35,7 @@ enum llm_chat_template {
35
  LLM_CHAT_TEMPLATE_GLMEDGE,
36
  LLM_CHAT_TEMPLATE_MINICPM,
37
  LLM_CHAT_TEMPLATE_EXAONE_3,
 
38
  LLM_CHAT_TEMPLATE_RWKV_WORLD,
39
  LLM_CHAT_TEMPLATE_GRANITE,
40
  LLM_CHAT_TEMPLATE_GIGACHAT,
@@ -45,6 +46,7 @@ enum llm_chat_template {
45
  LLM_CHAT_TEMPLATE_SMOLVLM,
46
  LLM_CHAT_TEMPLATE_DOTS1,
47
  LLM_CHAT_TEMPLATE_HUNYUAN_MOE,
 
48
  LLM_CHAT_TEMPLATE_UNKNOWN,
49
  };
50
 
 
35
  LLM_CHAT_TEMPLATE_GLMEDGE,
36
  LLM_CHAT_TEMPLATE_MINICPM,
37
  LLM_CHAT_TEMPLATE_EXAONE_3,
38
+ LLM_CHAT_TEMPLATE_EXAONE_4,
39
  LLM_CHAT_TEMPLATE_RWKV_WORLD,
40
  LLM_CHAT_TEMPLATE_GRANITE,
41
  LLM_CHAT_TEMPLATE_GIGACHAT,
 
46
  LLM_CHAT_TEMPLATE_SMOLVLM,
47
  LLM_CHAT_TEMPLATE_DOTS1,
48
  LLM_CHAT_TEMPLATE_HUNYUAN_MOE,
49
+ LLM_CHAT_TEMPLATE_KIMI_K2,
50
  LLM_CHAT_TEMPLATE_UNKNOWN,
51
  };
52
 
examples/talk-llama/llama-context.cpp CHANGED
@@ -98,10 +98,20 @@ llama_context::llama_context(
98
  LLAMA_LOG_WARN("%s: n_batch is less than GGML_KQ_MASK_PAD - increasing to %d\n", __func__, GGML_KQ_MASK_PAD);
99
  cparams.n_batch = GGML_KQ_MASK_PAD;
100
  }
101
-
102
  cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
103
 
104
  cparams.op_offload = params.op_offload;
 
 
 
 
 
 
 
 
 
 
 
105
 
106
  const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
107
 
@@ -112,6 +122,7 @@ llama_context::llama_context(
112
  LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
113
  LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn);
114
  LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn);
 
115
  LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
116
  LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);
117
 
@@ -227,8 +238,8 @@ llama_context::llama_context(
227
 
228
  LLAMA_LOG_DEBUG("%s: max_nodes = %zu\n", __func__, max_nodes);
229
 
230
- // buffer used to store the computation graph and the tensor meta data
231
- buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false));
232
 
233
  // TODO: move these checks to ggml_backend_sched
234
  // enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary
@@ -267,7 +278,7 @@ llama_context::llama_context(
267
 
268
  // reserve worst-case graph
269
  if (!hparams.vocab_only && memory) {
270
- const uint32_t n_seqs = cparams.n_seq_max;
271
  const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
272
 
273
  LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
@@ -287,7 +298,7 @@ llama_context::llama_context(
287
 
288
  cross.v_embd.clear();
289
 
290
- // reserve pp graph first so that buffers are only allocated once
291
  {
292
  auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
293
  if (!gf) {
@@ -298,9 +309,9 @@ llama_context::llama_context(
298
  n_nodes_pp = ggml_graph_n_nodes(gf);
299
  }
300
 
301
- // reserve with tg graph to get the number of splits and nodes
302
  {
303
- auto * gf = graph_reserve(1, 1, 1, mctx.get());
304
  if (!gf) {
305
  throw std::runtime_error("failed to allocate compute tg buffers");
306
  }
@@ -311,6 +322,10 @@ llama_context::llama_context(
311
 
312
  // reserve again with pp graph to avoid ggml-alloc reallocations during inference
313
  {
 
 
 
 
314
  auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
315
  if (!gf) {
316
  throw std::runtime_error("failed to allocate compute pp buffers");
@@ -388,10 +403,6 @@ ggml_backend_sched_t llama_context::get_sched() const {
388
  return sched.get();
389
  }
390
 
391
- ggml_context * llama_context::get_ctx_compute() const {
392
- return ctx_compute.get();
393
- }
394
-
395
  uint32_t llama_context::n_ctx() const {
396
  return cparams.n_ctx;
397
  }
@@ -463,6 +474,11 @@ bool llama_context::kv_self_update(bool optimize) {
463
  }
464
  }
465
 
 
 
 
 
 
466
  if (!mctx->apply()) {
467
  LLAMA_LOG_ERROR("%s: failed to apply memory update\n", __func__);
468
  }
@@ -475,7 +491,7 @@ bool llama_context::kv_self_update(bool optimize) {
475
  throw std::runtime_error("failed to initialize memory context");
476
  }
477
 
478
- const uint32_t n_seqs = cparams.n_seq_max;
479
  const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
480
 
481
  auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
@@ -492,12 +508,16 @@ enum llama_pooling_type llama_context::pooling_type() const {
492
  }
493
 
494
  float * llama_context::get_logits() {
 
 
495
  return logits;
496
  }
497
 
498
  float * llama_context::get_logits_ith(int32_t i) {
499
  int64_t j = -1;
500
 
 
 
501
  try {
502
  if (logits == nullptr) {
503
  throw std::runtime_error("no logits");
@@ -534,12 +554,16 @@ float * llama_context::get_logits_ith(int32_t i) {
534
  }
535
 
536
  float * llama_context::get_embeddings() {
 
 
537
  return embd;
538
  }
539
 
540
  float * llama_context::get_embeddings_ith(int32_t i) {
541
  int64_t j = -1;
542
 
 
 
543
  try {
544
  if (embd == nullptr) {
545
  throw std::runtime_error("no embeddings");
@@ -678,38 +702,59 @@ bool llama_context::apply_adapter_cvec(
678
  return cvec.apply(model, data, len, n_embd, il_start, il_end);
679
  }
680
 
681
- llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) {
682
  if (mctx && !mctx->apply()) {
683
  LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__);
684
  ret = GGML_STATUS_FAILED;
685
  return nullptr;
686
  }
687
 
688
- auto * gf = graph_init();
689
- if (!gf) {
690
- LLAMA_LOG_ERROR("%s: failed to initialize graph\n", __func__);
691
- ret = GGML_STATUS_FAILED;
692
- return nullptr;
693
- }
694
 
695
- auto res = graph_build(ctx_compute.get(), gf, ubatch, gtype, mctx);
696
- if (!res) {
697
- LLAMA_LOG_ERROR("%s: failed to build graph\n", __func__);
698
- ret = GGML_STATUS_FAILED;
699
- return nullptr;
700
- }
701
 
702
- // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
 
703
 
704
- if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) {
705
- LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__);
706
- ret = GGML_STATUS_ALLOC_FAILED;
707
- return nullptr;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
708
  }
709
 
710
- res->set_inputs(&ubatch);
 
 
 
 
 
 
 
711
 
712
- const auto status = graph_compute(gf, ubatch.n_tokens > 1);
713
  if (status != GGML_STATUS_SUCCESS) {
714
  LLAMA_LOG_ERROR("%s: failed to compute graph, compute status: %d\n", __func__, status);
715
  ret = status;
@@ -731,16 +776,19 @@ int llama_context::encode(const llama_batch & batch_inp) {
731
 
732
  const auto & hparams = model.hparams;
733
 
734
- const int64_t n_embd = hparams.n_embd;
 
735
 
736
  // note: during encode, we always pass the full sequence starting from pos = 0
737
- if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd, true)) {
738
  LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
739
  return -1;
740
  }
741
 
742
  const uint32_t n_tokens = balloc->get_n_tokens();
743
 
 
 
744
  const llama_ubatch ubatch = balloc->split_simple(n_tokens);
745
 
746
  // micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
@@ -767,9 +815,6 @@ int llama_context::encode(const llama_batch & batch_inp) {
767
 
768
  n_outputs = n_tokens;
769
 
770
- ggml_backend_sched_reset(sched.get());
771
- ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
772
-
773
  const auto causal_attn_org = cparams.causal_attn;
774
 
775
  // always use non-causal attention for encoder graphs
@@ -778,7 +823,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
778
  cparams.causal_attn = false;
779
 
780
  ggml_status status;
781
- const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status);
782
 
783
  cparams.causal_attn = causal_attn_org;
784
 
@@ -791,10 +836,20 @@ int llama_context::encode(const llama_batch & batch_inp) {
791
  }
792
  }
793
 
 
794
  auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd();
795
 
 
 
 
 
 
 
 
 
 
796
  // extract embeddings
797
- if (t_embd) {
798
  ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
799
  GGML_ASSERT(backend_embd != nullptr);
800
 
@@ -844,9 +899,11 @@ int llama_context::encode(const llama_batch & batch_inp) {
844
  }
845
  }
846
 
847
- // Reset state for the next token before backend sync, to allow the CPU activities in the reset to
848
- // overlap with device computation.
849
- ggml_backend_sched_reset(sched.get());
 
 
850
 
851
  // TODO: hacky solution
852
  if (model.arch == LLM_ARCH_T5 && t_embd) {
@@ -899,7 +956,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
899
  // when computing embeddings, all tokens are output
900
  const bool output_all = cparams.embeddings;
901
 
902
- if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, output_all)) {
903
  LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
904
  return -1;
905
  }
@@ -927,6 +984,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
927
 
928
  // TODO: this clear of the buffer can easily be forgotten - need something better
929
  embd_seq.clear();
 
930
 
931
  bool did_optimize = false;
932
 
@@ -1005,11 +1063,8 @@ int llama_context::decode(const llama_batch & batch_inp) {
1005
  n_outputs = n_outputs_new;
1006
  }
1007
 
1008
- ggml_backend_sched_reset(sched.get());
1009
- ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
1010
-
1011
  ggml_status status;
1012
- const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status);
1013
 
1014
  if (!res) {
1015
  // the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
@@ -1149,9 +1204,6 @@ int llama_context::decode(const llama_batch & batch_inp) {
1149
  // make the outputs have the same order they had in the user-provided batch
1150
  // note: this is mostly relevant for recurrent models atm
1151
  if (!sorted_output) {
1152
- const uint32_t n_vocab = model.vocab.n_tokens();
1153
- const uint64_t n_embd = model.hparams.n_embd;
1154
-
1155
  GGML_ASSERT((size_t) n_outputs == out_ids.size());
1156
 
1157
  // TODO: is there something more efficient which also minimizes swaps?
@@ -1167,16 +1219,9 @@ int llama_context::decode(const llama_batch & batch_inp) {
1167
  continue;
1168
  }
1169
  std::swap(out_ids[i], out_ids[j_min]);
1170
- if (logits_size > 0) {
1171
- for (uint32_t k = 0; k < n_vocab; k++) {
1172
- std::swap(logits[i*n_vocab + k], logits[j_min*n_vocab + k]);
1173
- }
1174
- }
1175
- if (embd_size > 0) {
1176
- for (uint32_t k = 0; k < n_embd; k++) {
1177
- std::swap(embd[i*n_embd + k], embd[j_min*n_embd + k]);
1178
- }
1179
- }
1180
  }
1181
 
1182
  std::fill(output_ids.begin(), output_ids.end(), -1);
@@ -1190,9 +1235,11 @@ int llama_context::decode(const llama_batch & batch_inp) {
1190
  // wait for the computation to finish (automatically done when obtaining the model output)
1191
  //synchronize();
1192
 
1193
- // Reset state for the next token before backend sync, to allow the CPU activities in the reset to
1194
- // overlap with device computation.
1195
- ggml_backend_sched_reset(sched.get());
 
 
1196
 
1197
  return 0;
1198
  }
@@ -1271,24 +1318,40 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
1271
  return n_outputs_max;
1272
  }
1273
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1274
  //
1275
  // graph
1276
  //
1277
 
1278
- int32_t llama_context::graph_max_nodes() const {
1279
- return std::max<int32_t>(65536, 5*model.n_tensors());
1280
  }
1281
 
1282
- ggml_cgraph * llama_context::graph_init() {
1283
- ggml_init_params params = {
1284
- /*.mem_size =*/ buf_compute_meta.size(),
1285
- /*.mem_buffer =*/ buf_compute_meta.data(),
1286
- /*.no_alloc =*/ true,
1287
- };
1288
-
1289
- ctx_compute.reset(ggml_init(params));
1290
-
1291
- return ggml_new_graph_custom(ctx_compute.get(), graph_max_nodes(), false);
1292
  }
1293
 
1294
  ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx) {
@@ -1301,6 +1364,11 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
1301
  LLAMA_LOG_DEBUG("%s: making n_tokens a multiple of n_seqs - n_tokens = %u, n_seqs = %u, n_outputs = %u\n", __func__, n_tokens, n_seqs, n_outputs);
1302
  }
1303
 
 
 
 
 
 
1304
  // store the n_outputs as it is, and restore it afterwards
1305
  // TODO: not sure if needed, might simplify in the future by removing this
1306
  const auto save_n_outputs = this->n_outputs;
@@ -1310,17 +1378,15 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
1310
  llama_batch_allocr balloc(model.hparams.n_pos_per_embd());
1311
  llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens/n_seqs, n_seqs);
1312
 
1313
- auto * gf = graph_init();
1314
- auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx);
1315
 
1316
- this->n_outputs = save_n_outputs;
1317
 
1318
- if (!res) {
1319
- LLAMA_LOG_ERROR("%s: failed to build worst-case graph\n", __func__);
1320
- return nullptr;
1321
- }
1322
 
1323
- ggml_backend_sched_reset(sched.get());
 
 
1324
 
1325
  // initialize scheduler with the specified graph
1326
  if (!ggml_backend_sched_reserve(sched.get(), gf)) {
@@ -1331,28 +1397,27 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
1331
  return gf;
1332
  }
1333
 
1334
- llm_graph_result_ptr llama_context::graph_build(
1335
- ggml_context * ctx,
1336
- ggml_cgraph * gf,
1337
- const llama_ubatch & ubatch,
1338
- llm_graph_type gtype,
1339
- const llama_memory_context_i * mctx) {
1340
- return model.build_graph(
1341
- {
1342
- /*.ctx =*/ ctx,
1343
- /*.arch =*/ model.arch,
1344
- /*.hparams =*/ model.hparams,
1345
- /*.cparams =*/ cparams,
1346
- /*.ubatch =*/ ubatch,
1347
- /*.sched =*/ sched.get(),
1348
- /*.backend_cpu =*/ backend_cpu,
1349
- /*.cvec =*/ &cvec,
1350
- /*.loras =*/ &loras,
1351
- /*.mctx =*/ mctx,
1352
- /*.cross =*/ &cross,
1353
- /*.n_outputs =*/ n_outputs,
1354
- /*.cb =*/ graph_get_cb(),
1355
- }, gf, gtype);
1356
  }
1357
 
1358
  ggml_status llama_context::graph_compute(
@@ -1930,6 +1995,7 @@ llama_perf_context_data llama_context::perf_get_data() const {
1930
  data.t_eval_ms = 1e-3 * t_eval_us;
1931
  data.n_p_eval = std::max(1, n_p_eval);
1932
  data.n_eval = std::max(1, n_eval);
 
1933
 
1934
  return data;
1935
  }
@@ -1938,6 +2004,7 @@ void llama_context::perf_reset() {
1938
  t_start_us = ggml_time_us();
1939
  t_eval_us = n_eval = 0;
1940
  t_p_eval_us = n_p_eval = 0;
 
1941
  }
1942
 
1943
  //
@@ -2028,7 +2095,7 @@ void llama_context::opt_epoch_iter(
2028
  batch.logits [pos_batch] = true;
2029
  }
2030
 
2031
- if (!balloc->init(batch, model.vocab, nullptr, model.hparams.n_embd, true)) {
2032
  LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
2033
  return;
2034
  }
@@ -2064,8 +2131,13 @@ void llama_context::opt_epoch_iter(
2064
  break;
2065
  }
2066
 
2067
- auto * gf = graph_init();
2068
- auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx.get());
 
 
 
 
 
2069
 
2070
  struct ggml_context * ctx_compute_opt;
2071
  {
@@ -2187,6 +2259,7 @@ llama_context_params llama_context_default_params() {
2187
  /*.no_perf =*/ true,
2188
  /*.op_offload =*/ true,
2189
  /*.swa_full =*/ true,
 
2190
  };
2191
 
2192
  return result;
@@ -2807,6 +2880,7 @@ void llama_perf_context_print(const llama_context * ctx) {
2807
  LLAMA_LOG_INFO("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
2808
  __func__, data.t_eval_ms, data.n_eval, data.t_eval_ms / data.n_eval, 1e3 / data.t_eval_ms * data.n_eval);
2809
  LLAMA_LOG_INFO("%s: total time = %10.2f ms / %5d tokens\n", __func__, (t_end_ms - data.t_start_ms), (data.n_p_eval + data.n_eval));
 
2810
  }
2811
 
2812
  void llama_perf_context_reset(llama_context * ctx) {
 
98
  LLAMA_LOG_WARN("%s: n_batch is less than GGML_KQ_MASK_PAD - increasing to %d\n", __func__, GGML_KQ_MASK_PAD);
99
  cparams.n_batch = GGML_KQ_MASK_PAD;
100
  }
 
101
  cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
102
 
103
  cparams.op_offload = params.op_offload;
104
+ cparams.kv_unified = params.kv_unified;
105
+
106
+ {
107
+ const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS");
108
+ supports_set_rows = LLAMA_SET_ROWS ? (atoi(LLAMA_SET_ROWS) != 0) : false;
109
+
110
+ if (!supports_set_rows && !cparams.kv_unified) {
111
+ LLAMA_LOG_WARN("%s: non-unified KV cache requires ggml_set_rows() - forcing unified KV cache\n", __func__);
112
+ cparams.kv_unified = true;
113
+ }
114
+ }
115
 
116
  const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
117
 
 
122
  LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
123
  LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn);
124
  LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn);
125
+ LLAMA_LOG_INFO("%s: kv_unified = %s\n", __func__, cparams.kv_unified ? "true" : "false");
126
  LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
127
  LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);
128
 
 
238
 
239
  LLAMA_LOG_DEBUG("%s: max_nodes = %zu\n", __func__, max_nodes);
240
 
241
+ gf_res_prev.reset(new llm_graph_result(max_nodes));
242
+ gf_res_reserve.reset(new llm_graph_result(max_nodes));
243
 
244
  // TODO: move these checks to ggml_backend_sched
245
  // enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary
 
278
 
279
  // reserve worst-case graph
280
  if (!hparams.vocab_only && memory) {
281
+ const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
282
  const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
283
 
284
  LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
 
298
 
299
  cross.v_embd.clear();
300
 
301
+ // reserve pp (prompt processing) graph first so that buffers are only allocated once
302
  {
303
  auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
304
  if (!gf) {
 
309
  n_nodes_pp = ggml_graph_n_nodes(gf);
310
  }
311
 
312
+ // reserve with tg (token generation) graph to get the number of splits and nodes
313
  {
314
+ auto * gf = graph_reserve(n_seqs, n_seqs, n_seqs, mctx.get());
315
  if (!gf) {
316
  throw std::runtime_error("failed to allocate compute tg buffers");
317
  }
 
322
 
323
  // reserve again with pp graph to avoid ggml-alloc reallocations during inference
324
  {
325
+ // TODO: not sure if the following graph would be worster case for multi-stream KV caches:
326
+ //
327
+ // auto * gf = graph_reserve(n_tokens, 1, n_tokens, mctx.get());
328
+ //
329
  auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
330
  if (!gf) {
331
  throw std::runtime_error("failed to allocate compute pp buffers");
 
403
  return sched.get();
404
  }
405
 
 
 
 
 
406
  uint32_t llama_context::n_ctx() const {
407
  return cparams.n_ctx;
408
  }
 
474
  }
475
  }
476
 
477
+ // reset the previous graph result to make sure that it won't be reused
478
+ // TODO: change the mctx->apply() to return information if a graph reserve is needed
479
+ // reset the graph result only if the memory module did reset the scheduler
480
+ gf_res_prev->reset();
481
+
482
  if (!mctx->apply()) {
483
  LLAMA_LOG_ERROR("%s: failed to apply memory update\n", __func__);
484
  }
 
491
  throw std::runtime_error("failed to initialize memory context");
492
  }
493
 
494
+ const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
495
  const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
496
 
497
  auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
 
508
  }
509
 
510
  float * llama_context::get_logits() {
511
+ output_reorder();
512
+
513
  return logits;
514
  }
515
 
516
  float * llama_context::get_logits_ith(int32_t i) {
517
  int64_t j = -1;
518
 
519
+ output_reorder();
520
+
521
  try {
522
  if (logits == nullptr) {
523
  throw std::runtime_error("no logits");
 
554
  }
555
 
556
  float * llama_context::get_embeddings() {
557
+ output_reorder();
558
+
559
  return embd;
560
  }
561
 
562
  float * llama_context::get_embeddings_ith(int32_t i) {
563
  int64_t j = -1;
564
 
565
+ output_reorder();
566
+
567
  try {
568
  if (embd == nullptr) {
569
  throw std::runtime_error("no embeddings");
 
702
  return cvec.apply(model, data, len, n_embd, il_start, il_end);
703
  }
704
 
705
+ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) {
706
  if (mctx && !mctx->apply()) {
707
  LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__);
708
  ret = GGML_STATUS_FAILED;
709
  return nullptr;
710
  }
711
 
712
+ auto * res = gf_res_prev.get();
713
+ auto * gf = res->get_gf();
 
 
 
 
714
 
715
+ // the new graph parameters
716
+ // in order to correctly reuse a graph, it's full topology has to be uniquely determined by these parameters
717
+ const auto gparams = graph_params(res, ubatch, mctx, gtype);
 
 
 
718
 
719
+ if (res->can_reuse(gparams)) {
720
+ //LLAMA_LOG_DEBUG("%s: reusing previous graph\n", __func__);
721
 
722
+ n_reused++;
723
+ } else {
724
+ res->reset();
725
+
726
+ ggml_backend_sched_reset(sched.get());
727
+ ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
728
+
729
+ //const auto t_start_us = ggml_time_us();
730
+
731
+ gf = model.build_graph(gparams);
732
+
733
+ //LLAMA_LOG_INFO("graph build time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0);
734
+
735
+ if (!gf) {
736
+ LLAMA_LOG_ERROR("%s: failed to initialize graph\n", __func__);
737
+ ret = GGML_STATUS_FAILED;
738
+ return nullptr;
739
+ }
740
+
741
+ if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) {
742
+ LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__);
743
+ ret = GGML_STATUS_ALLOC_FAILED;
744
+ return nullptr;
745
+ }
746
  }
747
 
748
+ // set the input data for the input tensors
749
+ {
750
+ //const auto t_start_us = ggml_time_us();
751
+
752
+ res->set_inputs(&ubatch);
753
+
754
+ //LLAMA_LOG_INFO("graph set inputs time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0);
755
+ }
756
 
757
+ const auto status = graph_compute(res->get_gf(), ubatch.n_tokens > 1);
758
  if (status != GGML_STATUS_SUCCESS) {
759
  LLAMA_LOG_ERROR("%s: failed to compute graph, compute status: %d\n", __func__, status);
760
  ret = status;
 
776
 
777
  const auto & hparams = model.hparams;
778
 
779
+ const int64_t n_embd = hparams.n_embd;
780
+ const int32_t n_vocab = model.vocab.n_tokens();
781
 
782
  // note: during encode, we always pass the full sequence starting from pos = 0
783
+ if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) {
784
  LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
785
  return -1;
786
  }
787
 
788
  const uint32_t n_tokens = balloc->get_n_tokens();
789
 
790
+ // [TAG_NO_CACHE_PAD]
791
+ // TODO: add new split mode where we pad the input sequences so that ubatch.equal_seqs == true
792
  const llama_ubatch ubatch = balloc->split_simple(n_tokens);
793
 
794
  // micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
 
815
 
816
  n_outputs = n_tokens;
817
 
 
 
 
818
  const auto causal_attn_org = cparams.causal_attn;
819
 
820
  // always use non-causal attention for encoder graphs
 
823
  cparams.causal_attn = false;
824
 
825
  ggml_status status;
826
+ const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status);
827
 
828
  cparams.causal_attn = causal_attn_org;
829
 
 
836
  }
837
  }
838
 
839
+ auto * t_logits = res->get_logits();
840
  auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd();
841
 
842
+ // extract logits
843
+ if (logits && t_logits) {
844
+ ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits);
845
+ GGML_ASSERT(backend_res != nullptr);
846
+ GGML_ASSERT(logits != nullptr);
847
+
848
+ ggml_backend_tensor_get_async(backend_res, t_logits, logits, 0, n_tokens*n_vocab*sizeof(float));
849
+ }
850
+
851
  // extract embeddings
852
+ if (embd && t_embd) {
853
  ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
854
  GGML_ASSERT(backend_embd != nullptr);
855
 
 
899
  }
900
  }
901
 
902
+ if (!supports_set_rows) {
903
+ // Reset state for the next token before backend sync, to allow the CPU activities in the reset to
904
+ // overlap with device computation.
905
+ ggml_backend_sched_reset(sched.get());
906
+ }
907
 
908
  // TODO: hacky solution
909
  if (model.arch == LLM_ARCH_T5 && t_embd) {
 
956
  // when computing embeddings, all tokens are output
957
  const bool output_all = cparams.embeddings;
958
 
959
+ if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, output_all)) {
960
  LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
961
  return -1;
962
  }
 
984
 
985
  // TODO: this clear of the buffer can easily be forgotten - need something better
986
  embd_seq.clear();
987
+ output_swaps.clear();
988
 
989
  bool did_optimize = false;
990
 
 
1063
  n_outputs = n_outputs_new;
1064
  }
1065
 
 
 
 
1066
  ggml_status status;
1067
+ const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status);
1068
 
1069
  if (!res) {
1070
  // the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
 
1204
  // make the outputs have the same order they had in the user-provided batch
1205
  // note: this is mostly relevant for recurrent models atm
1206
  if (!sorted_output) {
 
 
 
1207
  GGML_ASSERT((size_t) n_outputs == out_ids.size());
1208
 
1209
  // TODO: is there something more efficient which also minimizes swaps?
 
1219
  continue;
1220
  }
1221
  std::swap(out_ids[i], out_ids[j_min]);
1222
+
1223
+ // remember the swaps and apply them lazily upon logits/embeddings access
1224
+ output_swaps.push_back({ i, j_min });
 
 
 
 
 
 
 
1225
  }
1226
 
1227
  std::fill(output_ids.begin(), output_ids.end(), -1);
 
1235
  // wait for the computation to finish (automatically done when obtaining the model output)
1236
  //synchronize();
1237
 
1238
+ if (!supports_set_rows) {
1239
+ // Reset state for the next token before backend sync, to allow the CPU activities in the reset to
1240
+ // overlap with device computation.
1241
+ ggml_backend_sched_reset(sched.get());
1242
+ }
1243
 
1244
  return 0;
1245
  }
 
1318
  return n_outputs_max;
1319
  }
1320
 
1321
+ void llama_context::output_reorder() {
1322
+ const uint32_t n_vocab = model.vocab.n_tokens();
1323
+ const uint64_t n_embd = model.hparams.n_embd;
1324
+
1325
+ for (uint32_t s = 0; s < output_swaps.size(); ++s) {
1326
+ const uint32_t i0 = output_swaps[s].i0;
1327
+ const uint32_t i1 = output_swaps[s].i1;
1328
+
1329
+ if (logits_size > 0) {
1330
+ for (uint32_t k = 0; k < n_vocab; k++) {
1331
+ std::swap(logits[i0*n_vocab + k], logits[i1*n_vocab + k]);
1332
+ }
1333
+ }
1334
+
1335
+ if (embd_size > 0) {
1336
+ for (uint32_t k = 0; k < n_embd; k++) {
1337
+ std::swap(embd[i0*n_embd + k], embd[i1*n_embd + k]);
1338
+ }
1339
+ }
1340
+ }
1341
+
1342
+ output_swaps.clear();
1343
+ }
1344
+
1345
  //
1346
  // graph
1347
  //
1348
 
1349
+ uint32_t llama_context::graph_max_nodes() const {
1350
+ return std::max<uint32_t>(1024u, 8u*model.n_tensors());
1351
  }
1352
 
1353
+ llm_graph_result * llama_context::get_gf_res_reserve() const {
1354
+ return static_cast<llm_graph_result *>(gf_res_reserve.get());
 
 
 
 
 
 
 
 
1355
  }
1356
 
1357
  ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx) {
 
1364
  LLAMA_LOG_DEBUG("%s: making n_tokens a multiple of n_seqs - n_tokens = %u, n_seqs = %u, n_outputs = %u\n", __func__, n_tokens, n_seqs, n_outputs);
1365
  }
1366
 
1367
+ ggml_backend_sched_reset(sched.get());
1368
+
1369
+ // when the scheduler is reset, we cannnot reuse the old graph, so we reset the previous graph result to prevent that
1370
+ gf_res_prev->reset();
1371
+
1372
  // store the n_outputs as it is, and restore it afterwards
1373
  // TODO: not sure if needed, might simplify in the future by removing this
1374
  const auto save_n_outputs = this->n_outputs;
 
1378
  llama_batch_allocr balloc(model.hparams.n_pos_per_embd());
1379
  llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens/n_seqs, n_seqs);
1380
 
1381
+ auto * res = gf_res_reserve.get();
 
1382
 
1383
+ const auto gparams = graph_params(res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT);
1384
 
1385
+ res->reset();
 
 
 
1386
 
1387
+ auto * gf = model.build_graph(gparams);
1388
+
1389
+ this->n_outputs = save_n_outputs;
1390
 
1391
  // initialize scheduler with the specified graph
1392
  if (!ggml_backend_sched_reserve(sched.get(), gf)) {
 
1397
  return gf;
1398
  }
1399
 
1400
+ llm_graph_params llama_context::graph_params(
1401
+ llm_graph_result * res,
1402
+ const llama_ubatch & ubatch,
1403
+ const llama_memory_context_i * mctx,
1404
+ llm_graph_type gtype) const {
1405
+ return {
1406
+ /*.arch =*/ model.arch,
1407
+ /*.hparams =*/ model.hparams,
1408
+ /*.cparams =*/ cparams,
1409
+ /*.ubatch =*/ ubatch,
1410
+ /*.gtype =*/ gtype,
1411
+ /*.sched =*/ sched.get(),
1412
+ /*.backend_cpu =*/ backend_cpu,
1413
+ /*.cvec =*/ &cvec,
1414
+ /*.loras =*/ &loras,
1415
+ /*.mctx =*/ mctx,
1416
+ /*.cross =*/ &cross,
1417
+ /*.n_outputs =*/ n_outputs,
1418
+ /*.cb =*/ graph_get_cb(),
1419
+ /*.res =*/ res,
1420
+ };
 
1421
  }
1422
 
1423
  ggml_status llama_context::graph_compute(
 
1995
  data.t_eval_ms = 1e-3 * t_eval_us;
1996
  data.n_p_eval = std::max(1, n_p_eval);
1997
  data.n_eval = std::max(1, n_eval);
1998
+ data.n_reused = std::max(0, n_reused);
1999
 
2000
  return data;
2001
  }
 
2004
  t_start_us = ggml_time_us();
2005
  t_eval_us = n_eval = 0;
2006
  t_p_eval_us = n_p_eval = 0;
2007
+ n_reused = 0;
2008
  }
2009
 
2010
  //
 
2095
  batch.logits [pos_batch] = true;
2096
  }
2097
 
2098
+ if (!balloc->init(batch, model.vocab, nullptr, model.hparams.n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) {
2099
  LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
2100
  return;
2101
  }
 
2131
  break;
2132
  }
2133
 
2134
+ auto * res = gf_res_prev.get();
2135
+
2136
+ const auto gparams = graph_params(res, ubatch, mctx.get(), LLM_GRAPH_TYPE_DEFAULT);
2137
+
2138
+ res->reset();
2139
+
2140
+ auto * gf = model.build_graph(gparams);
2141
 
2142
  struct ggml_context * ctx_compute_opt;
2143
  {
 
2259
  /*.no_perf =*/ true,
2260
  /*.op_offload =*/ true,
2261
  /*.swa_full =*/ true,
2262
+ /*.kv_unified =*/ false,
2263
  };
2264
 
2265
  return result;
 
2880
  LLAMA_LOG_INFO("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
2881
  __func__, data.t_eval_ms, data.n_eval, data.t_eval_ms / data.n_eval, 1e3 / data.t_eval_ms * data.n_eval);
2882
  LLAMA_LOG_INFO("%s: total time = %10.2f ms / %5d tokens\n", __func__, (t_end_ms - data.t_start_ms), (data.n_p_eval + data.n_eval));
2883
+ LLAMA_LOG_INFO("%s: graphs reused = %10d\n", __func__, data.n_reused);
2884
  }
2885
 
2886
  void llama_perf_context_reset(llama_context * ctx) {
examples/talk-llama/llama-context.h CHANGED
@@ -35,8 +35,6 @@ struct llama_context {
35
 
36
  ggml_backend_sched_t get_sched() const;
37
 
38
- ggml_context * get_ctx_compute() const;
39
-
40
  uint32_t n_ctx() const;
41
  uint32_t n_ctx_per_seq() const;
42
  uint32_t n_batch() const;
@@ -96,7 +94,7 @@ struct llama_context {
96
  // if memory_context is provided, it will be applied first to the context's memory
97
  // ret contains the status of the graph computation
98
  // returns nullptr only if ret != GGML_STATUS_SUCCESS
99
- llm_graph_result_ptr process_ubatch(
100
  const llama_ubatch & ubatch,
101
  llm_graph_type gtype,
102
  llama_memory_context_i * mctx,
@@ -183,15 +181,17 @@ private:
183
  // Returns max number of outputs for which space was reserved.
184
  uint32_t output_reserve(int32_t n_outputs);
185
 
 
 
186
  //
187
  // graph
188
  //
189
 
190
  public:
191
- int32_t graph_max_nodes() const;
192
 
193
- // zero-out inputs and create the ctx_compute for the compute graph
194
- ggml_cgraph * graph_init();
195
 
196
  // returns the result of ggml_backend_sched_graph_compute_async execution
197
  ggml_status graph_compute(ggml_cgraph * gf, bool batched);
@@ -200,12 +200,11 @@ public:
200
  ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx);
201
 
202
  private:
203
- llm_graph_result_ptr graph_build(
204
- ggml_context * ctx,
205
- ggml_cgraph * gf,
206
- const llama_ubatch & ubatch,
207
- llm_graph_type gtype,
208
- const llama_memory_context_i * mctx);
209
 
210
  llm_graph_cb graph_get_cb() const;
211
 
@@ -253,13 +252,18 @@ private:
253
 
254
  std::vector<int32_t> output_ids; // map batch token positions to ids of the logits and embd buffers
255
 
 
 
 
 
 
 
 
256
  ggml_backend_sched_ptr sched;
257
 
258
  ggml_backend_t backend_cpu = nullptr;
259
  std::vector<ggml_backend_ptr> backends;
260
 
261
- ggml_context_ptr ctx_compute;
262
-
263
  // training
264
  ggml_opt_context_t opt_ctx = nullptr;
265
 
@@ -275,14 +279,18 @@ private:
275
  std::vector<ggml_backend_t> backend_ptrs;
276
  std::vector<ggml_backend_buffer_type_t> backend_buft;
277
 
278
- // memory buffers used to evaluate the model
279
- std::vector<uint8_t> buf_compute_meta;
280
 
281
  // host buffer for the model output (logits and embeddings)
282
  ggml_backend_buffer_ptr buf_output;
283
 
284
  bool has_evaluated_once = false;
285
 
 
 
 
 
286
  // perf
287
  mutable int64_t t_start_us = 0;
288
  mutable int64_t t_load_us = 0;
@@ -294,4 +302,6 @@ private:
294
 
295
  mutable int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1)
296
  mutable int32_t n_eval = 0; // number of eval calls
 
 
297
  };
 
35
 
36
  ggml_backend_sched_t get_sched() const;
37
 
 
 
38
  uint32_t n_ctx() const;
39
  uint32_t n_ctx_per_seq() const;
40
  uint32_t n_batch() const;
 
94
  // if memory_context is provided, it will be applied first to the context's memory
95
  // ret contains the status of the graph computation
96
  // returns nullptr only if ret != GGML_STATUS_SUCCESS
97
+ llm_graph_result * process_ubatch(
98
  const llama_ubatch & ubatch,
99
  llm_graph_type gtype,
100
  llama_memory_context_i * mctx,
 
181
  // Returns max number of outputs for which space was reserved.
182
  uint32_t output_reserve(int32_t n_outputs);
183
 
184
+ void output_reorder();
185
+
186
  //
187
  // graph
188
  //
189
 
190
  public:
191
+ uint32_t graph_max_nodes() const;
192
 
193
+ // can reuse the llm_graph_result instance of the context (for example to update a memory module)
194
+ llm_graph_result * get_gf_res_reserve() const;
195
 
196
  // returns the result of ggml_backend_sched_graph_compute_async execution
197
  ggml_status graph_compute(ggml_cgraph * gf, bool batched);
 
200
  ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx);
201
 
202
  private:
203
+ llm_graph_params graph_params(
204
+ llm_graph_result * res,
205
+ const llama_ubatch & ubatch,
206
+ const llama_memory_context_i * mctx,
207
+ llm_graph_type gtype) const;
 
208
 
209
  llm_graph_cb graph_get_cb() const;
210
 
 
252
 
253
  std::vector<int32_t> output_ids; // map batch token positions to ids of the logits and embd buffers
254
 
255
+ struct swap_info {
256
+ uint32_t i0;
257
+ uint32_t i1;
258
+ };
259
+
260
+ std::vector<swap_info> output_swaps;
261
+
262
  ggml_backend_sched_ptr sched;
263
 
264
  ggml_backend_t backend_cpu = nullptr;
265
  std::vector<ggml_backend_ptr> backends;
266
 
 
 
267
  // training
268
  ggml_opt_context_t opt_ctx = nullptr;
269
 
 
279
  std::vector<ggml_backend_t> backend_ptrs;
280
  std::vector<ggml_backend_buffer_type_t> backend_buft;
281
 
282
+ llm_graph_result_ptr gf_res_prev;
283
+ llm_graph_result_ptr gf_res_reserve;
284
 
285
  // host buffer for the model output (logits and embeddings)
286
  ggml_backend_buffer_ptr buf_output;
287
 
288
  bool has_evaluated_once = false;
289
 
290
+ // env: LLAMA_SET_ROWS (temporary)
291
+ // ref: https://github.com/ggml-org/llama.cpp/pull/14285
292
+ bool supports_set_rows = false;
293
+
294
  // perf
295
  mutable int64_t t_start_us = 0;
296
  mutable int64_t t_load_us = 0;
 
302
 
303
  mutable int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1)
304
  mutable int32_t n_eval = 0; // number of eval calls
305
+
306
+ mutable int32_t n_reused = 0; // number of times the previous graph was reused
307
  };
examples/talk-llama/llama-cparams.h CHANGED
@@ -11,8 +11,8 @@ struct llama_cparams {
11
  uint32_t n_batch;
12
  uint32_t n_ubatch;
13
  uint32_t n_seq_max;
14
- int n_threads; // number of threads to use for generation
15
- int n_threads_batch; // number of threads to use for batch processing
16
 
17
  float rope_freq_base;
18
  float rope_freq_scale;
@@ -33,6 +33,7 @@ struct llama_cparams {
33
  bool no_perf;
34
  bool warmup;
35
  bool op_offload;
 
36
 
37
  enum llama_pooling_type pooling_type;
38
 
 
11
  uint32_t n_batch;
12
  uint32_t n_ubatch;
13
  uint32_t n_seq_max;
14
+ int32_t n_threads; // number of threads to use for generation
15
+ int32_t n_threads_batch; // number of threads to use for batch processing
16
 
17
  float rope_freq_base;
18
  float rope_freq_scale;
 
33
  bool no_perf;
34
  bool warmup;
35
  bool op_offload;
36
+ bool kv_unified;
37
 
38
  enum llama_pooling_type pooling_type;
39
 
examples/talk-llama/llama-graph.cpp CHANGED
@@ -28,6 +28,15 @@ void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
28
  }
29
  }
30
 
 
 
 
 
 
 
 
 
 
31
  void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) {
32
  if (ubatch->pos && pos) {
33
  const int64_t n_tokens = ubatch->n_tokens;
@@ -50,6 +59,14 @@ void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) {
50
  }
51
  }
52
 
 
 
 
 
 
 
 
 
53
  void llm_graph_input_attn_temp::set_input(const llama_ubatch * ubatch) {
54
  if (ubatch->pos && attn_scale) {
55
  const int64_t n_tokens = ubatch->n_tokens;
@@ -71,7 +88,7 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
71
  const int64_t n_tokens = ubatch->n_tokens;
72
 
73
  GGML_ASSERT(ggml_backend_buffer_is_host(pos_bucket->buffer));
74
- GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
75
 
76
  int32_t * data = (int32_t *) pos_bucket->data;
77
 
@@ -118,6 +135,14 @@ void llm_graph_input_out_ids::set_input(const llama_ubatch * ubatch) {
118
  }
119
  }
120
 
 
 
 
 
 
 
 
 
121
  void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
122
  if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
123
  const int64_t n_tokens = ubatch->n_tokens;
@@ -287,6 +312,24 @@ void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
287
  mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
288
  }
289
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
290
  void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
291
  mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch);
292
  mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
@@ -299,6 +342,30 @@ void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch
299
  mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
300
  }
301
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
302
  void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
303
  GGML_ASSERT(cross_kq_mask);
304
 
@@ -306,7 +373,7 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
306
  const int64_t n_tokens = ubatch->n_tokens;
307
 
308
  GGML_ASSERT(ggml_backend_buffer_is_host(cross_kq_mask->buffer));
309
- GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
310
 
311
  float * data = (float *) cross_kq_mask->data;
312
 
@@ -340,6 +407,91 @@ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
340
  inp_rs->set_input(ubatch);
341
  }
342
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
343
  //
344
  // llm_graph_context
345
  //
@@ -374,7 +526,6 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
374
  n_ctx_orig (cparams.n_ctx_orig_yarn),
375
  pooling_type (cparams.pooling_type),
376
  rope_type (hparams.rope_type),
377
- ctx0 (params.ctx),
378
  sched (params.sched),
379
  backend_cpu (params.backend_cpu),
380
  cvec (params.cvec),
@@ -382,7 +533,10 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
382
  mctx (params.mctx),
383
  cross (params.cross),
384
  cb_func (params.cb),
385
- res (std::make_unique<llm_graph_result>()) {
 
 
 
386
  }
387
 
388
  void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const {
@@ -753,20 +907,28 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
753
  cb(cur, "ffn_moe_weighted", il);
754
  }
755
 
 
 
 
 
 
 
 
 
 
 
 
756
  // aggregate experts
757
- ggml_tensor * moe_out = nullptr;
758
- for (int i = 0; i < n_expert_used; ++i) {
759
- ggml_tensor * cur_expert = ggml_view_2d(ctx0, experts, n_embd, n_tokens,
760
- experts->nb[2], i*experts->nb[1]);
761
 
762
- if (i == 0) {
763
- moe_out = cur_expert;
764
- } else {
765
- moe_out = ggml_add(ctx0, moe_out, cur_expert);
766
- }
767
  }
768
 
769
- if (n_expert_used == 1) {
770
  // avoid returning a non-contiguous tensor
771
  moe_out = ggml_cont(ctx0, moe_out);
772
  }
@@ -972,7 +1134,6 @@ ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_t
972
  }
973
 
974
  ggml_tensor * llm_graph_context::build_attn_mha(
975
- ggml_cgraph * gf,
976
  ggml_tensor * q,
977
  ggml_tensor * k,
978
  ggml_tensor * v,
@@ -982,13 +1143,16 @@ ggml_tensor * llm_graph_context::build_attn_mha(
982
  float kq_scale) const {
983
  const bool v_trans = v->nb[1] > v->nb[2];
984
 
 
 
 
 
 
985
  q = ggml_permute(ctx0, q, 0, 2, 1, 3);
986
  k = ggml_permute(ctx0, k, 0, 2, 1, 3);
987
  v = ggml_permute(ctx0, v, 0, 2, 1, 3);
988
 
989
- const auto n_tokens = q->ne[1];
990
- const auto n_head = q->ne[2];
991
- const auto n_kv = k->ne[1];
992
 
993
  ggml_tensor * cur;
994
 
@@ -1030,7 +1194,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
1030
  #endif
1031
  }
1032
 
1033
- cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
1034
  } else {
1035
  ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
1036
 
@@ -1075,7 +1239,8 @@ ggml_tensor * llm_graph_context::build_attn_mha(
1075
 
1076
  cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
1077
 
1078
- cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
 
1079
 
1080
  if (!cparams.offload_kqv) {
1081
  // all nodes between the KV store and the attention output are run on the CPU
@@ -1102,7 +1267,6 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con
1102
 
1103
  ggml_tensor * llm_graph_context::build_attn(
1104
  llm_graph_input_attn_no_cache * inp,
1105
- ggml_cgraph * gf,
1106
  ggml_tensor * wo,
1107
  ggml_tensor * wo_b,
1108
  ggml_tensor * q_cur,
@@ -1122,11 +1286,15 @@ ggml_tensor * llm_graph_context::build_attn(
1122
 
1123
  const auto & kq_mask = inp->get_kq_mask();
1124
 
 
 
 
 
1125
  ggml_tensor * q = q_cur;
1126
  ggml_tensor * k = k_cur;
1127
  ggml_tensor * v = v_cur;
1128
 
1129
- ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1130
  cb(cur, "kqv_out", il);
1131
 
1132
  if (wo) {
@@ -1156,13 +1324,14 @@ static std::unique_ptr<llm_graph_input_attn_kv_unified> build_attn_inp_kv_unifie
1156
  {
1157
  GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
1158
 
1159
- const auto n_kv = mctx_cur->get_n_kv();
1160
  const auto n_tokens = ubatch.n_tokens;
 
1161
 
1162
  inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
1163
  inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
1164
 
1165
- inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
1166
  ggml_set_input(inp->self_kq_mask);
1167
 
1168
  inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
@@ -1181,7 +1350,6 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
1181
 
1182
  ggml_tensor * llm_graph_context::build_attn(
1183
  llm_graph_input_attn_kv_unified * inp,
1184
- ggml_cgraph * gf,
1185
  ggml_tensor * wo,
1186
  ggml_tensor * wo_b,
1187
  ggml_tensor * q_cur,
@@ -1214,7 +1382,7 @@ ggml_tensor * llm_graph_context::build_attn(
1214
  ggml_tensor * k = mctx_cur->get_k(ctx0, il);
1215
  ggml_tensor * v = mctx_cur->get_v(ctx0, il);
1216
 
1217
- ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1218
  cb(cur, "kqv_out", il);
1219
 
1220
  if (wo) {
@@ -1234,7 +1402,6 @@ ggml_tensor * llm_graph_context::build_attn(
1234
 
1235
  ggml_tensor * llm_graph_context::build_attn(
1236
  llm_graph_input_attn_kv_unified_iswa * inp,
1237
- ggml_cgraph * gf,
1238
  ggml_tensor * wo,
1239
  ggml_tensor * wo_b,
1240
  ggml_tensor * q_cur,
@@ -1281,7 +1448,7 @@ ggml_tensor * llm_graph_context::build_attn(
1281
  ggml_tensor * k = mctx_cur->get_k(ctx0, il);
1282
  ggml_tensor * v = mctx_cur->get_v(ctx0, il);
1283
 
1284
- ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1285
  cb(cur, "kqv_out", il);
1286
 
1287
  if (wo) {
@@ -1314,7 +1481,6 @@ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
1314
 
1315
  ggml_tensor * llm_graph_context::build_attn(
1316
  llm_graph_input_attn_cross * inp,
1317
- ggml_cgraph * gf,
1318
  ggml_tensor * wo,
1319
  ggml_tensor * wo_b,
1320
  ggml_tensor * q_cur,
@@ -1336,7 +1502,7 @@ ggml_tensor * llm_graph_context::build_attn(
1336
  ggml_tensor * k = k_cur;
1337
  ggml_tensor * v = v_cur;
1338
 
1339
- ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1340
  cb(cur, "kqv_out", il);
1341
 
1342
  if (wo) {
@@ -1362,13 +1528,15 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
1362
 
1363
  auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur);
1364
 
 
 
1365
  {
1366
  const auto n_kv = mctx_cur->get_base()->get_n_kv();
1367
 
1368
  inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
1369
  inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
1370
 
1371
- inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
1372
  ggml_set_input(inp->self_kq_mask);
1373
 
1374
  inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
@@ -1382,7 +1550,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
1382
  inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
1383
  inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
1384
 
1385
- inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
1386
  ggml_set_input(inp->self_kq_mask_swa);
1387
 
1388
  inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
@@ -1392,7 +1560,6 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
1392
  }
1393
 
1394
  ggml_tensor * llm_graph_context::build_rs(
1395
- ggml_cgraph * gf,
1396
  ggml_tensor * s,
1397
  ggml_tensor * state_copy,
1398
  int32_t state_size,
@@ -1450,21 +1617,19 @@ llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
1450
 
1451
  ggml_tensor * llm_graph_context::build_rs(
1452
  llm_graph_input_rs * inp,
1453
- ggml_cgraph * gf,
1454
  ggml_tensor * s,
1455
  int32_t state_size,
1456
  int32_t n_seqs,
1457
  const llm_graph_get_rows_fn & get_state_rows) const {
1458
  const auto * kv_state = inp->mctx;
1459
 
1460
- return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), get_state_rows);
1461
  }
1462
 
1463
  ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
1464
  llm_graph_input_rs * inp,
1465
- ggml_cgraph * gf,
1466
  const llama_ubatch & ubatch,
1467
- int il) const {
1468
  const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
1469
 
1470
  const auto token_shift_count = hparams.token_shift_count;
@@ -1474,7 +1639,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
1474
  ggml_tensor * token_shift_all = mctx_cur->get_r_l(il);
1475
 
1476
  ggml_tensor * token_shift = build_rs(
1477
- inp, gf, token_shift_all,
1478
  hparams.n_embd_r(), n_seqs);
1479
 
1480
  token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs);
@@ -1514,7 +1679,6 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
1514
  }
1515
 
1516
  void llm_graph_context::build_pooling(
1517
- ggml_cgraph * gf,
1518
  ggml_tensor * cls,
1519
  ggml_tensor * cls_b,
1520
  ggml_tensor * cls_out,
 
28
  }
29
  }
30
 
31
+ bool llm_graph_input_embd::can_reuse(const llm_graph_params & params) {
32
+ bool res = true;
33
+
34
+ res &= (!tokens && !params.ubatch.token) || (tokens && tokens->ne[0] == params.ubatch.n_tokens);
35
+ res &= (!embd && !params.ubatch.embd) || (embd && embd->ne[0] == params.ubatch.n_tokens);
36
+
37
+ return res;
38
+ }
39
+
40
  void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) {
41
  if (ubatch->pos && pos) {
42
  const int64_t n_tokens = ubatch->n_tokens;
 
59
  }
60
  }
61
 
62
+ bool llm_graph_input_pos::can_reuse(const llm_graph_params & params) {
63
+ bool res = true;
64
+
65
+ res &= pos->ne[0] == params.ubatch.n_tokens;
66
+
67
+ return res;
68
+ }
69
+
70
  void llm_graph_input_attn_temp::set_input(const llama_ubatch * ubatch) {
71
  if (ubatch->pos && attn_scale) {
72
  const int64_t n_tokens = ubatch->n_tokens;
 
88
  const int64_t n_tokens = ubatch->n_tokens;
89
 
90
  GGML_ASSERT(ggml_backend_buffer_is_host(pos_bucket->buffer));
91
+ GGML_ASSERT(!ubatch->equal_seqs()); // TODO: use ubatch->n_seqs instead of failing
92
 
93
  int32_t * data = (int32_t *) pos_bucket->data;
94
 
 
135
  }
136
  }
137
 
138
+ bool llm_graph_input_out_ids::can_reuse(const llm_graph_params & params) {
139
+ bool res = true;
140
+
141
+ res &= n_outputs == params.n_outputs;
142
+
143
+ return res;
144
+ }
145
+
146
  void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
147
  if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
148
  const int64_t n_tokens = ubatch->n_tokens;
 
312
  mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
313
  }
314
 
315
+ bool llm_graph_input_attn_kv_unified::can_reuse(const llm_graph_params & params) {
316
+ const auto * mctx = static_cast<const llama_kv_cache_unified_context *>(params.mctx);
317
+
318
+ this->mctx = mctx;
319
+
320
+ bool res = true;
321
+
322
+ res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;
323
+ //res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
324
+
325
+ res &= self_kq_mask->ne[0] == mctx->get_n_kv();
326
+ res &= self_kq_mask->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);
327
+
328
+ res &= mctx->get_supports_set_rows(); // TODO: tmp
329
+
330
+ return res;
331
+ }
332
+
333
  void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
334
  mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch);
335
  mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
 
342
  mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
343
  }
344
 
345
+ bool llm_graph_input_attn_kv_unified_iswa::can_reuse(const llm_graph_params & params) {
346
+ const auto * mctx = static_cast<const llama_kv_cache_unified_iswa_context *>(params.mctx);
347
+
348
+ this->mctx = mctx;
349
+
350
+ bool res = true;
351
+
352
+ res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;
353
+ //res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
354
+
355
+ res &= self_k_idxs_swa->ne[0] == params.ubatch.n_tokens;
356
+ //res &= self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
357
+
358
+ res &= self_kq_mask->ne[0] == mctx->get_base()->get_n_kv();
359
+ res &= self_kq_mask->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);
360
+
361
+ res &= self_kq_mask_swa->ne[0] == mctx->get_swa()->get_n_kv();
362
+ res &= self_kq_mask_swa->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);
363
+
364
+ res &= mctx->get_base()->get_supports_set_rows(); // TODO: tmp
365
+
366
+ return res;
367
+ }
368
+
369
  void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
370
  GGML_ASSERT(cross_kq_mask);
371
 
 
373
  const int64_t n_tokens = ubatch->n_tokens;
374
 
375
  GGML_ASSERT(ggml_backend_buffer_is_host(cross_kq_mask->buffer));
376
+ GGML_ASSERT(!ubatch->equal_seqs()); // TODO: use ubatch->n_seqs instead of failing
377
 
378
  float * data = (float *) cross_kq_mask->data;
379
 
 
407
  inp_rs->set_input(ubatch);
408
  }
409
 
410
+ //
411
+ // llm_graph_result
412
+ //
413
+
414
+ llm_graph_result::llm_graph_result(int64_t max_nodes) : max_nodes(max_nodes) {
415
+ reset();
416
+
417
+ const char * LLAMA_GRAPH_RESULT_DEBUG = getenv("LLAMA_GRAPH_RESULT_DEBUG");
418
+ debug = LLAMA_GRAPH_RESULT_DEBUG ? atoi(LLAMA_GRAPH_RESULT_DEBUG) : 0;
419
+ }
420
+
421
+ int64_t llm_graph_result::get_max_nodes() const {
422
+ return max_nodes;
423
+ }
424
+
425
+ void llm_graph_result::reset() {
426
+ t_tokens = nullptr;
427
+ t_logits = nullptr;
428
+ t_embd = nullptr;
429
+ t_embd_pooled = nullptr;
430
+
431
+ params = {};
432
+
433
+ inputs.clear();
434
+
435
+ buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false));
436
+
437
+ ggml_init_params params = {
438
+ /*.mem_size =*/ buf_compute_meta.size(),
439
+ /*.mem_buffer =*/ buf_compute_meta.data(),
440
+ /*.no_alloc =*/ true,
441
+ };
442
+
443
+ ctx_compute.reset(ggml_init(params));
444
+
445
+ gf = ggml_new_graph_custom(ctx_compute.get(), max_nodes, false);
446
+ }
447
+
448
+ void llm_graph_result::set_inputs(const llama_ubatch * ubatch) {
449
+ for (auto & input : inputs) {
450
+ input->set_input(ubatch);
451
+ }
452
+ }
453
+
454
+ bool llm_graph_result::can_reuse(const llm_graph_params & params) {
455
+ if (!this->params.allow_reuse(params)) {
456
+ if (debug > 1) {
457
+ LLAMA_LOG_DEBUG("%s: cannot reuse graph due to incompatible graph parameters\n", __func__);
458
+ }
459
+
460
+ return false;
461
+ }
462
+
463
+ if (debug > 1) {
464
+ LLAMA_LOG_DEBUG("%s: checking compatibility of %d inputs:\n", __func__, (int) inputs.size());
465
+ }
466
+
467
+ bool res = true;
468
+
469
+ for (auto & input : inputs) {
470
+ const bool cur = input->can_reuse(params);
471
+
472
+ if (debug > 1) {
473
+ LLAMA_LOG_DEBUG("%s: can_reuse = %d\n", "placeholder", cur);
474
+ }
475
+
476
+ res = res && cur;
477
+ }
478
+
479
+ if (debug > 0) {
480
+ LLAMA_LOG_DEBUG("%s: can reuse graph = %d\n", __func__, res);
481
+ }
482
+
483
+ return res;
484
+ }
485
+
486
+ llm_graph_input_i * llm_graph_result::add_input(llm_graph_input_ptr input) {
487
+ inputs.emplace_back(std::move(input));
488
+ return inputs.back().get();
489
+ }
490
+
491
+ void llm_graph_result::set_params(const llm_graph_params & params) {
492
+ this->params = params;
493
+ }
494
+
495
  //
496
  // llm_graph_context
497
  //
 
526
  n_ctx_orig (cparams.n_ctx_orig_yarn),
527
  pooling_type (cparams.pooling_type),
528
  rope_type (hparams.rope_type),
 
529
  sched (params.sched),
530
  backend_cpu (params.backend_cpu),
531
  cvec (params.cvec),
 
533
  mctx (params.mctx),
534
  cross (params.cross),
535
  cb_func (params.cb),
536
+ res (params.res),
537
+ ctx0 (res->get_ctx()),
538
+ gf (res->get_gf()) {
539
+ res->set_params(params);
540
  }
541
 
542
  void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const {
 
907
  cb(cur, "ffn_moe_weighted", il);
908
  }
909
 
910
+ ggml_tensor * cur_experts[LLAMA_MAX_EXPERTS] = { nullptr };
911
+
912
+ assert(n_expert_used > 0);
913
+
914
+ // order the views before the adds
915
+ for (uint32_t i = 0; i < hparams.n_expert_used; ++i) {
916
+ cur_experts[i] = ggml_view_2d(ctx0, experts, n_embd, n_tokens, experts->nb[2], i*experts->nb[1]);
917
+
918
+ ggml_build_forward_expand(gf, cur_experts[i]);
919
+ }
920
+
921
  // aggregate experts
922
+ // note: here we explicitly use hparams.n_expert_used instead of n_expert_used
923
+ // to avoid potentially a large number of add nodes during warmup
924
+ // ref: https://github.com/ggml-org/llama.cpp/pull/14753
925
+ ggml_tensor * moe_out = cur_experts[0];
926
 
927
+ for (uint32_t i = 1; i < hparams.n_expert_used; ++i) {
928
+ moe_out = ggml_add(ctx0, moe_out, cur_experts[i]);
 
 
 
929
  }
930
 
931
+ if (hparams.n_expert_used == 1) {
932
  // avoid returning a non-contiguous tensor
933
  moe_out = ggml_cont(ctx0, moe_out);
934
  }
 
1134
  }
1135
 
1136
  ggml_tensor * llm_graph_context::build_attn_mha(
 
1137
  ggml_tensor * q,
1138
  ggml_tensor * k,
1139
  ggml_tensor * v,
 
1143
  float kq_scale) const {
1144
  const bool v_trans = v->nb[1] > v->nb[2];
1145
 
1146
+ // split the batch into streams if needed
1147
+ const auto n_stream = k->ne[3];
1148
+
1149
+ q = ggml_reshape_4d(ctx0, q, q->ne[0], q->ne[1], q->ne[2]/n_stream, n_stream);
1150
+
1151
  q = ggml_permute(ctx0, q, 0, 2, 1, 3);
1152
  k = ggml_permute(ctx0, k, 0, 2, 1, 3);
1153
  v = ggml_permute(ctx0, v, 0, 2, 1, 3);
1154
 
1155
+ const auto n_kv = k->ne[1];
 
 
1156
 
1157
  ggml_tensor * cur;
1158
 
 
1194
  #endif
1195
  }
1196
 
1197
+ cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]);
1198
  } else {
1199
  ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
1200
 
 
1239
 
1240
  cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
1241
 
1242
+ // recombine streams
1243
+ cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]);
1244
 
1245
  if (!cparams.offload_kqv) {
1246
  // all nodes between the KV store and the attention output are run on the CPU
 
1267
 
1268
  ggml_tensor * llm_graph_context::build_attn(
1269
  llm_graph_input_attn_no_cache * inp,
 
1270
  ggml_tensor * wo,
1271
  ggml_tensor * wo_b,
1272
  ggml_tensor * q_cur,
 
1286
 
1287
  const auto & kq_mask = inp->get_kq_mask();
1288
 
1289
+ // [TAG_NO_CACHE_PAD]
1290
+ // TODO: if ubatch.equal_seqs() == true, we can split the three tensors below into ubatch.n_seqs_unq streams
1291
+ assert(!ubatch.equal_seqs());
1292
+
1293
  ggml_tensor * q = q_cur;
1294
  ggml_tensor * k = k_cur;
1295
  ggml_tensor * v = v_cur;
1296
 
1297
+ ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1298
  cb(cur, "kqv_out", il);
1299
 
1300
  if (wo) {
 
1324
  {
1325
  GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
1326
 
1327
+ const auto n_kv = mctx_cur->get_n_kv();
1328
  const auto n_tokens = ubatch.n_tokens;
1329
+ const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
1330
 
1331
  inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
1332
  inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
1333
 
1334
+ inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
1335
  ggml_set_input(inp->self_kq_mask);
1336
 
1337
  inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
 
1350
 
1351
  ggml_tensor * llm_graph_context::build_attn(
1352
  llm_graph_input_attn_kv_unified * inp,
 
1353
  ggml_tensor * wo,
1354
  ggml_tensor * wo_b,
1355
  ggml_tensor * q_cur,
 
1382
  ggml_tensor * k = mctx_cur->get_k(ctx0, il);
1383
  ggml_tensor * v = mctx_cur->get_v(ctx0, il);
1384
 
1385
+ ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1386
  cb(cur, "kqv_out", il);
1387
 
1388
  if (wo) {
 
1402
 
1403
  ggml_tensor * llm_graph_context::build_attn(
1404
  llm_graph_input_attn_kv_unified_iswa * inp,
 
1405
  ggml_tensor * wo,
1406
  ggml_tensor * wo_b,
1407
  ggml_tensor * q_cur,
 
1448
  ggml_tensor * k = mctx_cur->get_k(ctx0, il);
1449
  ggml_tensor * v = mctx_cur->get_v(ctx0, il);
1450
 
1451
+ ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1452
  cb(cur, "kqv_out", il);
1453
 
1454
  if (wo) {
 
1481
 
1482
  ggml_tensor * llm_graph_context::build_attn(
1483
  llm_graph_input_attn_cross * inp,
 
1484
  ggml_tensor * wo,
1485
  ggml_tensor * wo_b,
1486
  ggml_tensor * q_cur,
 
1502
  ggml_tensor * k = k_cur;
1503
  ggml_tensor * v = v_cur;
1504
 
1505
+ ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1506
  cb(cur, "kqv_out", il);
1507
 
1508
  if (wo) {
 
1528
 
1529
  auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur);
1530
 
1531
+ const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
1532
+
1533
  {
1534
  const auto n_kv = mctx_cur->get_base()->get_n_kv();
1535
 
1536
  inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
1537
  inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
1538
 
1539
+ inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
1540
  ggml_set_input(inp->self_kq_mask);
1541
 
1542
  inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
 
1550
  inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
1551
  inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
1552
 
1553
+ inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
1554
  ggml_set_input(inp->self_kq_mask_swa);
1555
 
1556
  inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
 
1560
  }
1561
 
1562
  ggml_tensor * llm_graph_context::build_rs(
 
1563
  ggml_tensor * s,
1564
  ggml_tensor * state_copy,
1565
  int32_t state_size,
 
1617
 
1618
  ggml_tensor * llm_graph_context::build_rs(
1619
  llm_graph_input_rs * inp,
 
1620
  ggml_tensor * s,
1621
  int32_t state_size,
1622
  int32_t n_seqs,
1623
  const llm_graph_get_rows_fn & get_state_rows) const {
1624
  const auto * kv_state = inp->mctx;
1625
 
1626
+ return build_rs(s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), get_state_rows);
1627
  }
1628
 
1629
  ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
1630
  llm_graph_input_rs * inp,
 
1631
  const llama_ubatch & ubatch,
1632
+ int il) const {
1633
  const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
1634
 
1635
  const auto token_shift_count = hparams.token_shift_count;
 
1639
  ggml_tensor * token_shift_all = mctx_cur->get_r_l(il);
1640
 
1641
  ggml_tensor * token_shift = build_rs(
1642
+ inp, token_shift_all,
1643
  hparams.n_embd_r(), n_seqs);
1644
 
1645
  token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs);
 
1679
  }
1680
 
1681
  void llm_graph_context::build_pooling(
 
1682
  ggml_tensor * cls,
1683
  ggml_tensor * cls_b,
1684
  ggml_tensor * cls_out,
examples/talk-llama/llama-graph.h CHANGED
@@ -1,6 +1,7 @@
1
  #pragma once
2
 
3
  #include "llama-arch.h"
 
4
  #include "llama-hparams.h"
5
  #include "llama-adapter.h"
6
 
@@ -14,7 +15,6 @@ struct ggml_cgraph;
14
  struct ggml_context;
15
  struct ggml_tensor;
16
 
17
- struct llama_ubatch;
18
  struct llama_cparams;
19
 
20
  struct llama_memory_context_i;
@@ -69,6 +69,8 @@ struct llama_cross {
69
  std::vector<std::set<llama_seq_id>> seq_ids_enc;
70
  };
71
 
 
 
72
  //
73
  // llm_graph_input
74
  //
@@ -78,11 +80,19 @@ public:
78
  virtual ~llm_graph_input_i() = default;
79
 
80
  virtual void set_input(const llama_ubatch * ubatch) = 0;
 
 
 
 
 
 
 
 
 
81
  };
82
 
83
  using llm_graph_input_ptr = std::unique_ptr<llm_graph_input_i>;
84
 
85
-
86
  class llm_graph_input_embd : public llm_graph_input_i {
87
  public:
88
  llm_graph_input_embd() = default;
@@ -90,6 +100,8 @@ public:
90
 
91
  void set_input(const llama_ubatch * ubatch) override;
92
 
 
 
93
  ggml_tensor * tokens = nullptr; // I32 [n_batch]
94
  ggml_tensor * embd = nullptr; // F32 [n_embd, n_batch]
95
  };
@@ -101,6 +113,8 @@ public:
101
 
102
  void set_input(const llama_ubatch * ubatch) override;
103
 
 
 
104
  ggml_tensor * pos = nullptr; // I32 [n_batch]
105
 
106
  const uint32_t n_pos_per_embd = 1;
@@ -154,17 +168,19 @@ public:
154
  llm_graph_input_out_ids(
155
  const llama_hparams & hparams,
156
  const llama_cparams & cparams,
157
- int32_t n_outputs) : hparams(hparams), cparams(cparams), n_outputs(n_outputs) {}
158
  virtual ~llm_graph_input_out_ids() = default;
159
 
160
  void set_input(const llama_ubatch * ubatch) override;
161
 
 
 
162
  ggml_tensor * out_ids; // I32 [n_outputs]
163
 
164
  const llama_hparams & hparams;
165
  const llama_cparams & cparams;
166
 
167
- const int32_t n_outputs;
168
  };
169
 
170
  class llm_graph_input_mean : public llm_graph_input_i {
@@ -249,16 +265,18 @@ public:
249
 
250
  void set_input(const llama_ubatch * ubatch) override;
251
 
 
 
252
  ggml_tensor * get_k_idxs() const { return self_k_idxs; }
253
  ggml_tensor * get_v_idxs() const { return self_v_idxs; }
254
 
255
  ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
256
 
257
  ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
258
- ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
259
 
260
- ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch, 1, 1]
261
- ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch, 1, 1]
262
 
263
  const llama_hparams & hparams;
264
  const llama_cparams & cparams;
@@ -280,6 +298,8 @@ public:
280
 
281
  void set_input(const llama_ubatch * ubatch) override;
282
 
 
 
283
  ggml_tensor * get_k_idxs() const { return self_k_idxs; }
284
  ggml_tensor * get_v_idxs() const { return self_v_idxs; }
285
  ggml_tensor * get_k_idxs_swa() const { return self_k_idxs_swa; }
@@ -289,14 +309,14 @@ public:
289
  ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
290
 
291
  ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
292
- ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
293
  ggml_tensor * self_k_idxs_swa = nullptr; // I64 [n_batch]
294
- ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch]
295
 
296
- ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch, 1, 1]
297
- ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch, 1, 1]
298
- ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch, 1, 1]
299
- ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch, 1, 1]
300
 
301
  const llama_hparams & hparams;
302
  const llama_cparams & cparams;
@@ -351,40 +371,108 @@ public:
351
  // along with the input tensors, the object also provides commonly used outputs tensors, such as logits, embeddings, etc.
352
  // these are used by the llama_context to extact the relevant data, based on the compute parameters
353
 
354
- class llm_graph_result_i {
355
- public:
356
- virtual ~llm_graph_result_i() = default;
357
 
358
- virtual ggml_tensor * get_tokens() = 0;
359
- virtual ggml_tensor * get_logits() = 0;
360
- virtual ggml_tensor * get_embd() = 0;
361
- virtual ggml_tensor * get_embd_pooled() = 0;
362
 
363
- virtual void set_inputs(const llama_ubatch * ubatch) = 0;
364
- };
365
 
366
- using llm_graph_result_ptr = std::unique_ptr<llm_graph_result_i>;
 
367
 
 
368
 
369
- class llm_graph_result : public llm_graph_result_i {
370
- public:
371
- virtual ~llm_graph_result() = default;
372
 
373
- ggml_tensor * get_tokens() override { return t_tokens; }
374
- ggml_tensor * get_logits() override { return t_logits; }
375
- ggml_tensor * get_embd() override { return t_embd; }
376
- ggml_tensor * get_embd_pooled() override { return t_embd_pooled; }
377
 
378
- void set_inputs(const llama_ubatch * ubatch) override {
379
- for (auto & input : inputs) {
380
- input->set_input(ubatch);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
381
  }
382
- }
383
 
384
- llm_graph_input_i * add_input(llm_graph_input_ptr input) {
385
- inputs.emplace_back(std::move(input));
386
- return inputs.back().get();
 
 
 
 
 
 
 
 
 
 
387
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
388
 
389
  // important graph nodes
390
  ggml_tensor * t_tokens = nullptr;
@@ -393,36 +481,31 @@ public:
393
  ggml_tensor * t_embd_pooled = nullptr;
394
 
395
  std::vector<llm_graph_input_ptr> inputs;
396
- };
397
 
398
- //
399
- // llm_graph_context
400
- //
401
 
402
- // callback that allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.)
403
- using llm_graph_cb = std::function<void(const llama_ubatch & ubatch, ggml_tensor * cur, const char * name, int il)>;
404
 
405
- struct llm_graph_params {
406
- ggml_context * ctx;
407
 
408
- const llm_arch arch;
409
 
410
- const llama_hparams & hparams;
411
- const llama_cparams & cparams;
412
- const llama_ubatch & ubatch;
 
 
413
 
414
- ggml_backend_sched_t sched;
415
- ggml_backend_t backend_cpu;
416
-
417
- const llama_adapter_cvec * cvec;
418
- const llama_adapter_loras * loras;
419
- const llama_memory_context_i * mctx;
420
- const llama_cross * cross;
421
 
422
- uint32_t n_outputs;
423
 
424
- const llm_graph_cb & cb;
425
- };
 
426
 
427
  // used in build_rs to properly order writes and avoid unnecessary copies
428
  using llm_graph_get_rows_fn = std::function<ggml_tensor * (ggml_context *, ggml_tensor * states, ggml_tensor * ids)>;
@@ -463,8 +546,6 @@ struct llm_graph_context {
463
  const enum llama_pooling_type pooling_type;
464
  const enum llama_rope_type rope_type;
465
 
466
- ggml_context * ctx0 = nullptr;
467
-
468
  ggml_backend_sched_t sched;
469
 
470
  ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
@@ -476,7 +557,10 @@ struct llm_graph_context {
476
 
477
  const llm_graph_cb & cb_func;
478
 
479
- std::unique_ptr<llm_graph_result> res;
 
 
 
480
 
481
  llm_graph_context(const llm_graph_params & params);
482
  virtual ~llm_graph_context() = default;
@@ -562,7 +646,6 @@ struct llm_graph_context {
562
  //
563
 
564
  ggml_tensor * build_attn_mha(
565
- ggml_cgraph * gf,
566
  ggml_tensor * q, // [n_embd_head_q, n_head_q, n_tokens]
567
  ggml_tensor * k, // [n_embd_head_k, n_head_k, n_tokens]
568
  ggml_tensor * v, // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false)
@@ -575,7 +658,6 @@ struct llm_graph_context {
575
 
576
  ggml_tensor * build_attn(
577
  llm_graph_input_attn_no_cache * inp,
578
- ggml_cgraph * gf,
579
  ggml_tensor * wo,
580
  ggml_tensor * wo_b,
581
  ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
@@ -590,7 +672,6 @@ struct llm_graph_context {
590
 
591
  ggml_tensor * build_attn(
592
  llm_graph_input_attn_kv_unified * inp,
593
- ggml_cgraph * gf,
594
  ggml_tensor * wo,
595
  ggml_tensor * wo_b,
596
  ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
@@ -606,7 +687,6 @@ struct llm_graph_context {
606
  // note: if k_cur or v_cur are not provided, they will not be stored in the memory
607
  ggml_tensor * build_attn(
608
  llm_graph_input_attn_kv_unified_iswa * inp,
609
- ggml_cgraph * gf,
610
  ggml_tensor * wo,
611
  ggml_tensor * wo_b,
612
  ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
@@ -621,7 +701,6 @@ struct llm_graph_context {
621
 
622
  ggml_tensor * build_attn(
623
  llm_graph_input_attn_cross * inp,
624
- ggml_cgraph * gf,
625
  ggml_tensor * wo,
626
  ggml_tensor * wo_b,
627
  ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
@@ -643,7 +722,6 @@ struct llm_graph_context {
643
  // implementation in 2 separate methods. the goal is to avoid calling `ggml_build_forward_expand` in
644
  // `llama_memory_recurrent`
645
  ggml_tensor * build_rs(
646
- ggml_cgraph * gf,
647
  ggml_tensor * s,
648
  ggml_tensor * state_copy,
649
  int32_t state_size,
@@ -658,7 +736,6 @@ struct llm_graph_context {
658
 
659
  ggml_tensor * build_rs(
660
  llm_graph_input_rs * inp,
661
- ggml_cgraph * gf,
662
  ggml_tensor * s,
663
  int32_t state_size,
664
  int32_t n_seqs,
@@ -666,9 +743,8 @@ struct llm_graph_context {
666
 
667
  ggml_tensor * build_rwkv_token_shift_load(
668
  llm_graph_input_rs * inp,
669
- ggml_cgraph * gf,
670
  const llama_ubatch & ubatch,
671
- int il) const;
672
 
673
  ggml_tensor * build_rwkv_token_shift_store(
674
  ggml_tensor * token_shift,
@@ -685,7 +761,6 @@ struct llm_graph_context {
685
  //
686
 
687
  void build_pooling(
688
- ggml_cgraph * gf,
689
  ggml_tensor * cls,
690
  ggml_tensor * cls_b,
691
  ggml_tensor * cls_out,
 
1
  #pragma once
2
 
3
  #include "llama-arch.h"
4
+ #include "llama-batch.h"
5
  #include "llama-hparams.h"
6
  #include "llama-adapter.h"
7
 
 
15
  struct ggml_context;
16
  struct ggml_tensor;
17
 
 
18
  struct llama_cparams;
19
 
20
  struct llama_memory_context_i;
 
69
  std::vector<std::set<llama_seq_id>> seq_ids_enc;
70
  };
71
 
72
+ struct llm_graph_params;
73
+
74
  //
75
  // llm_graph_input
76
  //
 
80
  virtual ~llm_graph_input_i() = default;
81
 
82
  virtual void set_input(const llama_ubatch * ubatch) = 0;
83
+
84
+ // return true if the resulting input tensors using the provided graph parameters would be
85
+ // the same as the previous input tensors that we have currently stored in the object
86
+ virtual bool can_reuse(const llm_graph_params & params) {
87
+ // returning false here by default will prevent from reusing the graph if the check
88
+ // for the input type has not been implemented yet
89
+ GGML_UNUSED(params);
90
+ return false;
91
+ }
92
  };
93
 
94
  using llm_graph_input_ptr = std::unique_ptr<llm_graph_input_i>;
95
 
 
96
  class llm_graph_input_embd : public llm_graph_input_i {
97
  public:
98
  llm_graph_input_embd() = default;
 
100
 
101
  void set_input(const llama_ubatch * ubatch) override;
102
 
103
+ bool can_reuse(const llm_graph_params & params) override;
104
+
105
  ggml_tensor * tokens = nullptr; // I32 [n_batch]
106
  ggml_tensor * embd = nullptr; // F32 [n_embd, n_batch]
107
  };
 
113
 
114
  void set_input(const llama_ubatch * ubatch) override;
115
 
116
+ bool can_reuse(const llm_graph_params & params) override;
117
+
118
  ggml_tensor * pos = nullptr; // I32 [n_batch]
119
 
120
  const uint32_t n_pos_per_embd = 1;
 
168
  llm_graph_input_out_ids(
169
  const llama_hparams & hparams,
170
  const llama_cparams & cparams,
171
+ uint32_t n_outputs) : hparams(hparams), cparams(cparams), n_outputs(n_outputs) {}
172
  virtual ~llm_graph_input_out_ids() = default;
173
 
174
  void set_input(const llama_ubatch * ubatch) override;
175
 
176
+ bool can_reuse(const llm_graph_params & params) override;
177
+
178
  ggml_tensor * out_ids; // I32 [n_outputs]
179
 
180
  const llama_hparams & hparams;
181
  const llama_cparams & cparams;
182
 
183
+ const uint32_t n_outputs;
184
  };
185
 
186
  class llm_graph_input_mean : public llm_graph_input_i {
 
265
 
266
  void set_input(const llama_ubatch * ubatch) override;
267
 
268
+ bool can_reuse(const llm_graph_params & params) override;
269
+
270
  ggml_tensor * get_k_idxs() const { return self_k_idxs; }
271
  ggml_tensor * get_v_idxs() const { return self_v_idxs; }
272
 
273
  ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
274
 
275
  ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
276
+ ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
277
 
278
+ ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
279
+ ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
280
 
281
  const llama_hparams & hparams;
282
  const llama_cparams & cparams;
 
298
 
299
  void set_input(const llama_ubatch * ubatch) override;
300
 
301
+ bool can_reuse(const llm_graph_params & params) override;
302
+
303
  ggml_tensor * get_k_idxs() const { return self_k_idxs; }
304
  ggml_tensor * get_v_idxs() const { return self_v_idxs; }
305
  ggml_tensor * get_k_idxs_swa() const { return self_k_idxs_swa; }
 
309
  ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
310
 
311
  ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
312
+ ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
313
  ggml_tensor * self_k_idxs_swa = nullptr; // I64 [n_batch]
314
+ ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
315
 
316
+ ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
317
+ ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
318
+ ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
319
+ ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
320
 
321
  const llama_hparams & hparams;
322
  const llama_cparams & cparams;
 
371
  // along with the input tensors, the object also provides commonly used outputs tensors, such as logits, embeddings, etc.
372
  // these are used by the llama_context to extact the relevant data, based on the compute parameters
373
 
374
+ // callback that allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.)
375
+ using llm_graph_cb = std::function<void(const llama_ubatch & ubatch, ggml_tensor * cur, const char * name, int il)>;
 
376
 
377
+ class llm_graph_result;
 
 
 
378
 
379
+ struct llm_graph_params {
380
+ llm_arch arch = LLM_ARCH_UNKNOWN;
381
 
382
+ llama_hparams hparams;
383
+ llama_cparams cparams;
384
 
385
+ llama_ubatch ubatch; // note: intentionally make a copy
386
 
387
+ llm_graph_type gtype;
 
 
388
 
389
+ ggml_backend_sched_t sched;
390
+ ggml_backend_t backend_cpu;
 
 
391
 
392
+ const llama_adapter_cvec * cvec;
393
+ const llama_adapter_loras * loras;
394
+ const llama_memory_context_i * mctx;
395
+ const llama_cross * cross;
396
+
397
+ uint32_t n_outputs;
398
+
399
+ llm_graph_cb cb;
400
+
401
+ llm_graph_result * res;
402
+
403
+ // return true if the "other" params would result in a graph with the same topology as with the current params
404
+ // having the same topology allows us to reuse the graph in some cases
405
+ bool allow_reuse(const llm_graph_params & other) const {
406
+ // first check the ubatch
407
+ bool can_reuse_ubatch =
408
+ ubatch.equal_seqs() == other.ubatch.equal_seqs() &&
409
+ ubatch.n_tokens == other.ubatch.n_tokens &&
410
+ ubatch.n_seq_tokens == other.ubatch.n_seq_tokens &&
411
+ ubatch.n_seqs == other.ubatch.n_seqs &&
412
+ ubatch.n_seqs_unq == other.ubatch.n_seqs_unq &&
413
+ (
414
+ (!ubatch.token && !other.ubatch.token) ||
415
+ (!ubatch.embd && !other.ubatch.embd)
416
+ );
417
+
418
+ if (can_reuse_ubatch && !ubatch.equal_seqs()) {
419
+ if (!ubatch.data) {
420
+ // if the old ubatch does not own it's data, then we cannot guarantee that it is still alive, and
421
+ // therefore we cannot perform the sequence id check. normally should never happen
422
+ can_reuse_ubatch = false;
423
+ } else {
424
+ for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
425
+ can_reuse_ubatch &= ubatch.seq_id_unq[s] == other.ubatch.seq_id_unq[s];
426
+ }
427
+ }
428
  }
 
429
 
430
+ if (!can_reuse_ubatch) {
431
+ return false;
432
+ }
433
+
434
+ return
435
+ cparams.embeddings == other.cparams.embeddings &&
436
+ cparams.causal_attn == other.cparams.causal_attn &&
437
+ arch == other.arch &&
438
+ gtype == other.gtype &&
439
+ cvec == other.cvec &&
440
+ loras == other.loras &&
441
+ cross == other.cross &&
442
+ n_outputs == other.n_outputs;
443
  }
444
+ };
445
+
446
+ class llm_graph_result {
447
+ public:
448
+ llm_graph_result(int64_t max_nodes);
449
+
450
+ virtual ~llm_graph_result() = default;
451
+
452
+ ggml_tensor * get_tokens() const { return t_tokens; }
453
+ ggml_tensor * get_logits() const { return t_logits; }
454
+ ggml_tensor * get_embd() const { return t_embd; }
455
+ ggml_tensor * get_embd_pooled() const { return t_embd_pooled; }
456
+
457
+ ggml_cgraph * get_gf() const { return gf; }
458
+ ggml_context * get_ctx() const { return ctx_compute.get(); }
459
+
460
+ int64_t get_max_nodes() const;
461
+
462
+ void reset();
463
+
464
+ void set_inputs(const llama_ubatch * ubatch);
465
+
466
+ // try to update the existing graph result using the new graph parameters in order to reuse it
467
+ // this can only be done if we determine that the resulting graph using the new graph parameters
468
+ // would be identical to the existing graph. in that case, we simply have to update the memory
469
+ // contexts of the input tensors of the graph and we can reuse it for another computation
470
+ // return true if the graph was updated and can be reused
471
+ bool can_reuse(const llm_graph_params & params);
472
+
473
+ llm_graph_input_i * add_input(llm_graph_input_ptr input);
474
+
475
+ void set_params(const llm_graph_params & params);
476
 
477
  // important graph nodes
478
  ggml_tensor * t_tokens = nullptr;
 
481
  ggml_tensor * t_embd_pooled = nullptr;
482
 
483
  std::vector<llm_graph_input_ptr> inputs;
 
484
 
485
+ ggml_context_ptr ctx_compute;
 
 
486
 
487
+ // memory buffers used to evaluate the model
488
+ std::vector<uint8_t> buf_compute_meta;
489
 
490
+ ggml_cgraph * gf;
 
491
 
492
+ int64_t max_nodes;
493
 
494
+ private:
495
+ // keep a copy of the previous graph parameters
496
+ // we will use this to determine whether the graph can be reused by comparing them with the new parameters
497
+ // note: these are updated after constructing the new graph
498
+ llm_graph_params params;
499
 
500
+ // env: LLAMA_GRAPH_RESULT_DEBUG
501
+ int debug = 0;
502
+ };
 
 
 
 
503
 
504
+ using llm_graph_result_ptr = std::unique_ptr<llm_graph_result>;
505
 
506
+ //
507
+ // llm_graph_context
508
+ //
509
 
510
  // used in build_rs to properly order writes and avoid unnecessary copies
511
  using llm_graph_get_rows_fn = std::function<ggml_tensor * (ggml_context *, ggml_tensor * states, ggml_tensor * ids)>;
 
546
  const enum llama_pooling_type pooling_type;
547
  const enum llama_rope_type rope_type;
548
 
 
 
549
  ggml_backend_sched_t sched;
550
 
551
  ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
 
557
 
558
  const llm_graph_cb & cb_func;
559
 
560
+ llm_graph_result * res;
561
+
562
+ ggml_context * ctx0 = nullptr;
563
+ ggml_cgraph * gf = nullptr;
564
 
565
  llm_graph_context(const llm_graph_params & params);
566
  virtual ~llm_graph_context() = default;
 
646
  //
647
 
648
  ggml_tensor * build_attn_mha(
 
649
  ggml_tensor * q, // [n_embd_head_q, n_head_q, n_tokens]
650
  ggml_tensor * k, // [n_embd_head_k, n_head_k, n_tokens]
651
  ggml_tensor * v, // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false)
 
658
 
659
  ggml_tensor * build_attn(
660
  llm_graph_input_attn_no_cache * inp,
 
661
  ggml_tensor * wo,
662
  ggml_tensor * wo_b,
663
  ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
 
672
 
673
  ggml_tensor * build_attn(
674
  llm_graph_input_attn_kv_unified * inp,
 
675
  ggml_tensor * wo,
676
  ggml_tensor * wo_b,
677
  ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
 
687
  // note: if k_cur or v_cur are not provided, they will not be stored in the memory
688
  ggml_tensor * build_attn(
689
  llm_graph_input_attn_kv_unified_iswa * inp,
 
690
  ggml_tensor * wo,
691
  ggml_tensor * wo_b,
692
  ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
 
701
 
702
  ggml_tensor * build_attn(
703
  llm_graph_input_attn_cross * inp,
 
704
  ggml_tensor * wo,
705
  ggml_tensor * wo_b,
706
  ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
 
722
  // implementation in 2 separate methods. the goal is to avoid calling `ggml_build_forward_expand` in
723
  // `llama_memory_recurrent`
724
  ggml_tensor * build_rs(
 
725
  ggml_tensor * s,
726
  ggml_tensor * state_copy,
727
  int32_t state_size,
 
736
 
737
  ggml_tensor * build_rs(
738
  llm_graph_input_rs * inp,
 
739
  ggml_tensor * s,
740
  int32_t state_size,
741
  int32_t n_seqs,
 
743
 
744
  ggml_tensor * build_rwkv_token_shift_load(
745
  llm_graph_input_rs * inp,
 
746
  const llama_ubatch & ubatch,
747
+ int il) const;
748
 
749
  ggml_tensor * build_rwkv_token_shift_store(
750
  ggml_tensor * token_shift,
 
761
  //
762
 
763
  void build_pooling(
 
764
  ggml_tensor * cls,
765
  ggml_tensor * cls_b,
766
  ggml_tensor * cls_out,
examples/talk-llama/llama-hparams.cpp CHANGED
@@ -65,6 +65,46 @@ uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const {
65
  return n_embd_head_v * n_head_kv;
66
  }
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  uint32_t llama_hparams::n_embd_r() const {
69
  if (wkv_head_size != 0) {
70
  // for RWKV models
 
65
  return n_embd_head_v * n_head_kv;
66
  }
67
 
68
+ bool llama_hparams::is_n_embd_k_gqa_variable() const {
69
+ const uint32_t val = n_embd_k_gqa();
70
+ for (uint32_t il = 0; il < n_layer; ++il) {
71
+ if (val != n_embd_k_gqa(il)) {
72
+ return true;
73
+ }
74
+ }
75
+
76
+ return false;
77
+ }
78
+
79
+ bool llama_hparams::is_n_embd_v_gqa_variable() const {
80
+ const uint32_t val = n_embd_v_gqa();
81
+ for (uint32_t il = 0; il < n_layer; ++il) {
82
+ if (val != n_embd_v_gqa(il)) {
83
+ return true;
84
+ }
85
+ }
86
+
87
+ return false;
88
+ }
89
+
90
+ uint32_t llama_hparams::n_embd_k_gqa_max() const {
91
+ uint32_t val = n_embd_k_gqa();
92
+ for (uint32_t il = 0; il < n_layer; ++il) {
93
+ val = std::max(val, n_embd_k_gqa(il));
94
+ }
95
+
96
+ return val;
97
+ }
98
+
99
+ uint32_t llama_hparams::n_embd_v_gqa_max() const {
100
+ uint32_t val = n_embd_v_gqa();
101
+ for (uint32_t il = 0; il < n_layer; ++il) {
102
+ val = std::max(val, n_embd_v_gqa(il));
103
+ }
104
+
105
+ return val;
106
+ }
107
+
108
  uint32_t llama_hparams::n_embd_r() const {
109
  if (wkv_head_size != 0) {
110
  // for RWKV models
examples/talk-llama/llama-hparams.h CHANGED
@@ -6,7 +6,7 @@
6
 
7
  // bump if necessary
8
  #define LLAMA_MAX_LAYERS 512
9
- #define LLAMA_MAX_EXPERTS 256 // DeepSeekV3
10
 
11
  enum llama_expert_gating_func_type {
12
  LLAMA_EXPERT_GATING_FUNC_TYPE_NONE = 0,
@@ -98,7 +98,7 @@ struct llama_hparams {
98
  float rope_freq_scale_train;
99
  float rope_freq_scale_train_swa;
100
  uint32_t n_ctx_orig_yarn;
101
- float rope_yarn_log_mul;
102
 
103
  std::array<int, 4> rope_sections;
104
 
@@ -191,6 +191,14 @@ struct llama_hparams {
191
  // dimension of value embeddings across all k-v heads
192
  uint32_t n_embd_v_gqa(uint32_t il = 0) const;
193
 
 
 
 
 
 
 
 
 
194
  // dimension of the rolling state embeddings
195
  // corresponds to Mamba's conv_states size or RWKV's token_shift states size
196
  uint32_t n_embd_r() const;
 
6
 
7
  // bump if necessary
8
  #define LLAMA_MAX_LAYERS 512
9
+ #define LLAMA_MAX_EXPERTS 384 // Kimi-K2
10
 
11
  enum llama_expert_gating_func_type {
12
  LLAMA_EXPERT_GATING_FUNC_TYPE_NONE = 0,
 
98
  float rope_freq_scale_train;
99
  float rope_freq_scale_train_swa;
100
  uint32_t n_ctx_orig_yarn;
101
+ float rope_yarn_log_mul = 0.0f;
102
 
103
  std::array<int, 4> rope_sections;
104
 
 
191
  // dimension of value embeddings across all k-v heads
192
  uint32_t n_embd_v_gqa(uint32_t il = 0) const;
193
 
194
+ // true if any layer has a different n_embd_k_gqa/n_embd_v_gqa
195
+ bool is_n_embd_k_gqa_variable() const;
196
+ bool is_n_embd_v_gqa_variable() const;
197
+
198
+ // return the maximum n_embd_k_gqa/n_embd_v_gqa across all layers
199
+ uint32_t n_embd_k_gqa_max() const;
200
+ uint32_t n_embd_v_gqa_max() const;
201
+
202
  // dimension of the rolling state embeddings
203
  // corresponds to Mamba's conv_states size or RWKV's token_shift states size
204
  uint32_t n_embd_r() const;
examples/talk-llama/llama-kv-cache-unified-iswa.cpp CHANGED
@@ -18,16 +18,17 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
18
  bool v_trans,
19
  bool offload,
20
  bool swa_full,
 
21
  uint32_t kv_size,
22
  uint32_t n_seq_max,
23
  uint32_t n_ubatch,
24
- uint32_t n_pad) : hparams(model.hparams) {
25
  llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
26
  llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); };
27
 
28
  const uint32_t size_base = kv_size;
29
 
30
- uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*n_seq_max + n_ubatch, n_pad));
31
 
32
  // when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size
33
  if (swa_full) {
@@ -41,14 +42,14 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
41
 
42
  kv_base = std::make_unique<llama_kv_cache_unified>(
43
  model, std::move(filter_base), type_k, type_v,
44
- v_trans, offload, size_base, n_seq_max, n_pad,
45
  0, LLAMA_SWA_TYPE_NONE);
46
 
47
  LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa);
48
 
49
  kv_swa = std::make_unique<llama_kv_cache_unified>(
50
  model, std::move(filter_swa), type_k, type_v,
51
- v_trans, offload, size_swa, n_seq_max, n_pad,
52
  hparams.n_swa, hparams.swa_type);
53
  }
54
 
@@ -100,6 +101,11 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
100
 
101
  // first try simple split
102
  do {
 
 
 
 
 
103
  balloc.split_reset();
104
 
105
  std::vector<llama_ubatch> ubatches;
@@ -140,7 +146,7 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
140
 
141
  std::vector<llama_ubatch> ubatches;
142
  while (true) {
143
- auto ubatch = balloc.split_equal(n_ubatch, false);
144
 
145
  if (ubatch.n_tokens == 0) {
146
  break;
 
18
  bool v_trans,
19
  bool offload,
20
  bool swa_full,
21
+ bool unified,
22
  uint32_t kv_size,
23
  uint32_t n_seq_max,
24
  uint32_t n_ubatch,
25
+ uint32_t n_pad) : hparams(model.hparams), unified(unified) {
26
  llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
27
  llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); };
28
 
29
  const uint32_t size_base = kv_size;
30
 
31
+ uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*(unified ? n_seq_max : 1) + n_ubatch, n_pad));
32
 
33
  // when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size
34
  if (swa_full) {
 
42
 
43
  kv_base = std::make_unique<llama_kv_cache_unified>(
44
  model, std::move(filter_base), type_k, type_v,
45
+ v_trans, offload, unified, size_base, n_seq_max, n_pad,
46
  0, LLAMA_SWA_TYPE_NONE);
47
 
48
  LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa);
49
 
50
  kv_swa = std::make_unique<llama_kv_cache_unified>(
51
  model, std::move(filter_swa), type_k, type_v,
52
+ v_trans, offload, unified, size_swa, n_seq_max, n_pad,
53
  hparams.n_swa, hparams.swa_type);
54
  }
55
 
 
101
 
102
  // first try simple split
103
  do {
104
+ if (!unified) {
105
+ // requires equal splits, so we skip the simple split
106
+ break;
107
+ }
108
+
109
  balloc.split_reset();
110
 
111
  std::vector<llama_ubatch> ubatches;
 
146
 
147
  std::vector<llama_ubatch> ubatches;
148
  while (true) {
149
+ auto ubatch = balloc.split_equal(n_ubatch, !unified);
150
 
151
  if (ubatch.n_tokens == 0) {
152
  break;
examples/talk-llama/llama-kv-cache-unified-iswa.h CHANGED
@@ -20,6 +20,7 @@ public:
20
  bool v_trans,
21
  bool offload,
22
  bool swa_full,
 
23
  uint32_t kv_size,
24
  uint32_t n_seq_max,
25
  uint32_t n_ubatch,
@@ -68,6 +69,8 @@ public:
68
  private:
69
  const llama_hparams & hparams;
70
 
 
 
71
  std::unique_ptr<llama_kv_cache_unified> kv_base;
72
  std::unique_ptr<llama_kv_cache_unified> kv_swa;
73
  };
 
20
  bool v_trans,
21
  bool offload,
22
  bool swa_full,
23
+ bool unified,
24
  uint32_t kv_size,
25
  uint32_t n_seq_max,
26
  uint32_t n_ubatch,
 
69
  private:
70
  const llama_hparams & hparams;
71
 
72
+ const bool unified;
73
+
74
  std::unique_ptr<llama_kv_cache_unified> kv_base;
75
  std::unique_ptr<llama_kv_cache_unified> kv_swa;
76
  };
examples/talk-llama/llama-kv-cache-unified.cpp CHANGED
@@ -23,13 +23,14 @@ llama_kv_cache_unified::llama_kv_cache_unified(
23
  ggml_type type_v,
24
  bool v_trans,
25
  bool offload,
 
26
  uint32_t kv_size,
27
  uint32_t n_seq_max,
28
  uint32_t n_pad,
29
  uint32_t n_swa,
30
  llama_swa_type swa_type) :
31
  model(model), hparams(model.hparams), v_trans(v_trans),
32
- n_seq_max(n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) {
33
 
34
  GGML_ASSERT(kv_size % n_pad == 0);
35
 
@@ -45,7 +46,7 @@ llama_kv_cache_unified::llama_kv_cache_unified(
45
  auto it = ctx_map.find(buft);
46
  if (it == ctx_map.end()) {
47
  ggml_init_params params = {
48
- /*.mem_size =*/ size_t(2u*n_layer_cache*ggml_tensor_overhead()),
49
  /*.mem_buffer =*/ NULL,
50
  /*.no_alloc =*/ true,
51
  };
@@ -64,9 +65,33 @@ llama_kv_cache_unified::llama_kv_cache_unified(
64
  return it->second;
65
  };
66
 
67
- head = 0;
68
 
69
- cells.resize(kv_size);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
  for (uint32_t il = 0; il < n_layer_cache; il++) {
72
  if (filter && !filter(il)) {
@@ -74,8 +99,9 @@ llama_kv_cache_unified::llama_kv_cache_unified(
74
  continue;
75
  }
76
 
77
- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
78
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
 
79
 
80
  const char * dev_name = "CPU";
81
 
@@ -98,14 +124,23 @@ llama_kv_cache_unified::llama_kv_cache_unified(
98
  ggml_tensor * k;
99
  ggml_tensor * v;
100
 
101
- k = ggml_new_tensor_2d(ctx, type_k, n_embd_k_gqa, kv_size);
102
- v = ggml_new_tensor_2d(ctx, type_v, n_embd_v_gqa, kv_size);
103
 
104
  ggml_format_name(k, "cache_k_l%d", il);
105
  ggml_format_name(v, "cache_v_l%d", il);
106
 
 
 
 
 
 
 
 
 
107
  map_layer_ids[il] = layers.size();
108
- layers.push_back({ il, k, v });
 
109
  }
110
 
111
  // TODO: this is temporary until we support passing reuse layer filters [KV_REUSE]
@@ -148,8 +183,8 @@ llama_kv_cache_unified::llama_kv_cache_unified(
148
  const size_t memory_size_k = size_k_bytes();
149
  const size_t memory_size_v = size_v_bytes();
150
 
151
- LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u seqs), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
152
- (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), kv_size, (int) layers.size(), n_seq_max,
153
  ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
154
  ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
155
  }
@@ -158,7 +193,12 @@ llama_kv_cache_unified::llama_kv_cache_unified(
158
  debug = LLAMA_KV_CACHE_DEBUG ? atoi(LLAMA_KV_CACHE_DEBUG) : 0;
159
 
160
  const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS");
161
- supports_set_rows = LLAMA_SET_ROWS ? atoi(LLAMA_SET_ROWS) : 0;
 
 
 
 
 
162
 
163
  if (!supports_set_rows) {
164
  LLAMA_LOG_WARN("%s: LLAMA_SET_ROWS=0, using old ggml_cpy() method for backwards compatibility\n", __func__);
@@ -166,9 +206,10 @@ llama_kv_cache_unified::llama_kv_cache_unified(
166
  }
167
 
168
  void llama_kv_cache_unified::clear(bool data) {
169
- cells.reset();
170
-
171
- head = 0;
 
172
 
173
  if (data) {
174
  for (auto & buf : bufs) {
@@ -178,6 +219,11 @@ void llama_kv_cache_unified::clear(bool data) {
178
  }
179
 
180
  bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
 
 
 
 
 
181
  uint32_t new_head = cells.size();
182
 
183
  if (p0 < 0) {
@@ -224,30 +270,94 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
224
  }
225
 
226
  void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
227
- if (seq_id_src == seq_id_dst) {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
  return;
229
  }
230
 
231
- if (p0 < 0) {
232
- p0 = 0;
 
 
 
 
233
  }
234
 
235
- if (p1 < 0) {
236
- p1 = std::numeric_limits<llama_pos>::max();
237
  }
238
 
239
- for (uint32_t i = 0; i < cells.size(); ++i) {
240
- if (!cells.pos_in(i, p0, p1)) {
241
- continue;
242
- }
 
243
 
244
- if (cells.seq_has(i, seq_id_src)) {
245
- cells.seq_add(i, seq_id_dst);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
  }
247
  }
 
 
 
 
 
 
248
  }
249
 
250
  void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) {
 
 
 
 
 
251
  uint32_t new_head = cells.size();
252
 
253
  for (uint32_t i = 0; i < cells.size(); ++i) {
@@ -265,6 +375,11 @@ void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) {
265
  }
266
 
267
  void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
 
 
 
 
 
268
  if (shift == 0) {
269
  return;
270
  }
@@ -304,6 +419,10 @@ void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_po
304
  }
305
 
306
  void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
 
 
 
 
307
  if (d == 1) {
308
  return;
309
  }
@@ -333,10 +452,18 @@ void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_po
333
  }
334
 
335
  llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const {
 
 
 
 
336
  return cells.seq_pos_min(seq_id);
337
  }
338
 
339
  llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
 
 
 
 
340
  return cells.seq_pos_max(seq_id);
341
  }
342
 
@@ -351,7 +478,7 @@ llama_memory_context_ptr llama_kv_cache_unified::init_batch(
351
 
352
  std::vector<llama_ubatch> ubatches;
353
  while (true) {
354
- auto ubatch = balloc.split_simple(n_ubatch);
355
 
356
  if (ubatch.n_tokens == 0) {
357
  break;
@@ -387,7 +514,10 @@ llama_memory_context_ptr llama_kv_cache_unified::init_update(llama_context * lct
387
  defrag_info dinfo;
388
 
389
  // see if we need to defrag
390
- {
 
 
 
391
  bool do_defrag = optimize;
392
 
393
  const auto thold = lctx->get_cparams().defrag_thold;
@@ -411,22 +541,22 @@ llama_memory_context_ptr llama_kv_cache_unified::init_update(llama_context * lct
411
  }
412
  }
413
 
414
- return std::make_unique<llama_kv_cache_unified_context>(this, lctx, do_shift, std::move(dinfo));
415
  }
416
 
417
  llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
418
  llama_kv_cache_unified::slot_info_vec_t res;
419
 
420
- struct state {
421
- uint32_t head_old; // old position of the head, before placing the ubatch
422
-
423
  slot_info sinfo; // slot info for the ubatch
424
 
425
- llama_kv_cells_unified cells; // copy of the old cells, before placing the ubatch
 
 
426
  };
427
 
428
  // remember the old state of the cells so we can restore it in the end
429
- std::vector<state> states;
430
 
431
  bool success = true;
432
 
@@ -445,16 +575,35 @@ llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const st
445
  res.push_back(sinfo_new);
446
 
447
  // store the old state of the cells in the recovery stack
448
- states.push_back({head, sinfo_new, cells.cp(sinfo_new.idxs)});
 
 
 
 
 
 
 
 
 
 
449
 
450
  // now emplace the ubatch
451
  apply_ubatch(sinfo_new, ubatch);
452
  }
453
 
 
 
454
  // iterate backwards and restore the cells to their original state
455
  for (auto it = states.rbegin(); it != states.rend(); ++it) {
456
- cells.set(it->sinfo.idxs, it->cells);
457
- head = it->head_old;
 
 
 
 
 
 
 
458
  }
459
 
460
  if (!success) {
@@ -464,11 +613,38 @@ llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const st
464
  return res;
465
  }
466
 
467
- bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const defrag_info & dinfo) {
468
  bool updated = false;
469
 
470
  auto * sched = lctx->get_sched();
471
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
472
  if (do_shift) {
473
  if (!get_can_shift()) {
474
  GGML_ABORT("The current KV cache / model configuration does not support K-shift");
@@ -480,14 +656,11 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d
480
  if (hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
481
  ggml_backend_sched_reset(sched);
482
 
483
- auto * gf = lctx->graph_init();
484
 
485
- auto res = build_graph_shift(lctx->get_cparams(), lctx->get_ctx_compute(), gf);
486
- if (!res) {
487
- LLAMA_LOG_ERROR("%s: failed to build graph for K-shift\n", __func__);
488
- return updated;
489
- }
490
 
 
491
  if (!ggml_backend_sched_alloc_graph(sched, gf)) {
492
  LLAMA_LOG_ERROR("%s: failed to allocate compute graph for K-shift\n", __func__);
493
  return updated;
@@ -503,12 +676,20 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d
503
  updated = true;
504
  }
505
 
506
- cells.reset_shift();
 
 
 
 
507
  }
508
 
509
  if (!dinfo.empty()) {
510
  LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__);
511
 
 
 
 
 
512
  // apply moves:
513
  {
514
  const auto n_kv = dinfo.ids.size();
@@ -529,14 +710,11 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d
529
 
530
  ggml_backend_sched_reset(sched);
531
 
532
- auto * gf = lctx->graph_init();
533
 
534
- auto res = build_graph_defrag(lctx->get_cparams(), lctx->get_ctx_compute(), gf, dinfo);
535
- if (!res) {
536
- LLAMA_LOG_ERROR("%s: failed to build graph for defrag\n", __func__);
537
- return updated;
538
- }
539
 
 
540
  if (!ggml_backend_sched_alloc_graph(sched, gf)) {
541
  LLAMA_LOG_ERROR("%s: failed to allocate compute graph for defrag\n", __func__);
542
  return updated;
@@ -556,23 +734,13 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d
556
  }
557
 
558
  llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch, bool cont) const {
559
- const uint32_t n_tokens = ubatch.n_tokens;
 
560
 
561
- uint32_t head_cur = this->head;
562
 
563
- // if we have enough unused cells before the current head ->
564
- // better to start searching from the beginning of the cache, hoping to fill it
565
- if (head_cur > cells.get_used() + 2*ubatch.n_tokens) {
566
- head_cur = 0;
567
- }
568
-
569
- if (n_tokens > cells.size()) {
570
- LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size());
571
- return { };
572
- }
573
-
574
- if (debug > 0) {
575
- LLAMA_LOG_DEBUG("%s: n = %5d, used = %5d, head = %5d, size = %5d, n_swa = %5d\n", __func__, cells.used_max_p1(), cells.get_used(), head, get_size(), n_swa);
576
 
577
  if ((debug == 2 && n_swa > 0) || debug > 2) {
578
  std::string ss;
@@ -629,86 +797,133 @@ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_
629
  }
630
  }
631
 
632
- uint32_t n_tested = 0;
 
 
 
 
633
 
634
- // for continuous slots, we test that all tokens in the ubatch fit, starting from the current head
635
- // for non-continuous slots, we test the tokens one by one
636
- const uint32_t n_test = cont ? n_tokens : 1;
637
 
638
- slot_info res;
 
 
 
 
 
639
 
640
- auto & idxs = res.idxs;
641
 
642
- idxs.reserve(n_tokens);
 
643
 
644
- while (true) {
645
- if (head_cur + n_test > cells.size()) {
646
- n_tested += cells.size() - head_cur;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
647
  head_cur = 0;
648
- continue;
649
  }
650
 
651
- for (uint32_t i = 0; i < n_test; i++) {
652
- const auto idx = head_cur;
 
 
 
 
 
 
 
 
653
 
654
- //const llama_pos pos = ubatch.pos[i];
655
- //const llama_seq_id seq_id = ubatch.seq_id[i][0];
 
 
 
 
656
 
657
- // can we use this cell? either:
658
- // - the cell is empty
659
- // - the cell is occupied only by one sequence:
660
- // - (disabled) mask causally, if the sequence is the same as the one we are inserting
661
- // - mask SWA, using current max pos for that sequence in the cache
662
- // always insert in the cell with minimum pos
663
- bool can_use = cells.is_empty(idx);
664
 
665
- if (!can_use && cells.seq_count(idx) == 1) {
666
- const llama_pos pos_cell = cells.pos_get(idx);
667
 
668
- // (disabled) causal mask
669
- // note: it's better to purge any "future" tokens beforehand
670
- //if (cells.seq_has(idx, seq_id)) {
671
- // can_use = pos_cell >= pos;
672
- //}
673
 
674
- if (!can_use) {
675
- const llama_seq_id seq_id_cell = cells.seq_get(idx);
 
 
 
 
 
676
 
677
- // SWA mask
678
- if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
679
- can_use = true;
 
 
 
 
 
 
 
 
 
 
 
 
 
680
  }
681
  }
682
- }
683
 
684
- head_cur++;
685
- n_tested++;
 
 
 
 
 
 
686
 
687
- if (can_use) {
688
- idxs.push_back(idx);
689
- } else {
690
  break;
691
  }
692
- }
693
 
694
- if (idxs.size() == n_tokens) {
695
- break;
696
- }
697
 
698
- if (cont) {
699
- idxs.clear();
 
 
700
  }
701
 
702
- if (n_tested >= cells.size()) {
703
- //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
704
  return { };
705
  }
706
  }
707
 
708
- // we didn't find a suitable slot - return empty result
709
- if (idxs.size() < n_tokens) {
710
- res.clear();
711
- }
712
 
713
  return res;
714
  }
@@ -717,41 +932,51 @@ void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_u
717
  // keep track of the max sequence position that we would overwrite with this ubatch
718
  // for non-SWA cache, this would be always empty
719
  llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ];
720
- for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
721
  seq_pos_max_rm[s] = -1;
722
  }
723
 
724
- assert(ubatch.n_tokens == sinfo.idxs.size());
725
 
726
- for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
727
- const auto idx = sinfo.idxs.at(i);
 
728
 
729
- if (!cells.is_empty(idx)) {
730
- assert(cells.seq_count(idx) == 1);
731
 
732
- const llama_seq_id seq_id = cells.seq_get(idx);
733
- const llama_pos pos = cells.pos_get(idx);
734
 
735
- seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
 
736
 
737
- cells.rm(idx);
738
- }
739
 
740
- cells.pos_set(idx, ubatch.pos[i]);
 
 
 
741
 
742
- for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) {
743
- cells.seq_add(idx, ubatch.seq_id[i][s]);
 
 
 
744
  }
745
  }
746
 
747
  // note: we want to preserve the invariant that all positions between [pos_min, pos_max] for each sequence
748
  // will be present in the cache. so we have to purge any position which is less than those we would overwrite
749
  // ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092
750
- for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
751
  if (seq_pos_max_rm[s] == -1) {
752
  continue;
753
  }
754
 
 
 
 
 
755
  if (cells.seq_pos_min(s) <= seq_pos_max_rm[s]) {
756
  LLAMA_LOG_DEBUG("%s: purging positions [%d, %d] of sequence %d from KV cache\n",
757
  __func__, cells.seq_pos_min(s), seq_pos_max_rm[s], s);
@@ -761,7 +986,11 @@ void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_u
761
  }
762
 
763
  // move the head at the end of the slot
764
- head = sinfo.idxs.back() + 1;
 
 
 
 
765
  }
766
 
767
  bool llama_kv_cache_unified::get_can_shift() const {
@@ -769,49 +998,91 @@ bool llama_kv_cache_unified::get_can_shift() const {
769
  }
770
 
771
  uint32_t llama_kv_cache_unified::get_size() const {
 
 
772
  return cells.size();
773
  }
774
 
 
 
 
 
775
  bool llama_kv_cache_unified::get_has_shift() const {
776
- return cells.get_has_shift();
 
 
 
 
 
 
777
  }
778
 
779
  uint32_t llama_kv_cache_unified::get_n_kv() const {
780
- return std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad)));
 
 
 
 
 
 
 
 
781
  }
782
 
783
- ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il, uint32_t n_kv) const {
 
 
 
 
784
  const int32_t ikv = map_layer_ids.at(il);
785
 
786
  auto * k = layers[ikv].k;
787
 
788
- return ggml_view_3d(ctx, k,
789
- hparams.n_embd_head_k, hparams.n_head_kv(il), n_kv,
 
 
 
 
 
 
 
790
  ggml_row_size(k->type, hparams.n_embd_head_k),
791
- ggml_row_size(k->type, hparams.n_embd_k_gqa(il)),
792
- 0);
 
793
  }
794
 
795
- ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const {
796
  const int32_t ikv = map_layer_ids.at(il);
797
 
798
  auto * v = layers[ikv].v;
799
 
 
 
 
 
 
 
 
 
800
  if (!v_trans) {
801
  // note: v->nb[1] <= v->nb[2]
802
- return ggml_view_3d(ctx, v,
803
- hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv,
804
- ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1]
805
- ggml_row_size(v->type, hparams.n_embd_v_gqa(il)), // v->nb[2]
806
- 0);
 
807
  }
808
 
809
  // note: v->nb[1] > v->nb[2]
810
- return ggml_view_3d(ctx, v,
811
- n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v,
812
- ggml_row_size(v->type, v->ne[1]*hparams.n_embd_head_v), // v->nb[1]
813
- ggml_row_size(v->type, v->ne[1]), // v->nb[2]
814
- 0);
 
815
  }
816
 
817
  ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const {
@@ -825,12 +1096,18 @@ ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_
825
  k_cur = ggml_reshape_2d(ctx, k_cur, k->ne[0], n_tokens);
826
 
827
  if (k_idxs && supports_set_rows) {
 
 
 
 
828
  return ggml_set_rows(ctx, k, k_cur, k_idxs);
829
  }
830
 
831
  // TODO: fallback to old ggml_cpy() method for backwards compatibility
832
  // will be removed when ggml_set_rows() is adopted by all backends
833
 
 
 
834
  ggml_tensor * k_view = ggml_view_1d(ctx, k,
835
  n_tokens*n_embd_k_gqa,
836
  ggml_row_size(k->type, n_embd_k_gqa)*sinfo.head());
@@ -843,37 +1120,38 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
843
 
844
  auto * v = layers[ikv].v;
845
 
846
- const int64_t n_embd_v_gqa = v->ne[0];
847
- const int64_t n_tokens = v_cur->ne[2];
848
 
849
  v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens);
850
 
851
  if (v_idxs && supports_set_rows) {
852
  if (!v_trans) {
 
 
 
 
853
  return ggml_set_rows(ctx, v, v_cur, v_idxs);
854
  }
855
 
856
- // the row becomes a single element
857
- ggml_tensor * v_view = ggml_reshape_3d(ctx, v, 1, v->ne[1], v->ne[0]);
 
 
858
 
859
- // note: the V cache is transposed when not using flash attention
860
- v_cur = ggml_permute(ctx, ggml_reshape_3d(ctx, v_cur, v_cur->ne[0], 1, v_cur->ne[1]), 2, 0, 1, 3);
861
 
862
- // note: we can be more explicit here at the cost of extra cont
863
- // however, above we take advantage that a row of single element is always continuous regardless of the row stride
864
- //v_cur = ggml_transpose(ctx, v_cur);
865
- //v_cur = ggml_cont_3d(ctx, v_cur, 1, v_cur->ne[0], v_cur->ne[1]);
866
 
867
- // we broadcast the KV indices n_embd_v_gqa times
868
- // v [1, n_kv, n_embd_v_gqa]
869
- // v_cur [1, n_tokens, n_embd_v_gqa]
870
- // v_idxs [n_tokens, 1, 1]
871
  return ggml_set_rows(ctx, v_view, v_cur, v_idxs);
872
  }
873
 
874
  // TODO: fallback to old ggml_cpy() method for backwards compatibility
875
  // will be removed when ggml_set_rows() is adopted by all backends
876
 
 
 
877
  ggml_tensor * v_view = nullptr;
878
 
879
  if (!v_trans) {
@@ -904,7 +1182,13 @@ ggml_tensor * llama_kv_cache_unified::build_input_k_idxs(ggml_context * ctx, con
904
  ggml_tensor * llama_kv_cache_unified::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
905
  const uint32_t n_tokens = ubatch.n_tokens;
906
 
907
- ggml_tensor * v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens);
 
 
 
 
 
 
908
 
909
  ggml_set_input(v_idxs);
910
 
@@ -917,12 +1201,17 @@ void llama_kv_cache_unified::set_input_k_idxs(ggml_tensor * dst, const llama_uba
917
  }
918
 
919
  const uint32_t n_tokens = ubatch->n_tokens;
 
920
 
921
  GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
922
  int64_t * data = (int64_t *) dst->data;
923
 
924
- for (int64_t i = 0; i < n_tokens; ++i) {
925
- data[i] = sinfo.idxs.at(i);
 
 
 
 
926
  }
927
  }
928
 
@@ -932,12 +1221,48 @@ void llama_kv_cache_unified::set_input_v_idxs(ggml_tensor * dst, const llama_uba
932
  }
933
 
934
  const uint32_t n_tokens = ubatch->n_tokens;
 
935
 
936
  GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
937
  int64_t * data = (int64_t *) dst->data;
938
 
939
- for (int64_t i = 0; i < n_tokens; ++i) {
940
- data[i] = sinfo.idxs.at(i);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
941
  }
942
  }
943
 
@@ -947,7 +1272,16 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
947
  GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
948
  float * data = (float *) dst->data;
949
 
950
- const int64_t n_kv = dst->ne[0];
 
 
 
 
 
 
 
 
 
951
 
952
  // Use only the previous KV cells of the correct sequence for each token of the ubatch.
953
  // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
@@ -961,70 +1295,57 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
961
  // xxxxx-----
962
  // xxxxx-----
963
  // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
 
964
  for (uint32_t h = 0; h < 1; ++h) {
965
- for (uint32_t i = 0; i < n_tokens; ++i) {
966
- const llama_seq_id seq_id = ubatch->seq_id[i][0];
 
967
 
968
- const llama_pos p1 = ubatch->pos[i];
969
 
970
- for (uint32_t j = 0; j < n_kv; ++j) {
971
- float f = 0.0f;
972
 
973
- bool masked = false;
974
 
975
- if (cells.is_empty(j)) {
976
- masked = true;
977
- } else {
978
- const llama_pos p0 = cells.pos_get(j);
 
 
979
 
980
  // mask the token if not the same sequence
981
- masked = masked || (!cells.seq_has(j, seq_id));
 
 
 
 
982
 
983
  // mask future tokens
984
- masked = masked || (causal_attn && p0 > p1);
 
 
985
 
986
  // apply SWA if any
987
- masked = masked || (is_masked_swa(p0, p1));
988
-
989
- if (!masked && hparams.use_alibi) {
990
- f = -std::abs(p0 - p1);
991
  }
992
- }
993
-
994
- if (masked) {
995
- f = -INFINITY;
996
- }
997
-
998
- data[h*(n_kv*n_tokens) + i*n_kv + j] = f;
999
- }
1000
- }
1001
 
1002
- // mask padded tokens
1003
- if (data) {
1004
- for (uint32_t i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
1005
- for (uint32_t j = 0; j < n_kv; ++j) {
1006
- data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
1007
  }
1008
  }
1009
  }
1010
  }
1011
  }
1012
 
1013
- void llama_kv_cache_unified::set_input_k_shift(ggml_tensor * dst) const {
1014
- GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
1015
-
1016
- int32_t * data = (int32_t *) dst->data;
1017
-
1018
- for (uint32_t i = 0; i < cells.size(); ++i) {
1019
- data[i] = cells.is_empty(i) ? 0 : cells.get_shift(i);
1020
- }
1021
- }
1022
-
1023
  void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
1024
  const int64_t n_tokens = ubatch->n_tokens;
1025
 
 
 
 
1026
  GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
1027
- GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
1028
 
1029
  int32_t * data = (int32_t *) dst->data;
1030
 
@@ -1129,7 +1450,7 @@ public:
1129
 
1130
  void set_input(const llama_ubatch * ubatch) override;
1131
 
1132
- ggml_tensor * k_shift; // I32 [kv_size]
1133
 
1134
  const llama_kv_cache_unified * kv_self;
1135
  };
@@ -1142,20 +1463,20 @@ void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) {
1142
  }
1143
  }
1144
 
1145
- llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
1146
- const llama_cparams & cparams,
1147
- ggml_context * ctx,
1148
- ggml_cgraph * gf) const {
1149
- auto res = std::make_unique<llm_graph_result>();
1150
 
1151
  const auto & n_embd_head_k = hparams.n_embd_head_k;
1152
  //const auto & n_embd_head_v = hparams.n_embd_head_v;
1153
 
1154
  auto inp = std::make_unique<llm_graph_input_k_shift>(this);
1155
 
1156
- inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, cells.size());
1157
  ggml_set_input(inp->k_shift);
1158
 
 
 
1159
  for (const auto & layer : layers) {
1160
  const uint32_t il = layer.il;
1161
 
@@ -1169,7 +1490,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
1169
 
1170
  ggml_tensor * k =
1171
  ggml_view_3d(ctx, layer.k,
1172
- n_embd_head_k, n_head_kv, cells.size(),
1173
  ggml_row_size(layer.k->type, n_embd_head_k),
1174
  ggml_row_size(layer.k->type, n_embd_k_gqa),
1175
  0);
@@ -1181,18 +1502,24 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
1181
 
1182
  res->add_input(std::move(inp));
1183
 
1184
- return res;
1185
  }
1186
 
1187
- llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
1188
- const llama_cparams & cparams,
1189
- ggml_context * ctx,
1190
- ggml_cgraph * gf,
1191
- const defrag_info & dinfo) const {
1192
- auto res = std::make_unique<llm_graph_result>();
 
 
 
 
1193
 
1194
  const auto & ids = dinfo.ids;
1195
 
 
 
1196
  #if 0
1197
  // CPU defrag
1198
  //
@@ -1329,10 +1656,14 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
1329
  //LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes);
1330
  #endif
1331
 
1332
- return res;
1333
  }
1334
 
1335
  llama_kv_cache_unified::defrag_info llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) const {
 
 
 
 
1336
  const uint32_t n_layer = layers.size();
1337
 
1338
  const uint32_t n_kv = cells.used_max_p1();
@@ -1478,64 +1809,94 @@ bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const {
1478
  }
1479
 
1480
  void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
1481
- std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
1482
- uint32_t cell_count = 0;
1483
 
1484
- // Count the number of cells with the specified seq_id
1485
- // Find all the ranges of cells with this seq id (or all, when -1)
1486
- uint32_t cell_range_begin = cells.size();
1487
 
1488
- for (uint32_t i = 0; i < cells.size(); ++i) {
1489
- if (!cells.is_empty(i) && (seq_id == -1 || cells.seq_has(i, seq_id))) {
1490
- ++cell_count;
1491
- if (cell_range_begin == cells.size()) {
1492
- cell_range_begin = i;
1493
- }
1494
- } else {
1495
- if (cell_range_begin != cells.size()) {
1496
- cell_ranges.emplace_back(cell_range_begin, i);
1497
- cell_range_begin = cells.size();
 
 
 
 
 
 
 
 
 
1498
  }
1499
  }
1500
- }
1501
 
1502
- if (cell_range_begin != cells.size()) {
1503
- cell_ranges.emplace_back(cell_range_begin, cells.size());
1504
- }
1505
 
1506
- // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
1507
- uint32_t cell_count_check = 0;
1508
- for (const auto & range : cell_ranges) {
1509
- cell_count_check += range.second - range.first;
1510
- }
1511
- GGML_ASSERT(cell_count == cell_count_check);
1512
 
1513
- io.write(&cell_count, sizeof(cell_count));
1514
 
1515
- state_write_meta(io, cell_ranges, seq_id);
1516
- state_write_data(io, cell_ranges);
 
 
 
 
 
 
1517
  }
1518
 
1519
  void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
1520
- uint32_t cell_count;
1521
- io.read_to(&cell_count, sizeof(cell_count));
1522
 
1523
- bool res = true;
1524
- res = res && state_read_meta(io, cell_count, seq_id);
1525
- res = res && state_read_data(io, cell_count);
 
 
 
 
 
 
 
 
 
 
1526
 
1527
- if (!res) {
1528
- if (seq_id == -1) {
1529
- clear(true);
1530
- } else {
1531
- seq_rm(seq_id, -1, -1);
 
 
 
 
 
 
 
 
1532
  }
1533
- throw std::runtime_error("failed to restore kv cache");
1534
  }
1535
  }
1536
 
1537
- void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const {
1538
- for (const auto & range : cell_ranges) {
 
 
1539
  for (uint32_t i = range.first; i < range.second; ++i) {
1540
  std::vector<llama_seq_id> seq_ids;
1541
 
@@ -1560,7 +1921,9 @@ void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std::
1560
  }
1561
  }
1562
 
1563
- void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const {
 
 
1564
  const uint32_t v_trans = this->v_trans ? 1 : 0;
1565
  const uint32_t n_layer = layers.size();
1566
 
@@ -1576,19 +1939,21 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
1576
 
1577
  const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
1578
 
 
 
1579
  // Write key type
1580
- const int32_t k_type_i = (int32_t)layer.k->type;
1581
  io.write(&k_type_i, sizeof(k_type_i));
1582
 
1583
  // Write row size of key
1584
- const uint64_t k_size_row = ggml_row_size(layer.k->type, n_embd_k_gqa);
1585
  io.write(&k_size_row, sizeof(k_size_row));
1586
 
1587
  // Read each range of cells of k_size length each into tmp_buf and write out
1588
- for (const auto & range : cell_ranges) {
1589
  const size_t range_size = range.second - range.first;
1590
  const size_t buf_size = range_size * k_size_row;
1591
- io.write_tensor(layer.k, range.first * k_size_row, buf_size);
1592
  }
1593
  }
1594
 
@@ -1598,19 +1963,21 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
1598
 
1599
  const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
1600
 
 
 
1601
  // Write value type
1602
- const int32_t v_type_i = (int32_t)layer.v->type;
1603
  io.write(&v_type_i, sizeof(v_type_i));
1604
 
1605
  // Write row size of value
1606
- const uint64_t v_size_row = ggml_row_size(layer.v->type, n_embd_v_gqa);
1607
  io.write(&v_size_row, sizeof(v_size_row));
1608
 
1609
  // Read each range of cells of v_size length each into tmp_buf and write out
1610
- for (const auto & range : cell_ranges) {
1611
  const size_t range_size = range.second - range.first;
1612
  const size_t buf_size = range_size * v_size_row;
1613
- io.write_tensor(layer.v, range.first * v_size_row, buf_size);
1614
  }
1615
  }
1616
  } else {
@@ -1622,12 +1989,14 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
1622
 
1623
  const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
1624
 
 
 
1625
  // Write value type
1626
- const int32_t v_type_i = (int32_t)layer.v->type;
1627
  io.write(&v_type_i, sizeof(v_type_i));
1628
 
1629
  // Write element size
1630
- const uint32_t v_size_el = ggml_type_size(layer.v->type);
1631
  io.write(&v_size_el, sizeof(v_size_el));
1632
 
1633
  // Write GQA embedding size
@@ -1636,27 +2005,31 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
1636
  // For each row, we get the element values of each cell
1637
  for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
1638
  // Read each range of cells of v_size_el length each into tmp_buf and write out
1639
- for (const auto & range : cell_ranges) {
1640
  const size_t range_size = range.second - range.first;
1641
  const size_t src_offset = (range.first + j * kv_size) * v_size_el;
1642
  const size_t buf_size = range_size * v_size_el;
1643
- io.write_tensor(layer.v, src_offset, buf_size);
1644
  }
1645
  }
1646
  }
1647
  }
1648
  }
1649
 
1650
- bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) {
 
 
 
1651
  if (dest_seq_id != -1) {
1652
  // single sequence
1653
-
1654
  seq_rm(dest_seq_id, -1, -1);
1655
 
1656
  llama_batch_allocr balloc(hparams.n_pos_per_embd());
1657
 
1658
  llama_ubatch ubatch = balloc.ubatch_reserve(cell_count, 1);
1659
 
 
 
1660
  for (uint32_t i = 0; i < cell_count; ++i) {
1661
  llama_pos pos;
1662
  uint32_t n_seq_id;
@@ -1693,6 +2066,8 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
1693
  // keep the head at the old position because we will read the KV data into it in state_read_data()
1694
  head = head_cur;
1695
 
 
 
1696
  // DEBUG CHECK: head_cur should be our first cell, head_cur + cell_count - 1 should be our last cell (verify seq_id and pos values)
1697
  // Assume that this is one contiguous block of cells
1698
  GGML_ASSERT(head_cur + cell_count <= cells.size());
@@ -1738,7 +2113,10 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
1738
  return true;
1739
  }
1740
 
1741
- bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell_count) {
 
 
 
1742
  uint32_t v_trans;
1743
  uint32_t n_layer;
1744
 
@@ -1766,10 +2144,12 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
1766
 
1767
  const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
1768
 
 
 
1769
  // Read type of key
1770
  int32_t k_type_i_ref;
1771
  io.read_to(&k_type_i_ref, sizeof(k_type_i_ref));
1772
- const int32_t k_type_i = (int32_t) layer.k->type;
1773
  if (k_type_i != k_type_i_ref) {
1774
  LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
1775
  return false;
@@ -1778,7 +2158,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
1778
  // Read row size of key
1779
  uint64_t k_size_row_ref;
1780
  io.read_to(&k_size_row_ref, sizeof(k_size_row_ref));
1781
- const size_t k_size_row = ggml_row_size(layer.k->type, n_embd_k_gqa);
1782
  if (k_size_row != k_size_row_ref) {
1783
  LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il);
1784
  return false;
@@ -1786,7 +2166,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
1786
 
1787
  if (cell_count) {
1788
  // Read and set the keys for the whole cell range
1789
- ggml_backend_tensor_set(layer.k, io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row);
1790
  }
1791
  }
1792
 
@@ -1796,10 +2176,12 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
1796
 
1797
  const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
1798
 
 
 
1799
  // Read type of value
1800
  int32_t v_type_i_ref;
1801
  io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
1802
- const int32_t v_type_i = (int32_t)layer.v->type;
1803
  if (v_type_i != v_type_i_ref) {
1804
  LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
1805
  return false;
@@ -1808,7 +2190,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
1808
  // Read row size of value
1809
  uint64_t v_size_row_ref;
1810
  io.read_to(&v_size_row_ref, sizeof(v_size_row_ref));
1811
- const size_t v_size_row = ggml_row_size(layer.v->type, n_embd_v_gqa);
1812
  if (v_size_row != v_size_row_ref) {
1813
  LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il);
1814
  return false;
@@ -1816,7 +2198,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
1816
 
1817
  if (cell_count) {
1818
  // Read and set the values for the whole cell range
1819
- ggml_backend_tensor_set(layer.v, io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row);
1820
  }
1821
  }
1822
  } else {
@@ -1826,10 +2208,12 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
1826
 
1827
  const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
1828
 
 
 
1829
  // Read type of value
1830
  int32_t v_type_i_ref;
1831
  io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
1832
- const int32_t v_type_i = (int32_t)layer.v->type;
1833
  if (v_type_i != v_type_i_ref) {
1834
  LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
1835
  return false;
@@ -1838,7 +2222,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
1838
  // Read element size of value
1839
  uint32_t v_size_el_ref;
1840
  io.read_to(&v_size_el_ref, sizeof(v_size_el_ref));
1841
- const size_t v_size_el = ggml_type_size(layer.v->type);
1842
  if (v_size_el != v_size_el_ref) {
1843
  LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il);
1844
  return false;
@@ -1856,7 +2240,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
1856
  // For each row in the transposed matrix, read the values for the whole cell range
1857
  for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
1858
  const size_t dst_offset = (head + j * cells.size()) * v_size_el;
1859
- ggml_backend_tensor_set(layer.v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
1860
  }
1861
  }
1862
  }
@@ -1875,18 +2259,26 @@ llama_kv_cache_unified_context::llama_kv_cache_unified_context(
1875
  llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) {
1876
  n_kv = kv->get_size();
1877
 
 
 
1878
  // create a dummy slot info - the actual data is irrelevant. we just need to build the graph
1879
  sinfos.resize(1);
1880
- sinfos[0].idxs.resize(1);
1881
- sinfos[0].idxs[0] = 0;
 
 
 
 
 
1882
  }
1883
 
1884
  llama_kv_cache_unified_context::llama_kv_cache_unified_context(
1885
  llama_kv_cache_unified * kv,
1886
  llama_context * lctx,
1887
  bool do_shift,
1888
- defrag_info dinfo) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), lctx(lctx), do_shift(do_shift), dinfo(std::move(dinfo)) {
1889
- if (!do_shift && this->dinfo.empty()) {
 
1890
  status = LLAMA_MEMORY_STATUS_NO_UPDATE;
1891
  }
1892
  }
@@ -1914,7 +2306,7 @@ bool llama_kv_cache_unified_context::apply() {
1914
 
1915
  // no ubatches -> this is a KV cache update
1916
  if (ubatches.empty()) {
1917
- kv->update(lctx, do_shift, dinfo);
1918
 
1919
  return true;
1920
  }
@@ -1940,12 +2332,16 @@ uint32_t llama_kv_cache_unified_context::get_n_kv() const {
1940
  return n_kv;
1941
  }
1942
 
 
 
 
 
1943
  ggml_tensor * llama_kv_cache_unified_context::get_k(ggml_context * ctx, int32_t il) const {
1944
- return kv->get_k(ctx, il, n_kv);
1945
  }
1946
 
1947
  ggml_tensor * llama_kv_cache_unified_context::get_v(ggml_context * ctx, int32_t il) const {
1948
- return kv->get_v(ctx, il, n_kv);
1949
  }
1950
 
1951
  ggml_tensor * llama_kv_cache_unified_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const {
 
23
  ggml_type type_v,
24
  bool v_trans,
25
  bool offload,
26
+ bool unified,
27
  uint32_t kv_size,
28
  uint32_t n_seq_max,
29
  uint32_t n_pad,
30
  uint32_t n_swa,
31
  llama_swa_type swa_type) :
32
  model(model), hparams(model.hparams), v_trans(v_trans),
33
+ n_seq_max(n_seq_max), n_stream(unified ? 1 : n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) {
34
 
35
  GGML_ASSERT(kv_size % n_pad == 0);
36
 
 
46
  auto it = ctx_map.find(buft);
47
  if (it == ctx_map.end()) {
48
  ggml_init_params params = {
49
+ /*.mem_size =*/ size_t(2u*(1 + n_stream)*n_layer_cache*ggml_tensor_overhead()),
50
  /*.mem_buffer =*/ NULL,
51
  /*.no_alloc =*/ true,
52
  };
 
65
  return it->second;
66
  };
67
 
68
+ GGML_ASSERT(n_stream == 1 || n_stream == n_seq_max);
69
 
70
+ v_heads.resize(n_stream);
71
+ for (uint32_t s = 0; s < n_stream; ++s) {
72
+ v_heads[s] = 0;
73
+ }
74
+
75
+ v_cells.resize(n_stream);
76
+ for (uint32_t s = 0; s < n_stream; ++s) {
77
+ v_cells[s].resize(kv_size);
78
+ }
79
+
80
+ // by default, all sequence ids are mapped to the 0th stream
81
+ seq_to_stream.resize(LLAMA_MAX_SEQ, 0);
82
+
83
+ if (n_stream > 1) {
84
+ seq_to_stream.resize(n_stream, 0);
85
+ for (uint32_t s = 0; s < n_stream; ++s) {
86
+ seq_to_stream[s] = s;
87
+ }
88
+ }
89
+
90
+ // [TAG_V_CACHE_VARIABLE]
91
+ if (v_trans && hparams.is_n_embd_v_gqa_variable()) {
92
+ LLAMA_LOG_WARN("%s: the V embeddings have different sizes across layers and FA is not enabled - padding V cache to %d\n",
93
+ __func__, hparams.n_embd_v_gqa_max());
94
+ }
95
 
96
  for (uint32_t il = 0; il < n_layer_cache; il++) {
97
  if (filter && !filter(il)) {
 
99
  continue;
100
  }
101
 
102
+ // [TAG_V_CACHE_VARIABLE]
103
+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
104
+ const uint32_t n_embd_v_gqa = !v_trans ? hparams.n_embd_v_gqa(il) : hparams.n_embd_v_gqa_max();
105
 
106
  const char * dev_name = "CPU";
107
 
 
124
  ggml_tensor * k;
125
  ggml_tensor * v;
126
 
127
+ k = ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, n_stream);
128
+ v = ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, n_stream);
129
 
130
  ggml_format_name(k, "cache_k_l%d", il);
131
  ggml_format_name(v, "cache_v_l%d", il);
132
 
133
+ std::vector<ggml_tensor *> k_stream;
134
+ std::vector<ggml_tensor *> v_stream;
135
+
136
+ for (uint32_t s = 0; s < n_stream; ++s) {
137
+ k_stream.push_back(ggml_view_2d(ctx, k, n_embd_k_gqa, kv_size, k->nb[1], s*k->nb[2]));
138
+ v_stream.push_back(ggml_view_2d(ctx, v, n_embd_v_gqa, kv_size, v->nb[1], s*v->nb[2]));
139
+ }
140
+
141
  map_layer_ids[il] = layers.size();
142
+
143
+ layers.push_back({ il, k, v, k_stream, v_stream, });
144
  }
145
 
146
  // TODO: this is temporary until we support passing reuse layer filters [KV_REUSE]
 
183
  const size_t memory_size_k = size_k_bytes();
184
  const size_t memory_size_v = size_v_bytes();
185
 
186
+ LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u/%2u seqs), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
187
+ (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), kv_size, (int) layers.size(), n_seq_max, n_stream,
188
  ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
189
  ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
190
  }
 
193
  debug = LLAMA_KV_CACHE_DEBUG ? atoi(LLAMA_KV_CACHE_DEBUG) : 0;
194
 
195
  const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS");
196
+ supports_set_rows = LLAMA_SET_ROWS ? atoi(LLAMA_SET_ROWS) != 0 : 0;
197
+
198
+ if (!supports_set_rows) {
199
+ // ref: https://github.com/ggml-org/llama.cpp/pull/14363
200
+ GGML_ASSERT(unified && "cannot use non-unified KV cache without ggml_set_rows() support");
201
+ }
202
 
203
  if (!supports_set_rows) {
204
  LLAMA_LOG_WARN("%s: LLAMA_SET_ROWS=0, using old ggml_cpy() method for backwards compatibility\n", __func__);
 
206
  }
207
 
208
  void llama_kv_cache_unified::clear(bool data) {
209
+ for (uint32_t s = 0; s < n_stream; ++s) {
210
+ v_cells[s].reset();
211
+ v_heads[s] = 0;
212
+ }
213
 
214
  if (data) {
215
  for (auto & buf : bufs) {
 
219
  }
220
 
221
  bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
222
+ GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
223
+
224
+ auto & cells = v_cells[seq_to_stream[seq_id]];
225
+ auto & head = v_heads[seq_to_stream[seq_id]];
226
+
227
  uint32_t new_head = cells.size();
228
 
229
  if (p0 < 0) {
 
270
  }
271
 
272
  void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
273
+ GGML_ASSERT(seq_id_src >= 0 && (size_t) seq_id_src < seq_to_stream.size());
274
+ GGML_ASSERT(seq_id_dst >= 0 && (size_t) seq_id_dst < seq_to_stream.size());
275
+
276
+ const auto s0 = seq_to_stream[seq_id_src];
277
+ const auto s1 = seq_to_stream[seq_id_dst];
278
+
279
+ if (s0 == s1) {
280
+ // since both sequences are in the same stream, no data copy is necessary
281
+ // we just have to update the cells meta data
282
+
283
+ auto & cells = v_cells[s0];
284
+
285
+ if (seq_id_src == seq_id_dst) {
286
+ return;
287
+ }
288
+
289
+ if (p0 < 0) {
290
+ p0 = 0;
291
+ }
292
+
293
+ if (p1 < 0) {
294
+ p1 = std::numeric_limits<llama_pos>::max();
295
+ }
296
+
297
+ for (uint32_t i = 0; i < cells.size(); ++i) {
298
+ if (!cells.pos_in(i, p0, p1)) {
299
+ continue;
300
+ }
301
+
302
+ if (cells.seq_has(i, seq_id_src)) {
303
+ cells.seq_add(i, seq_id_dst);
304
+ }
305
+ }
306
+
307
  return;
308
  }
309
 
310
+ // cross-stream sequence copies require to copy the actual buffer data
311
+
312
+ bool is_full = true;
313
+
314
+ if (p0 > 0 && p0 + 1 < (int) get_size()) {
315
+ is_full = false;
316
  }
317
 
318
+ if (p1 > 0 && p1 + 1 < (int) get_size()) {
319
+ is_full = false;
320
  }
321
 
322
+ GGML_ASSERT(is_full && "seq_cp() is only supported for full KV buffers");
323
+
324
+ // enqueue the copy operation - the buffer copy will be performed during the next update
325
+ sc_info.ssrc.push_back(s0);
326
+ sc_info.sdst.push_back(s1);
327
 
328
+ v_cells[s1].reset();
329
+ for (uint32_t i = 0; i < v_cells[s0].size(); ++i) {
330
+ if (v_cells[s0].seq_has(i, seq_id_src)) {
331
+ llama_pos pos = v_cells[s0].pos_get(i);
332
+ llama_pos shift = v_cells[s0].get_shift(i);
333
+
334
+ if (shift != 0) {
335
+ pos -= shift;
336
+ assert(pos >= 0);
337
+ }
338
+
339
+ v_cells[s1].pos_set(i, pos);
340
+ v_cells[s1].seq_add(i, seq_id_dst);
341
+
342
+ if (shift != 0) {
343
+ v_cells[s1].pos_add(i, shift);
344
+ }
345
  }
346
  }
347
+
348
+ v_heads[s1] = v_heads[s0];
349
+
350
+ //for (uint32_t s = 0; s < n_stream; ++s) {
351
+ // LLAMA_LOG_WARN("%s: seq %d: min = %d, max = %d\n", __func__, s, v_cells[s].seq_pos_min(s), v_cells[s].seq_pos_max(s));
352
+ //}
353
  }
354
 
355
  void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) {
356
+ GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
357
+
358
+ auto & cells = v_cells[seq_to_stream[seq_id]];
359
+ auto & head = v_heads[seq_to_stream[seq_id]];
360
+
361
  uint32_t new_head = cells.size();
362
 
363
  for (uint32_t i = 0; i < cells.size(); ++i) {
 
375
  }
376
 
377
  void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
378
+ GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
379
+
380
+ auto & cells = v_cells[seq_to_stream[seq_id]];
381
+ auto & head = v_heads[seq_to_stream[seq_id]];
382
+
383
  if (shift == 0) {
384
  return;
385
  }
 
419
  }
420
 
421
  void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
422
+ GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
423
+
424
+ auto & cells = v_cells[seq_to_stream[seq_id]];
425
+
426
  if (d == 1) {
427
  return;
428
  }
 
452
  }
453
 
454
  llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const {
455
+ GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
456
+
457
+ const auto & cells = v_cells[seq_to_stream[seq_id]];
458
+
459
  return cells.seq_pos_min(seq_id);
460
  }
461
 
462
  llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
463
+ GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
464
+
465
+ const auto & cells = v_cells[seq_to_stream[seq_id]];
466
+
467
  return cells.seq_pos_max(seq_id);
468
  }
469
 
 
478
 
479
  std::vector<llama_ubatch> ubatches;
480
  while (true) {
481
+ auto ubatch = n_stream == 1 ? balloc.split_simple(n_ubatch) : balloc.split_equal(n_ubatch, true);
482
 
483
  if (ubatch.n_tokens == 0) {
484
  break;
 
514
  defrag_info dinfo;
515
 
516
  // see if we need to defrag
517
+ if (n_stream == 1) {
518
+ // note : for now do not consider defrag for n_stream > 1
519
+ const auto & cells = v_cells[seq_to_stream[0]];
520
+
521
  bool do_defrag = optimize;
522
 
523
  const auto thold = lctx->get_cparams().defrag_thold;
 
541
  }
542
  }
543
 
544
+ return std::make_unique<llama_kv_cache_unified_context>(this, lctx, do_shift, std::move(dinfo), std::move(sc_info));
545
  }
546
 
547
  llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
548
  llama_kv_cache_unified::slot_info_vec_t res;
549
 
550
+ struct state_t {
 
 
551
  slot_info sinfo; // slot info for the ubatch
552
 
553
+ std::vector<uint32_t> v_heads_old; // old positions of the heads, before placing the ubatch
554
+
555
+ std::vector<llama_kv_cells_unified> v_cells; // copy of the old cells, before placing the ubatch
556
  };
557
 
558
  // remember the old state of the cells so we can restore it in the end
559
+ std::vector<state_t> states;
560
 
561
  bool success = true;
562
 
 
575
  res.push_back(sinfo_new);
576
 
577
  // store the old state of the cells in the recovery stack
578
+ {
579
+ state_t state = { sinfo_new, v_heads, {} };
580
+
581
+ for (uint32_t s = 0; s < sinfo_new.n_stream(); ++s) {
582
+ auto & cells = v_cells[sinfo_new.strm[s]];
583
+
584
+ state.v_cells.push_back(cells.cp(sinfo_new.idxs[s]));
585
+ }
586
+
587
+ states.push_back(std::move(state));
588
+ }
589
 
590
  // now emplace the ubatch
591
  apply_ubatch(sinfo_new, ubatch);
592
  }
593
 
594
+ GGML_ASSERT(!states.empty() || !success);
595
+
596
  // iterate backwards and restore the cells to their original state
597
  for (auto it = states.rbegin(); it != states.rend(); ++it) {
598
+ const auto & sinfo = it->sinfo;
599
+
600
+ for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
601
+ auto & cells = v_cells[sinfo.strm[s]];
602
+ auto & head = v_heads[sinfo.strm[s]];
603
+
604
+ cells.set(sinfo.idxs[s], it->v_cells[s]);
605
+ head = it->v_heads_old[s];
606
+ }
607
  }
608
 
609
  if (!success) {
 
613
  return res;
614
  }
615
 
616
+ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const defrag_info & dinfo, const stream_copy_info & sc_info) {
617
  bool updated = false;
618
 
619
  auto * sched = lctx->get_sched();
620
 
621
+ if (!sc_info.empty()) {
622
+ assert(n_stream > 1 && "stream copy should never happen with a single stream");
623
+
624
+ llama_synchronize(lctx);
625
+
626
+ const size_t n_copy = sc_info.ssrc.size();
627
+
628
+ for (size_t i = 0; i < n_copy; ++i) {
629
+ const auto ssrc = sc_info.ssrc[i];
630
+ const auto sdst = sc_info.sdst[i];
631
+
632
+ assert(ssrc < n_stream);
633
+ assert(sdst < n_stream);
634
+
635
+ LLAMA_LOG_DEBUG("%s: copying KV buffer: stream %d to stream %d\n", __func__, ssrc, sdst);
636
+
637
+ assert(ssrc != sdst);
638
+
639
+ for (uint32_t il = 0; il < layers.size(); ++il) {
640
+ const auto & layer = layers[il];
641
+
642
+ ggml_backend_tensor_copy(layer.k_stream[ssrc], layer.k_stream[sdst]);
643
+ ggml_backend_tensor_copy(layer.v_stream[ssrc], layer.v_stream[sdst]);
644
+ }
645
+ }
646
+ }
647
+
648
  if (do_shift) {
649
  if (!get_can_shift()) {
650
  GGML_ABORT("The current KV cache / model configuration does not support K-shift");
 
656
  if (hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
657
  ggml_backend_sched_reset(sched);
658
 
659
+ auto * res = lctx->get_gf_res_reserve();
660
 
661
+ res->reset();
 
 
 
 
662
 
663
+ auto * gf = build_graph_shift(res, lctx);
664
  if (!ggml_backend_sched_alloc_graph(sched, gf)) {
665
  LLAMA_LOG_ERROR("%s: failed to allocate compute graph for K-shift\n", __func__);
666
  return updated;
 
676
  updated = true;
677
  }
678
 
679
+ for (uint32_t s = 0; s < n_stream; ++s) {
680
+ auto & cells = v_cells[s];
681
+
682
+ cells.reset_shift();
683
+ }
684
  }
685
 
686
  if (!dinfo.empty()) {
687
  LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__);
688
 
689
+ // note: for now do not consider defrag for n_stream > 1
690
+ auto & cells = v_cells[seq_to_stream[0]];
691
+ auto & head = v_heads[seq_to_stream[0]];
692
+
693
  // apply moves:
694
  {
695
  const auto n_kv = dinfo.ids.size();
 
710
 
711
  ggml_backend_sched_reset(sched);
712
 
713
+ auto * res = lctx->get_gf_res_reserve();
714
 
715
+ res->reset();
 
 
 
 
716
 
717
+ auto * gf = build_graph_defrag(res, lctx, dinfo);
718
  if (!ggml_backend_sched_alloc_graph(sched, gf)) {
719
  LLAMA_LOG_ERROR("%s: failed to allocate compute graph for defrag\n", __func__);
720
  return updated;
 
734
  }
735
 
736
  llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch, bool cont) const {
737
+ if (debug > 0) {
738
+ const auto & cells = v_cells[seq_to_stream[1]];
739
 
740
+ const uint32_t head_cur = v_heads[1];
741
 
742
+ LLAMA_LOG_DEBUG("%s: n = %5d, used = %5d, head = %5d, size = %5d, n_swa = %5d\n",
743
+ __func__, cells.used_max_p1(), cells.get_used(), head_cur, get_size(), n_swa);
 
 
 
 
 
 
 
 
 
 
 
744
 
745
  if ((debug == 2 && n_swa > 0) || debug > 2) {
746
  std::string ss;
 
797
  }
798
  }
799
 
800
+ uint32_t n_tokens = ubatch.n_tokens;
801
+ uint32_t n_seqs = 1;
802
+
803
+ if (n_stream > 1) {
804
+ GGML_ASSERT(n_tokens % ubatch.n_seqs_unq == 0);
805
 
806
+ n_seqs = ubatch.n_seqs_unq;
807
+ n_tokens = n_tokens / n_seqs;
808
+ }
809
 
810
+ slot_info res = {
811
+ /*.s0 =*/ LLAMA_MAX_SEQ,
812
+ /*.s1 =*/ 0,
813
+ /*.strm =*/ { },
814
+ /*.idxs =*/ { },
815
+ };
816
 
817
+ res.resize(n_seqs);
818
 
819
+ for (uint32_t s = 0; s < n_seqs; ++s) {
820
+ const auto seq_id = ubatch.seq_id_unq[s];
821
 
822
+ if (n_stream > 1) {
823
+ GGML_ASSERT(ubatch.n_seq_id[s*n_tokens] == 1);
824
+ GGML_ASSERT(ubatch.seq_id [s*n_tokens][0] == seq_id);
825
+ }
826
+
827
+ res.s0 = std::min<llama_seq_id>(res.s0, seq_to_stream[seq_id]);
828
+ res.s1 = std::max<llama_seq_id>(res.s1, seq_to_stream[seq_id]);
829
+
830
+ res.strm[s] = seq_to_stream[seq_id];
831
+ res.idxs[s].reserve(n_tokens);
832
+
833
+ const auto & cells = v_cells[seq_to_stream[seq_id]];
834
+
835
+ uint32_t head_cur = v_heads[seq_to_stream[seq_id]];
836
+
837
+ // if we have enough unused cells before the current head ->
838
+ // better to start searching from the beginning of the cache, hoping to fill it
839
+ if (head_cur > cells.get_used() + 2*n_tokens) {
840
  head_cur = 0;
 
841
  }
842
 
843
+ if (n_tokens > cells.size()) {
844
+ LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size());
845
+ return { };
846
+ }
847
+
848
+ uint32_t n_tested = 0;
849
+
850
+ // for continuous slots, we test that all tokens in the ubatch fit, starting from the current head
851
+ // for non-continuous slots, we test the tokens one by one
852
+ const uint32_t n_test = cont ? n_tokens : 1;
853
 
854
+ while (true) {
855
+ if (head_cur + n_test > cells.size()) {
856
+ n_tested += cells.size() - head_cur;
857
+ head_cur = 0;
858
+ continue;
859
+ }
860
 
861
+ for (uint32_t i = 0; i < n_test; i++) {
862
+ const auto idx = head_cur;
 
 
 
 
 
863
 
864
+ head_cur++;
865
+ n_tested++;
866
 
867
+ //const llama_pos pos = ubatch.pos[i];
868
+ //const llama_seq_id seq_id = ubatch.seq_id[i][0];
 
 
 
869
 
870
+ // can we use this cell? either:
871
+ // - the cell is empty
872
+ // - the cell is occupied only by one sequence:
873
+ // - (disabled) mask causally, if the sequence is the same as the one we are inserting
874
+ // - mask SWA, using current max pos for that sequence in the cache
875
+ // always insert in the cell with minimum pos
876
+ bool can_use = cells.is_empty(idx);
877
 
878
+ if (!can_use && cells.seq_count(idx) == 1) {
879
+ const llama_pos pos_cell = cells.pos_get(idx);
880
+
881
+ // (disabled) causal mask
882
+ // note: it's better to purge any "future" tokens beforehand
883
+ //if (cells.seq_has(idx, seq_id)) {
884
+ // can_use = pos_cell >= pos;
885
+ //}
886
+
887
+ if (!can_use) {
888
+ const llama_seq_id seq_id_cell = cells.seq_get(idx);
889
+
890
+ // SWA mask
891
+ if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
892
+ can_use = true;
893
+ }
894
  }
895
  }
 
896
 
897
+ if (can_use) {
898
+ res.idxs[s].push_back(idx);
899
+ } else {
900
+ if (cont) {
901
+ break;
902
+ }
903
+ }
904
+ }
905
 
906
+ if (res.idxs[s].size() == n_tokens) {
 
 
907
  break;
908
  }
 
909
 
910
+ if (cont) {
911
+ res.idxs[s].clear();
912
+ }
913
 
914
+ if (n_tested >= cells.size()) {
915
+ //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
916
+ return { };
917
+ }
918
  }
919
 
920
+ // we didn't find a suitable slot - return empty result
921
+ if (res.idxs[s].size() < n_tokens) {
922
  return { };
923
  }
924
  }
925
 
926
+ assert(res.s1 >= res.s0);
 
 
 
927
 
928
  return res;
929
  }
 
932
  // keep track of the max sequence position that we would overwrite with this ubatch
933
  // for non-SWA cache, this would be always empty
934
  llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ];
935
+ for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
936
  seq_pos_max_rm[s] = -1;
937
  }
938
 
939
+ assert(ubatch.n_tokens == sinfo.n_stream()*sinfo.size());
940
 
941
+ for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
942
+ for (uint32_t ii = 0; ii < sinfo.size(); ++ii) {
943
+ const uint32_t i = s*sinfo.size() + ii;
944
 
945
+ auto & cells = v_cells[sinfo.strm[s]];
 
946
 
947
+ const auto idx = sinfo.idxs[s][ii];
 
948
 
949
+ if (!cells.is_empty(idx)) {
950
+ assert(cells.seq_count(idx) == 1);
951
 
952
+ const llama_seq_id seq_id = cells.seq_get(idx);
953
+ const llama_pos pos = cells.pos_get(idx);
954
 
955
+ seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
956
+
957
+ cells.rm(idx);
958
+ }
959
 
960
+ cells.pos_set(idx, ubatch.pos[i]);
961
+
962
+ for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) {
963
+ cells.seq_add(idx, ubatch.seq_id[i][s]);
964
+ }
965
  }
966
  }
967
 
968
  // note: we want to preserve the invariant that all positions between [pos_min, pos_max] for each sequence
969
  // will be present in the cache. so we have to purge any position which is less than those we would overwrite
970
  // ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092
971
+ for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
972
  if (seq_pos_max_rm[s] == -1) {
973
  continue;
974
  }
975
 
976
+ GGML_ASSERT(s < seq_to_stream.size());
977
+
978
+ auto & cells = v_cells[seq_to_stream[s]];
979
+
980
  if (cells.seq_pos_min(s) <= seq_pos_max_rm[s]) {
981
  LLAMA_LOG_DEBUG("%s: purging positions [%d, %d] of sequence %d from KV cache\n",
982
  __func__, cells.seq_pos_min(s), seq_pos_max_rm[s], s);
 
986
  }
987
 
988
  // move the head at the end of the slot
989
+ for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
990
+ auto & head = v_heads[sinfo.strm[s]];
991
+
992
+ head = sinfo.idxs[s].back() + 1;
993
+ }
994
  }
995
 
996
  bool llama_kv_cache_unified::get_can_shift() const {
 
998
  }
999
 
1000
  uint32_t llama_kv_cache_unified::get_size() const {
1001
+ const auto & cells = v_cells[seq_to_stream[0]];
1002
+
1003
  return cells.size();
1004
  }
1005
 
1006
+ uint32_t llama_kv_cache_unified::get_n_stream() const {
1007
+ return n_stream;
1008
+ }
1009
+
1010
  bool llama_kv_cache_unified::get_has_shift() const {
1011
+ bool result = false;
1012
+
1013
+ for (uint32_t s = 0; s < n_stream; ++s) {
1014
+ result |= v_cells[s].get_has_shift();
1015
+ }
1016
+
1017
+ return result;
1018
  }
1019
 
1020
  uint32_t llama_kv_cache_unified::get_n_kv() const {
1021
+ uint32_t result = 0;
1022
+
1023
+ for (uint32_t s = 0; s < n_stream; ++s) {
1024
+ const auto & cells = v_cells[s];
1025
+
1026
+ result = std::max(std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad))), result);
1027
+ }
1028
+
1029
+ return result;
1030
  }
1031
 
1032
+ bool llama_kv_cache_unified::get_supports_set_rows() const {
1033
+ return supports_set_rows;
1034
+ }
1035
+
1036
+ ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const {
1037
  const int32_t ikv = map_layer_ids.at(il);
1038
 
1039
  auto * k = layers[ikv].k;
1040
 
1041
+ const uint64_t kv_size = get_size();
1042
+ const uint64_t n_embd_k_gqa = k->ne[0];
1043
+
1044
+ assert(n_embd_k_gqa == hparams.n_embd_k_gqa(il));
1045
+
1046
+ const uint32_t ns = sinfo.s1 - sinfo.s0 + 1;
1047
+
1048
+ return ggml_view_4d(ctx, k,
1049
+ hparams.n_embd_head_k, hparams.n_head_kv(il), n_kv, ns,
1050
  ggml_row_size(k->type, hparams.n_embd_head_k),
1051
+ ggml_row_size(k->type, n_embd_k_gqa),
1052
+ ggml_row_size(k->type, n_embd_k_gqa*kv_size),
1053
+ ggml_row_size(k->type, n_embd_k_gqa*kv_size)*sinfo.s0);
1054
  }
1055
 
1056
+ ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const {
1057
  const int32_t ikv = map_layer_ids.at(il);
1058
 
1059
  auto * v = layers[ikv].v;
1060
 
1061
+ const uint64_t kv_size = get_size();
1062
+ const uint64_t n_embd_v_gqa = v->ne[0];
1063
+
1064
+ // [TAG_V_CACHE_VARIABLE]
1065
+ assert(n_embd_v_gqa >= hparams.n_embd_v_gqa(il));
1066
+
1067
+ const uint32_t ns = sinfo.s1 - sinfo.s0 + 1;
1068
+
1069
  if (!v_trans) {
1070
  // note: v->nb[1] <= v->nb[2]
1071
+ return ggml_view_4d(ctx, v,
1072
+ hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv, ns,
1073
+ ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1]
1074
+ ggml_row_size(v->type, n_embd_v_gqa), // v->nb[2]
1075
+ ggml_row_size(v->type, n_embd_v_gqa*kv_size), // v->nb[3]
1076
+ ggml_row_size(v->type, n_embd_v_gqa*kv_size)*sinfo.s0);
1077
  }
1078
 
1079
  // note: v->nb[1] > v->nb[2]
1080
+ return ggml_view_4d(ctx, v,
1081
+ n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v, ns,
1082
+ ggml_row_size(v->type, kv_size*hparams.n_embd_head_v), // v->nb[1]
1083
+ ggml_row_size(v->type, kv_size), // v->nb[2]
1084
+ ggml_row_size(v->type, kv_size*n_embd_v_gqa), // v->nb[3]
1085
+ ggml_row_size(v->type, kv_size*n_embd_v_gqa)*sinfo.s0);
1086
  }
1087
 
1088
  ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const {
 
1096
  k_cur = ggml_reshape_2d(ctx, k_cur, k->ne[0], n_tokens);
1097
 
1098
  if (k_idxs && supports_set_rows) {
1099
+ if (k->ne[2] > 1) {
1100
+ k = ggml_reshape_2d(ctx, k, k->ne[0], k->ne[1]*k->ne[2]);
1101
+ }
1102
+
1103
  return ggml_set_rows(ctx, k, k_cur, k_idxs);
1104
  }
1105
 
1106
  // TODO: fallback to old ggml_cpy() method for backwards compatibility
1107
  // will be removed when ggml_set_rows() is adopted by all backends
1108
 
1109
+ GGML_ASSERT(n_stream == 1 && "n_stream > 1 not supported without LLAMA_SET_ROWS");
1110
+
1111
  ggml_tensor * k_view = ggml_view_1d(ctx, k,
1112
  n_tokens*n_embd_k_gqa,
1113
  ggml_row_size(k->type, n_embd_k_gqa)*sinfo.head());
 
1120
 
1121
  auto * v = layers[ikv].v;
1122
 
1123
+ const int64_t n_embd_v_gqa = v_cur->ne[0]*v_cur->ne[1];
1124
+ const int64_t n_tokens = v_cur->ne[2];
1125
 
1126
  v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens);
1127
 
1128
  if (v_idxs && supports_set_rows) {
1129
  if (!v_trans) {
1130
+ if (v->ne[2] > 1) {
1131
+ v = ggml_reshape_2d(ctx, v, v->ne[0], v->ne[1]*v->ne[2]);
1132
+ }
1133
+
1134
  return ggml_set_rows(ctx, v, v_cur, v_idxs);
1135
  }
1136
 
1137
+ // [TAG_V_CACHE_VARIABLE]
1138
+ if (n_embd_v_gqa < v->ne[0]) {
1139
+ v_cur = ggml_pad(ctx, v_cur, v->ne[0] - n_embd_v_gqa, 0, 0, 0);
1140
+ }
1141
 
1142
+ // the row becomes a single element
1143
+ ggml_tensor * v_view = ggml_reshape_2d(ctx, v, 1, v->ne[0]*v->ne[1]*v->ne[2]);
1144
 
1145
+ v_cur = ggml_reshape_2d(ctx, v_cur, 1, v_cur->ne[0]*v_cur->ne[1]);
 
 
 
1146
 
 
 
 
 
1147
  return ggml_set_rows(ctx, v_view, v_cur, v_idxs);
1148
  }
1149
 
1150
  // TODO: fallback to old ggml_cpy() method for backwards compatibility
1151
  // will be removed when ggml_set_rows() is adopted by all backends
1152
 
1153
+ GGML_ASSERT(n_stream == 1 && "n_stream > 1 not supported without LLAMA_SET_ROWS");
1154
+
1155
  ggml_tensor * v_view = nullptr;
1156
 
1157
  if (!v_trans) {
 
1182
  ggml_tensor * llama_kv_cache_unified::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
1183
  const uint32_t n_tokens = ubatch.n_tokens;
1184
 
1185
+ ggml_tensor * v_idxs;
1186
+
1187
+ if (!v_trans) {
1188
+ v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens);
1189
+ } else {
1190
+ v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens*hparams.n_embd_v_gqa_max());
1191
+ }
1192
 
1193
  ggml_set_input(v_idxs);
1194
 
 
1201
  }
1202
 
1203
  const uint32_t n_tokens = ubatch->n_tokens;
1204
+ GGML_ASSERT(n_tokens == (int64_t) sinfo.size()*sinfo.n_stream());
1205
 
1206
  GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
1207
  int64_t * data = (int64_t *) dst->data;
1208
 
1209
+ for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
1210
+ const int64_t offs = sinfo.strm[s]*get_size();
1211
+
1212
+ for (uint32_t i = 0; i < sinfo.size(); ++i) {
1213
+ data[s*sinfo.size() + i] = offs + sinfo.idxs[s][i];
1214
+ }
1215
  }
1216
  }
1217
 
 
1221
  }
1222
 
1223
  const uint32_t n_tokens = ubatch->n_tokens;
1224
+ GGML_ASSERT(n_tokens == (int64_t) sinfo.size()*sinfo.n_stream());
1225
 
1226
  GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
1227
  int64_t * data = (int64_t *) dst->data;
1228
 
1229
+ if (!v_trans) {
1230
+ for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
1231
+ const int64_t offs = sinfo.strm[s]*get_size();
1232
+
1233
+ for (uint32_t i = 0; i < sinfo.size(); ++i) {
1234
+ data[s*sinfo.size() + i] = offs + sinfo.idxs[s][i];
1235
+ }
1236
+ }
1237
+ } else {
1238
+ // note: the V cache is transposed when not using flash attention
1239
+ const int64_t kv_size = get_size();
1240
+
1241
+ const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa_max();
1242
+
1243
+ for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
1244
+ const int64_t offs = sinfo.strm[s]*kv_size*n_embd_v_gqa;
1245
+
1246
+ for (uint32_t i = 0; i < sinfo.size(); ++i) {
1247
+ for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
1248
+ data[s*sinfo.size()*n_embd_v_gqa + i*n_embd_v_gqa + j] = offs + j*kv_size + sinfo.idxs[s][i];
1249
+ }
1250
+ }
1251
+ }
1252
+ }
1253
+ }
1254
+
1255
+ void llama_kv_cache_unified::set_input_k_shift(ggml_tensor * dst) const {
1256
+ GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
1257
+
1258
+ int32_t * data = (int32_t *) dst->data;
1259
+
1260
+ for (uint32_t s = 0; s < n_stream; ++s) {
1261
+ const auto & cells = v_cells[s];
1262
+
1263
+ for (uint32_t i = 0; i < cells.size(); ++i) {
1264
+ data[s*cells.size() + i] = cells.is_empty(i) ? 0 : cells.get_shift(i);
1265
+ }
1266
  }
1267
  }
1268
 
 
1272
  GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
1273
  float * data = (float *) dst->data;
1274
 
1275
+ const int64_t n_kv = dst->ne[0];
1276
+ const int64_t n_stream = dst->ne[3]; // num streams in the current ubatch
1277
+
1278
+ GGML_ASSERT(n_tokens%n_stream == 0);
1279
+
1280
+ // n_tps == n_tokens_per_stream
1281
+ const int64_t n_tps = n_tokens/n_stream;
1282
+ const int64_t n_tps_pad = GGML_PAD(n_tps, GGML_KQ_MASK_PAD);
1283
+
1284
+ std::fill(data, data + ggml_nelements(dst), -INFINITY);
1285
 
1286
  // Use only the previous KV cells of the correct sequence for each token of the ubatch.
1287
  // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
 
1295
  // xxxxx-----
1296
  // xxxxx-----
1297
  // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
1298
+ // TODO: optimize this section
1299
  for (uint32_t h = 0; h < 1; ++h) {
1300
+ for (uint32_t s = 0; s < n_stream; ++s) {
1301
+ for (uint32_t ii = 0; ii < n_tps; ++ii) {
1302
+ const uint32_t i = s*n_tps + ii;
1303
 
1304
+ const llama_seq_id seq_id = ubatch->seq_id[i][0];
1305
 
1306
+ const auto & cells = v_cells[seq_to_stream[seq_id]];
 
1307
 
1308
+ const llama_pos p1 = ubatch->pos[i];
1309
 
1310
+ const uint64_t idst = n_kv*(h*n_stream*n_tps_pad + s*n_tps_pad + ii);
1311
+
1312
+ for (uint32_t j = 0; j < n_kv; ++j) {
1313
+ if (cells.is_empty(j)) {
1314
+ continue;
1315
+ }
1316
 
1317
  // mask the token if not the same sequence
1318
+ if (!cells.seq_has(j, seq_id)) {
1319
+ continue;
1320
+ }
1321
+
1322
+ const llama_pos p0 = cells.pos_get(j);
1323
 
1324
  // mask future tokens
1325
+ if (causal_attn && p0 > p1) {
1326
+ continue;
1327
+ }
1328
 
1329
  // apply SWA if any
1330
+ if (is_masked_swa(p0, p1)) {
1331
+ continue;
 
 
1332
  }
 
 
 
 
 
 
 
 
 
1333
 
1334
+ data[idst + j] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f;
 
 
 
 
1335
  }
1336
  }
1337
  }
1338
  }
1339
  }
1340
 
 
 
 
 
 
 
 
 
 
 
1341
  void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
1342
  const int64_t n_tokens = ubatch->n_tokens;
1343
 
1344
+ GGML_ASSERT(n_stream == 1 && "TODO: support multiple streams");
1345
+ const auto & cells = v_cells[0];
1346
+
1347
  GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
1348
+ GGML_ASSERT(!ubatch->equal_seqs()); // TODO: use ubatch->n_seqs instead of failing
1349
 
1350
  int32_t * data = (int32_t *) dst->data;
1351
 
 
1450
 
1451
  void set_input(const llama_ubatch * ubatch) override;
1452
 
1453
+ ggml_tensor * k_shift; // I32 [kv_size*n_stream]
1454
 
1455
  const llama_kv_cache_unified * kv_self;
1456
  };
 
1463
  }
1464
  }
1465
 
1466
+ ggml_cgraph * llama_kv_cache_unified::build_graph_shift(llm_graph_result * res, llama_context * lctx) const {
1467
+ auto * ctx = res->get_ctx();
1468
+ auto * gf = res->get_gf();
 
 
1469
 
1470
  const auto & n_embd_head_k = hparams.n_embd_head_k;
1471
  //const auto & n_embd_head_v = hparams.n_embd_head_v;
1472
 
1473
  auto inp = std::make_unique<llm_graph_input_k_shift>(this);
1474
 
1475
+ inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, (int64_t) get_size()*n_stream);
1476
  ggml_set_input(inp->k_shift);
1477
 
1478
+ const auto & cparams = lctx->get_cparams();
1479
+
1480
  for (const auto & layer : layers) {
1481
  const uint32_t il = layer.il;
1482
 
 
1490
 
1491
  ggml_tensor * k =
1492
  ggml_view_3d(ctx, layer.k,
1493
+ n_embd_head_k, n_head_kv, get_size()*n_stream,
1494
  ggml_row_size(layer.k->type, n_embd_head_k),
1495
  ggml_row_size(layer.k->type, n_embd_k_gqa),
1496
  0);
 
1502
 
1503
  res->add_input(std::move(inp));
1504
 
1505
+ return gf;
1506
  }
1507
 
1508
+ ggml_cgraph * llama_kv_cache_unified::build_graph_defrag(
1509
+ llm_graph_result * res,
1510
+ llama_context * lctx,
1511
+ const defrag_info & dinfo) const {
1512
+ auto * ctx = res->get_ctx();
1513
+ auto * gf = res->get_gf();
1514
+
1515
+ GGML_ASSERT(n_stream == 1 && "n_stream > 1 does not support defrag");
1516
+
1517
+ const auto & cells = v_cells[0];
1518
 
1519
  const auto & ids = dinfo.ids;
1520
 
1521
+ const auto & cparams = lctx->get_cparams();
1522
+
1523
  #if 0
1524
  // CPU defrag
1525
  //
 
1656
  //LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes);
1657
  #endif
1658
 
1659
+ return gf;
1660
  }
1661
 
1662
  llama_kv_cache_unified::defrag_info llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) const {
1663
+ GGML_ASSERT(n_stream == 1 && "n_stream > 1 does not support defrag");
1664
+
1665
+ const auto & cells = v_cells[0];
1666
+
1667
  const uint32_t n_layer = layers.size();
1668
 
1669
  const uint32_t n_kv = cells.used_max_p1();
 
1809
  }
1810
 
1811
  void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
1812
+ io.write(&n_stream, sizeof(n_stream));
 
1813
 
1814
+ for (uint32_t s = 0; s < n_stream; ++s) {
1815
+ cell_ranges_t cr { s, {} };
 
1816
 
1817
+ uint32_t cell_count = 0;
1818
+
1819
+ const auto & cells = v_cells[s];
1820
+
1821
+ // Count the number of cells with the specified seq_id
1822
+ // Find all the ranges of cells with this seq id (or all, when -1)
1823
+ uint32_t cell_range_begin = cells.size();
1824
+
1825
+ for (uint32_t i = 0; i < cells.size(); ++i) {
1826
+ if (!cells.is_empty(i) && (seq_id == -1 || cells.seq_has(i, seq_id))) {
1827
+ ++cell_count;
1828
+ if (cell_range_begin == cells.size()) {
1829
+ cell_range_begin = i;
1830
+ }
1831
+ } else {
1832
+ if (cell_range_begin != cells.size()) {
1833
+ cr.data.emplace_back(cell_range_begin, i);
1834
+ cell_range_begin = cells.size();
1835
+ }
1836
  }
1837
  }
 
1838
 
1839
+ if (cell_range_begin != cells.size()) {
1840
+ cr.data.emplace_back(cell_range_begin, cells.size());
1841
+ }
1842
 
1843
+ // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
1844
+ uint32_t cell_count_check = 0;
1845
+ for (const auto & range : cr.data) {
1846
+ cell_count_check += range.second - range.first;
1847
+ }
1848
+ GGML_ASSERT(cell_count == cell_count_check);
1849
 
1850
+ io.write(&cell_count, sizeof(cell_count));
1851
 
1852
+ // skip empty streams
1853
+ if (cell_count == 0) {
1854
+ continue;
1855
+ }
1856
+
1857
+ state_write_meta(io, cr, seq_id);
1858
+ state_write_data(io, cr);
1859
+ }
1860
  }
1861
 
1862
  void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
1863
+ GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()));
 
1864
 
1865
+ uint32_t n_stream_cur;
1866
+ io.read_to(&n_stream_cur, sizeof(n_stream_cur));
1867
+ if (n_stream_cur != n_stream) {
1868
+ throw std::runtime_error("n_stream mismatch");
1869
+ }
1870
+
1871
+ for (uint32_t s = 0; s < n_stream; ++s) {
1872
+ uint32_t cell_count;
1873
+ io.read_to(&cell_count, sizeof(cell_count));
1874
+
1875
+ if (cell_count == 0) {
1876
+ continue;
1877
+ }
1878
 
1879
+ const uint32_t strm = seq_id == -1 ? s : seq_to_stream[seq_id];
1880
+
1881
+ bool res = true;
1882
+ res = res && state_read_meta(io, strm, cell_count, seq_id);
1883
+ res = res && state_read_data(io, strm, cell_count);
1884
+
1885
+ if (!res) {
1886
+ if (seq_id == -1) {
1887
+ clear(true);
1888
+ } else {
1889
+ seq_rm(seq_id, -1, -1);
1890
+ }
1891
+ throw std::runtime_error("failed to restore kv cache");
1892
  }
 
1893
  }
1894
  }
1895
 
1896
+ void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const cell_ranges_t & cr, llama_seq_id seq_id) const {
1897
+ const auto & cells = v_cells[cr.strm];
1898
+
1899
+ for (const auto & range : cr.data) {
1900
  for (uint32_t i = range.first; i < range.second; ++i) {
1901
  std::vector<llama_seq_id> seq_ids;
1902
 
 
1921
  }
1922
  }
1923
 
1924
+ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const cell_ranges_t & cr) const {
1925
+ const auto & cells = v_cells[cr.strm];
1926
+
1927
  const uint32_t v_trans = this->v_trans ? 1 : 0;
1928
  const uint32_t n_layer = layers.size();
1929
 
 
1939
 
1940
  const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
1941
 
1942
+ auto * k = layer.k_stream[cr.strm];
1943
+
1944
  // Write key type
1945
+ const int32_t k_type_i = (int32_t) k->type;
1946
  io.write(&k_type_i, sizeof(k_type_i));
1947
 
1948
  // Write row size of key
1949
+ const uint64_t k_size_row = ggml_row_size(k->type, n_embd_k_gqa);
1950
  io.write(&k_size_row, sizeof(k_size_row));
1951
 
1952
  // Read each range of cells of k_size length each into tmp_buf and write out
1953
+ for (const auto & range : cr.data) {
1954
  const size_t range_size = range.second - range.first;
1955
  const size_t buf_size = range_size * k_size_row;
1956
+ io.write_tensor(k, range.first * k_size_row, buf_size);
1957
  }
1958
  }
1959
 
 
1963
 
1964
  const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
1965
 
1966
+ auto * v = layer.v_stream[cr.strm];
1967
+
1968
  // Write value type
1969
+ const int32_t v_type_i = (int32_t) v->type;
1970
  io.write(&v_type_i, sizeof(v_type_i));
1971
 
1972
  // Write row size of value
1973
+ const uint64_t v_size_row = ggml_row_size(v->type, n_embd_v_gqa);
1974
  io.write(&v_size_row, sizeof(v_size_row));
1975
 
1976
  // Read each range of cells of v_size length each into tmp_buf and write out
1977
+ for (const auto & range : cr.data) {
1978
  const size_t range_size = range.second - range.first;
1979
  const size_t buf_size = range_size * v_size_row;
1980
+ io.write_tensor(v, range.first * v_size_row, buf_size);
1981
  }
1982
  }
1983
  } else {
 
1989
 
1990
  const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
1991
 
1992
+ auto * v = layer.v_stream[cr.strm];
1993
+
1994
  // Write value type
1995
+ const int32_t v_type_i = (int32_t) v->type;
1996
  io.write(&v_type_i, sizeof(v_type_i));
1997
 
1998
  // Write element size
1999
+ const uint32_t v_size_el = ggml_type_size(v->type);
2000
  io.write(&v_size_el, sizeof(v_size_el));
2001
 
2002
  // Write GQA embedding size
 
2005
  // For each row, we get the element values of each cell
2006
  for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
2007
  // Read each range of cells of v_size_el length each into tmp_buf and write out
2008
+ for (const auto & range : cr.data) {
2009
  const size_t range_size = range.second - range.first;
2010
  const size_t src_offset = (range.first + j * kv_size) * v_size_el;
2011
  const size_t buf_size = range_size * v_size_el;
2012
+ io.write_tensor(v, src_offset, buf_size);
2013
  }
2014
  }
2015
  }
2016
  }
2017
  }
2018
 
2019
+ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, llama_seq_id dest_seq_id) {
2020
+ auto & cells = v_cells[strm];
2021
+ auto & head = v_heads[strm];
2022
+
2023
  if (dest_seq_id != -1) {
2024
  // single sequence
 
2025
  seq_rm(dest_seq_id, -1, -1);
2026
 
2027
  llama_batch_allocr balloc(hparams.n_pos_per_embd());
2028
 
2029
  llama_ubatch ubatch = balloc.ubatch_reserve(cell_count, 1);
2030
 
2031
+ ubatch.seq_id_unq[0] = dest_seq_id;
2032
+
2033
  for (uint32_t i = 0; i < cell_count; ++i) {
2034
  llama_pos pos;
2035
  uint32_t n_seq_id;
 
2066
  // keep the head at the old position because we will read the KV data into it in state_read_data()
2067
  head = head_cur;
2068
 
2069
+ LLAMA_LOG_DEBUG("%s: head_cur = %d, head = %d, cell_count = %d, dest_seq_id = %d\n", __func__, head_cur, head, cell_count, dest_seq_id);
2070
+
2071
  // DEBUG CHECK: head_cur should be our first cell, head_cur + cell_count - 1 should be our last cell (verify seq_id and pos values)
2072
  // Assume that this is one contiguous block of cells
2073
  GGML_ASSERT(head_cur + cell_count <= cells.size());
 
2113
  return true;
2114
  }
2115
 
2116
+ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count) {
2117
+ auto & cells = v_cells[strm];
2118
+ auto & head = v_heads[strm];
2119
+
2120
  uint32_t v_trans;
2121
  uint32_t n_layer;
2122
 
 
2144
 
2145
  const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
2146
 
2147
+ auto * k = layer.k_stream[strm];
2148
+
2149
  // Read type of key
2150
  int32_t k_type_i_ref;
2151
  io.read_to(&k_type_i_ref, sizeof(k_type_i_ref));
2152
+ const int32_t k_type_i = (int32_t) k->type;
2153
  if (k_type_i != k_type_i_ref) {
2154
  LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
2155
  return false;
 
2158
  // Read row size of key
2159
  uint64_t k_size_row_ref;
2160
  io.read_to(&k_size_row_ref, sizeof(k_size_row_ref));
2161
+ const size_t k_size_row = ggml_row_size(k->type, n_embd_k_gqa);
2162
  if (k_size_row != k_size_row_ref) {
2163
  LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il);
2164
  return false;
 
2166
 
2167
  if (cell_count) {
2168
  // Read and set the keys for the whole cell range
2169
+ ggml_backend_tensor_set(k, io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row);
2170
  }
2171
  }
2172
 
 
2176
 
2177
  const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
2178
 
2179
+ auto * v = layer.v_stream[strm];
2180
+
2181
  // Read type of value
2182
  int32_t v_type_i_ref;
2183
  io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
2184
+ const int32_t v_type_i = (int32_t) v->type;
2185
  if (v_type_i != v_type_i_ref) {
2186
  LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
2187
  return false;
 
2190
  // Read row size of value
2191
  uint64_t v_size_row_ref;
2192
  io.read_to(&v_size_row_ref, sizeof(v_size_row_ref));
2193
+ const size_t v_size_row = ggml_row_size(v->type, n_embd_v_gqa);
2194
  if (v_size_row != v_size_row_ref) {
2195
  LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il);
2196
  return false;
 
2198
 
2199
  if (cell_count) {
2200
  // Read and set the values for the whole cell range
2201
+ ggml_backend_tensor_set(v, io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row);
2202
  }
2203
  }
2204
  } else {
 
2208
 
2209
  const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
2210
 
2211
+ auto * v = layer.v_stream[strm];
2212
+
2213
  // Read type of value
2214
  int32_t v_type_i_ref;
2215
  io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
2216
+ const int32_t v_type_i = (int32_t) v->type;
2217
  if (v_type_i != v_type_i_ref) {
2218
  LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
2219
  return false;
 
2222
  // Read element size of value
2223
  uint32_t v_size_el_ref;
2224
  io.read_to(&v_size_el_ref, sizeof(v_size_el_ref));
2225
+ const size_t v_size_el = ggml_type_size(v->type);
2226
  if (v_size_el != v_size_el_ref) {
2227
  LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il);
2228
  return false;
 
2240
  // For each row in the transposed matrix, read the values for the whole cell range
2241
  for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
2242
  const size_t dst_offset = (head + j * cells.size()) * v_size_el;
2243
+ ggml_backend_tensor_set(v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
2244
  }
2245
  }
2246
  }
 
2259
  llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) {
2260
  n_kv = kv->get_size();
2261
 
2262
+ const uint32_t n_stream = kv->get_n_stream();
2263
+
2264
  // create a dummy slot info - the actual data is irrelevant. we just need to build the graph
2265
  sinfos.resize(1);
2266
+ sinfos[0].s0 = 0;
2267
+ sinfos[0].s1 = n_stream - 1;
2268
+ sinfos[0].idxs.resize(n_stream);
2269
+ for (uint32_t s = 0; s < n_stream; ++s) {
2270
+ sinfos[0].strm.push_back(s);
2271
+ sinfos[0].idxs[s].resize(1, 0);
2272
+ }
2273
  }
2274
 
2275
  llama_kv_cache_unified_context::llama_kv_cache_unified_context(
2276
  llama_kv_cache_unified * kv,
2277
  llama_context * lctx,
2278
  bool do_shift,
2279
+ defrag_info dinfo,
2280
+ stream_copy_info sc_info) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), lctx(lctx), do_shift(do_shift), dinfo(std::move(dinfo)), sc_info(std::move(sc_info)) {
2281
+ if (!do_shift && this->dinfo.empty() && this->sc_info.empty()) {
2282
  status = LLAMA_MEMORY_STATUS_NO_UPDATE;
2283
  }
2284
  }
 
2306
 
2307
  // no ubatches -> this is a KV cache update
2308
  if (ubatches.empty()) {
2309
+ kv->update(lctx, do_shift, dinfo, sc_info);
2310
 
2311
  return true;
2312
  }
 
2332
  return n_kv;
2333
  }
2334
 
2335
+ bool llama_kv_cache_unified_context::get_supports_set_rows() const {
2336
+ return kv->get_supports_set_rows();
2337
+ }
2338
+
2339
  ggml_tensor * llama_kv_cache_unified_context::get_k(ggml_context * ctx, int32_t il) const {
2340
+ return kv->get_k(ctx, il, n_kv, sinfos[i_cur]);
2341
  }
2342
 
2343
  ggml_tensor * llama_kv_cache_unified_context::get_v(ggml_context * ctx, int32_t il) const {
2344
+ return kv->get_v(ctx, il, n_kv, sinfos[i_cur]);
2345
  }
2346
 
2347
  ggml_tensor * llama_kv_cache_unified_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const {
examples/talk-llama/llama-kv-cache-unified.h CHANGED
@@ -35,16 +35,50 @@ public:
35
  std::vector<uint32_t> ids;
36
  };
37
 
 
 
 
 
 
 
 
 
 
 
38
  // for each ubatch, create a slot_info that contains information about where the ubatch should be inserted in the
39
  // KV cells. for example, cell indices for each token, such that: token[i] -> goes to cells[idxs[i]]
40
  struct slot_info {
41
  // data for ggml_set_rows
42
  using idx_vec_t = std::vector<uint32_t>;
43
 
44
- idx_vec_t idxs;
 
 
 
 
 
45
 
46
  uint32_t head() const {
47
- return idxs.at(0);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  }
49
 
50
  bool empty() const {
@@ -54,9 +88,6 @@ public:
54
  void clear() {
55
  idxs.clear();
56
  }
57
-
58
- // TODO: implement
59
- //std::vector<idx_vec_t> seq_idxs;
60
  };
61
 
62
  using slot_info_vec_t = std::vector<slot_info>;
@@ -68,6 +99,7 @@ public:
68
  ggml_type type_v,
69
  bool v_trans,
70
  bool offload,
 
71
  uint32_t kv_size,
72
  uint32_t n_seq_max,
73
  uint32_t n_pad,
@@ -111,7 +143,8 @@ public:
111
  // llama_kv_cache_unified specific API
112
  //
113
 
114
- uint32_t get_size() const;
 
115
 
116
  bool get_has_shift() const;
117
 
@@ -121,9 +154,12 @@ public:
121
 
122
  uint32_t get_n_kv() const;
123
 
 
 
 
124
  // get views of the current state of the cache
125
- ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv) const;
126
- ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const;
127
 
128
  // store k_cur and v_cur in the cache based on the provided head location
129
  ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const;
@@ -137,7 +173,7 @@ public:
137
  // return empty vector on failure
138
  slot_info_vec_t prepare(const std::vector<llama_ubatch> & ubatches);
139
 
140
- bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo);
141
 
142
  // find a slot of kv cells that can hold the ubatch
143
  // if cont == true, then the slot must be continuous
@@ -157,8 +193,9 @@ public:
157
  void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
158
  void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
159
 
 
 
160
  void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
161
- void set_input_k_shift (ggml_tensor * dst) const;
162
  void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
163
 
164
  private:
@@ -172,15 +209,15 @@ private:
172
 
173
  ggml_tensor * k;
174
  ggml_tensor * v;
 
 
 
175
  };
176
 
177
  bool v_trans = true; // the value tensor is transposed
178
 
179
- // the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot())
180
- // note: this is not part of the KV state and it's only used to speed-up the find_slot() method
181
- uint32_t head = 0;
182
-
183
  const uint32_t n_seq_max = 1;
 
184
 
185
  // required padding
186
  const uint32_t n_pad = 1;
@@ -193,14 +230,24 @@ private:
193
 
194
  // env: LLAMA_SET_ROWS (temporary)
195
  // ref: https://github.com/ggml-org/llama.cpp/pull/14285
196
- int supports_set_rows = false;
197
 
198
  const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
199
 
200
  std::vector<ggml_context_ptr> ctxs;
201
  std::vector<ggml_backend_buffer_ptr> bufs;
202
 
203
- llama_kv_cells_unified cells;
 
 
 
 
 
 
 
 
 
 
204
 
205
  std::vector<kv_layer> layers;
206
 
@@ -226,29 +273,34 @@ private:
226
  float freq_base,
227
  float freq_scale) const;
228
 
229
- llm_graph_result_ptr build_graph_shift(
230
- const llama_cparams & cparams,
231
- ggml_context * ctx,
232
- ggml_cgraph * gf) const;
233
 
234
- llm_graph_result_ptr build_graph_defrag(
235
- const llama_cparams & cparams,
236
- ggml_context * ctx,
237
- ggml_cgraph * gf,
238
  const defrag_info & dinfo) const;
239
 
240
- void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
241
- void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
242
 
243
- bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
244
- bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
 
 
 
 
 
 
245
  };
246
 
247
  class llama_kv_cache_unified_context : public llama_memory_context_i {
248
  public:
249
  // some shorthands
250
- using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
251
- using defrag_info = llama_kv_cache_unified::defrag_info;
 
252
 
253
  // used for errors
254
  llama_kv_cache_unified_context(llama_memory_status status);
@@ -262,7 +314,8 @@ public:
262
  llama_kv_cache_unified * kv,
263
  llama_context * lctx,
264
  bool do_shift,
265
- defrag_info dinfo);
 
266
 
267
  // used to create a batch procesing context from a batch
268
  llama_kv_cache_unified_context(
@@ -288,6 +341,9 @@ public:
288
 
289
  uint32_t get_n_kv() const;
290
 
 
 
 
291
  // get views of the current state of the cache
292
  ggml_tensor * get_k(ggml_context * ctx, int32_t il) const;
293
  ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;
@@ -320,6 +376,8 @@ private:
320
 
321
  defrag_info dinfo;
322
 
 
 
323
  //
324
  // batch processing context
325
  //
 
35
  std::vector<uint32_t> ids;
36
  };
37
 
38
+ struct stream_copy_info {
39
+ bool empty() const {
40
+ assert(ssrc.size() == sdst.size());
41
+ return ssrc.empty();
42
+ }
43
+
44
+ std::vector<uint32_t> ssrc;
45
+ std::vector<uint32_t> sdst;
46
+ };
47
+
48
  // for each ubatch, create a slot_info that contains information about where the ubatch should be inserted in the
49
  // KV cells. for example, cell indices for each token, such that: token[i] -> goes to cells[idxs[i]]
50
  struct slot_info {
51
  // data for ggml_set_rows
52
  using idx_vec_t = std::vector<uint32_t>;
53
 
54
+ // number of streams: ns = s1 - s0 + 1
55
+ llama_seq_id s0;
56
+ llama_seq_id s1;
57
+
58
+ std::vector<llama_seq_id> strm; // [ns]
59
+ std::vector<idx_vec_t> idxs; // [ns]
60
 
61
  uint32_t head() const {
62
+ GGML_ASSERT(idxs.size() == 1);
63
+ GGML_ASSERT(!idxs[0].empty());
64
+
65
+ return idxs[0][0];
66
+ }
67
+
68
+ void resize(size_t n) {
69
+ strm.resize(n);
70
+ idxs.resize(n);
71
+ }
72
+
73
+ size_t size() const {
74
+ GGML_ASSERT(idxs.size() == strm.size());
75
+ GGML_ASSERT(!idxs.empty());
76
+
77
+ return idxs[0].size();
78
+ }
79
+
80
+ size_t n_stream() const {
81
+ return strm.size();
82
  }
83
 
84
  bool empty() const {
 
88
  void clear() {
89
  idxs.clear();
90
  }
 
 
 
91
  };
92
 
93
  using slot_info_vec_t = std::vector<slot_info>;
 
99
  ggml_type type_v,
100
  bool v_trans,
101
  bool offload,
102
+ bool unified,
103
  uint32_t kv_size,
104
  uint32_t n_seq_max,
105
  uint32_t n_pad,
 
143
  // llama_kv_cache_unified specific API
144
  //
145
 
146
+ uint32_t get_size() const;
147
+ uint32_t get_n_stream() const;
148
 
149
  bool get_has_shift() const;
150
 
 
154
 
155
  uint32_t get_n_kv() const;
156
 
157
+ // TODO: temporary
158
+ bool get_supports_set_rows() const;
159
+
160
  // get views of the current state of the cache
161
+ ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const;
162
+ ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const;
163
 
164
  // store k_cur and v_cur in the cache based on the provided head location
165
  ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const;
 
173
  // return empty vector on failure
174
  slot_info_vec_t prepare(const std::vector<llama_ubatch> & ubatches);
175
 
176
+ bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo, const stream_copy_info & sc_info);
177
 
178
  // find a slot of kv cells that can hold the ubatch
179
  // if cont == true, then the slot must be continuous
 
193
  void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
194
  void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
195
 
196
+ void set_input_k_shift(ggml_tensor * dst) const;
197
+
198
  void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
 
199
  void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
200
 
201
  private:
 
209
 
210
  ggml_tensor * k;
211
  ggml_tensor * v;
212
+
213
+ std::vector<ggml_tensor *> k_stream;
214
+ std::vector<ggml_tensor *> v_stream;
215
  };
216
 
217
  bool v_trans = true; // the value tensor is transposed
218
 
 
 
 
 
219
  const uint32_t n_seq_max = 1;
220
+ const uint32_t n_stream = 1;
221
 
222
  // required padding
223
  const uint32_t n_pad = 1;
 
230
 
231
  // env: LLAMA_SET_ROWS (temporary)
232
  // ref: https://github.com/ggml-org/llama.cpp/pull/14285
233
+ bool supports_set_rows = false;
234
 
235
  const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
236
 
237
  std::vector<ggml_context_ptr> ctxs;
238
  std::vector<ggml_backend_buffer_ptr> bufs;
239
 
240
+ // the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot())
241
+ // note: this is not part of the KV state and it's only used to speed-up the find_slot() method
242
+ std::vector<uint32_t> v_heads;
243
+
244
+ std::vector<llama_kv_cells_unified> v_cells;
245
+
246
+ // maps from a sequence id to a stream id
247
+ std::vector<uint32_t> seq_to_stream;
248
+
249
+ // pending stream copies that will be applied during the next update
250
+ stream_copy_info sc_info;
251
 
252
  std::vector<kv_layer> layers;
253
 
 
273
  float freq_base,
274
  float freq_scale) const;
275
 
276
+ ggml_cgraph * build_graph_shift(
277
+ llm_graph_result * res,
278
+ llama_context * lctx) const;
 
279
 
280
+ ggml_cgraph * build_graph_defrag(
281
+ llm_graph_result * res,
282
+ llama_context * lctx,
 
283
  const defrag_info & dinfo) const;
284
 
285
+ struct cell_ranges_t {
286
+ uint32_t strm;
287
 
288
+ std::vector<std::pair<uint32_t, uint32_t>> data; // ranges, from inclusive, to exclusive
289
+ };
290
+
291
+ void state_write_meta(llama_io_write_i & io, const cell_ranges_t & cr, llama_seq_id seq_id = -1) const;
292
+ void state_write_data(llama_io_write_i & io, const cell_ranges_t & cr) const;
293
+
294
+ bool state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
295
+ bool state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count);
296
  };
297
 
298
  class llama_kv_cache_unified_context : public llama_memory_context_i {
299
  public:
300
  // some shorthands
301
+ using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
302
+ using defrag_info = llama_kv_cache_unified::defrag_info;
303
+ using stream_copy_info = llama_kv_cache_unified::stream_copy_info;
304
 
305
  // used for errors
306
  llama_kv_cache_unified_context(llama_memory_status status);
 
314
  llama_kv_cache_unified * kv,
315
  llama_context * lctx,
316
  bool do_shift,
317
+ defrag_info dinfo,
318
+ stream_copy_info sc_info);
319
 
320
  // used to create a batch procesing context from a batch
321
  llama_kv_cache_unified_context(
 
341
 
342
  uint32_t get_n_kv() const;
343
 
344
+ // TODO: temporary
345
+ bool get_supports_set_rows() const;
346
+
347
  // get views of the current state of the cache
348
  ggml_tensor * get_k(ggml_context * ctx, int32_t il) const;
349
  ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;
 
376
 
377
  defrag_info dinfo;
378
 
379
+ stream_copy_info sc_info;
380
+
381
  //
382
  // batch processing context
383
  //
examples/talk-llama/llama-memory-hybrid.cpp CHANGED
@@ -38,6 +38,7 @@ llama_memory_hybrid::llama_memory_hybrid(
38
  type_v,
39
  v_trans,
40
  offload,
 
41
  kv_size,
42
  n_seq_max,
43
  n_pad,
 
38
  type_v,
39
  v_trans,
40
  offload,
41
+ 1,
42
  kv_size,
43
  n_seq_max,
44
  n_pad,
examples/talk-llama/llama-memory-recurrent.cpp CHANGED
@@ -446,7 +446,7 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
446
  // A slot should be always be contiguous.
447
 
448
  // can only process batches with an equal number of new tokens in each sequence
449
- GGML_ASSERT(ubatch.equal_seqs);
450
 
451
  int32_t min = size - 1;
452
  int32_t max = 0;
@@ -768,6 +768,8 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std::
768
  // Iterate and write all the keys first, each row is a cell
769
  // Get whole range at a time
770
  for (uint32_t il = 0; il < n_layer; ++il) {
 
 
771
 
772
  // Write key type
773
  const int32_t r_type_i = (int32_t)r_l[il]->type;
@@ -787,6 +789,8 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std::
787
 
788
  if (!s_trans) {
789
  for (uint32_t il = 0; il < n_layer; ++il) {
 
 
790
 
791
  // Write value type
792
  const int32_t s_type_i = (int32_t)s_l[il]->type;
@@ -807,6 +811,9 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std::
807
  // When v is transposed, we also need the element size and get the element ranges from each row
808
  const uint32_t mem_size = size;
809
  for (uint32_t il = 0; il < n_layer; ++il) {
 
 
 
810
  const uint32_t n_embd_s = hparams.n_embd_s();
811
 
812
  // Write value type
@@ -951,6 +958,8 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
951
 
952
  // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
953
  for (uint32_t il = 0; il < n_layer; ++il) {
 
 
954
 
955
  // Read type of key
956
  int32_t r_type_i_ref;
@@ -978,11 +987,14 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
978
 
979
  if (!s_trans) {
980
  for (uint32_t il = 0; il < n_layer; ++il) {
 
 
981
 
982
  // Read type of value
983
  int32_t s_type_i_ref;
984
  io.read_to(&s_type_i_ref, sizeof(s_type_i_ref));
985
  const int32_t s_type_i = (int32_t)s_l[il]->type;
 
986
  if (s_type_i != s_type_i_ref) {
987
  LLAMA_LOG_ERROR("%s: mismatched s type (%d != %d, layer %d)\n", __func__, s_type_i, s_type_i_ref, il);
988
  return false;
@@ -1005,6 +1017,9 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
1005
  } else {
1006
  // For each layer, read the values for each cell (transposed)
1007
  for (uint32_t il = 0; il < n_layer; ++il) {
 
 
 
1008
  const uint32_t n_embd_s = hparams.n_embd_s();
1009
 
1010
  // Read type of value
 
446
  // A slot should be always be contiguous.
447
 
448
  // can only process batches with an equal number of new tokens in each sequence
449
+ GGML_ASSERT(ubatch.equal_seqs());
450
 
451
  int32_t min = size - 1;
452
  int32_t max = 0;
 
768
  // Iterate and write all the keys first, each row is a cell
769
  // Get whole range at a time
770
  for (uint32_t il = 0; il < n_layer; ++il) {
771
+ // skip null layers (read_data will handle this by checking "r_l" and "s_l" for null)
772
+ if (r_l[il] == nullptr) continue;
773
 
774
  // Write key type
775
  const int32_t r_type_i = (int32_t)r_l[il]->type;
 
789
 
790
  if (!s_trans) {
791
  for (uint32_t il = 0; il < n_layer; ++il) {
792
+ // skip null layers (read_data will handle this by checking "r_l" and "s_l" for null)
793
+ if (s_l[il] == nullptr) continue;
794
 
795
  // Write value type
796
  const int32_t s_type_i = (int32_t)s_l[il]->type;
 
811
  // When v is transposed, we also need the element size and get the element ranges from each row
812
  const uint32_t mem_size = size;
813
  for (uint32_t il = 0; il < n_layer; ++il) {
814
+ // skip null layers (read_data will handle this by checking "r_l" and "s_l" for null)
815
+ if (s_l[il] == nullptr) continue;
816
+
817
  const uint32_t n_embd_s = hparams.n_embd_s();
818
 
819
  // Write value type
 
958
 
959
  // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
960
  for (uint32_t il = 0; il < n_layer; ++il) {
961
+ // skip null layers
962
+ if (r_l[il] == nullptr) continue;
963
 
964
  // Read type of key
965
  int32_t r_type_i_ref;
 
987
 
988
  if (!s_trans) {
989
  for (uint32_t il = 0; il < n_layer; ++il) {
990
+ // skip null layers
991
+ if (s_l[il] == nullptr) continue;
992
 
993
  // Read type of value
994
  int32_t s_type_i_ref;
995
  io.read_to(&s_type_i_ref, sizeof(s_type_i_ref));
996
  const int32_t s_type_i = (int32_t)s_l[il]->type;
997
+
998
  if (s_type_i != s_type_i_ref) {
999
  LLAMA_LOG_ERROR("%s: mismatched s type (%d != %d, layer %d)\n", __func__, s_type_i, s_type_i_ref, il);
1000
  return false;
 
1017
  } else {
1018
  // For each layer, read the values for each cell (transposed)
1019
  for (uint32_t il = 0; il < n_layer; ++il) {
1020
+ // skip null layers
1021
+ if (s_l[il] == nullptr) continue;
1022
+
1023
  const uint32_t n_embd_s = hparams.n_embd_s();
1024
 
1025
  // Read type of value
examples/talk-llama/llama-model.cpp CHANGED
The diff for this file is too large to render. See raw diff
 
examples/talk-llama/llama-model.h CHANGED
@@ -99,8 +99,10 @@ enum llm_type {
99
  LLM_TYPE_17B_16E, // llama4 Scout
100
  LLM_TYPE_17B_128E, // llama4 Maverick
101
  LLM_TYPE_A13B,
 
102
  LLM_TYPE_30B_A3B,
103
  LLM_TYPE_235B_A22B,
 
104
  LLM_TYPE_E2B,
105
  LLM_TYPE_E4B,
106
  };
@@ -452,10 +454,7 @@ struct llama_model {
452
  llama_memory_i * create_memory(const llama_memory_params & params, llama_cparams & cparams) const;
453
 
454
  // TODO: move this to new llm_arch_model_i interface
455
- llm_graph_result_ptr build_graph(
456
- const llm_graph_params & params,
457
- ggml_cgraph * gf,
458
- llm_graph_type type) const;
459
 
460
  private:
461
  struct impl;
 
99
  LLM_TYPE_17B_16E, // llama4 Scout
100
  LLM_TYPE_17B_128E, // llama4 Maverick
101
  LLM_TYPE_A13B,
102
+ LLM_TYPE_21B_A3B, // Ernie MoE small
103
  LLM_TYPE_30B_A3B,
104
  LLM_TYPE_235B_A22B,
105
+ LLM_TYPE_300B_A47B, // Ernie MoE big
106
  LLM_TYPE_E2B,
107
  LLM_TYPE_E4B,
108
  };
 
454
  llama_memory_i * create_memory(const llama_memory_params & params, llama_cparams & cparams) const;
455
 
456
  // TODO: move this to new llm_arch_model_i interface
457
+ ggml_cgraph * build_graph(const llm_graph_params & params) const;
 
 
 
458
 
459
  private:
460
  struct impl;
examples/talk-llama/llama-quant.cpp CHANGED
@@ -884,8 +884,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
884
  if (std::regex pattern(tname); std::regex_search(tensor_name, pattern)) {
885
  if (qtype != new_type) {
886
  LLAMA_LOG_DEBUG("(overriding %s) ", ggml_type_name(new_type));
887
- new_type = qtype;
888
- break; // if two or more types are specified for the tensor, first match wins
889
  }
890
  }
891
  }
 
884
  if (std::regex pattern(tname); std::regex_search(tensor_name, pattern)) {
885
  if (qtype != new_type) {
886
  LLAMA_LOG_DEBUG("(overriding %s) ", ggml_type_name(new_type));
887
+ new_type = qtype; // if two or more types are specified for the same tensor, the last match wins
 
888
  }
889
  }
890
  }
examples/talk-llama/llama-vocab.cpp CHANGED
@@ -11,6 +11,7 @@
11
  #include <cassert>
12
  #include <cctype>
13
  #include <cfloat>
 
14
  #include <cstdarg>
15
  #include <cstring>
16
  #include <forward_list>
@@ -404,6 +405,13 @@ struct llm_tokenizer_bpe : llm_tokenizer {
404
  "[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))*((?=[\\p{L}])([^A-Z]))+(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))+((?=[\\p{L}])([^A-Z]))*(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
405
  };
406
  break;
 
 
 
 
 
 
 
407
  case LLAMA_VOCAB_PRE_TYPE_SUPERBPE:
408
  regex_exprs = {
409
  "\\p{N}+",
@@ -1196,6 +1204,284 @@ private:
1196
  const llm_tokenizer_rwkv & tokenizer;
1197
  };
1198
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1199
  //
1200
  // impl
1201
  //
@@ -1499,6 +1785,16 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
1499
  special_unk_id = LLAMA_TOKEN_NULL;
1500
  special_sep_id = LLAMA_TOKEN_NULL;
1501
  special_pad_id = LLAMA_TOKEN_NULL;
 
 
 
 
 
 
 
 
 
 
1502
  } else {
1503
  throw std::runtime_error(format("unknown tokenizer: '%s'", tokenizer_model.c_str()));
1504
  }
@@ -1629,6 +1925,9 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
1629
  } else if (
1630
  tokenizer_pre == "exaone") {
1631
  pre_type = LLAMA_VOCAB_PRE_TYPE_EXAONE;
 
 
 
1632
  } else if (
1633
  tokenizer_pre == "chameleon") {
1634
  pre_type = LLAMA_VOCAB_PRE_TYPE_CHAMELEON;
@@ -1665,6 +1964,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
1665
  tokenizer_pre == "hunyuan") {
1666
  pre_type = LLAMA_VOCAB_PRE_TYPE_HUNYUAN;
1667
  clean_spaces = false;
 
 
 
 
1668
  } else {
1669
  throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
1670
  }
@@ -2145,13 +2448,14 @@ enum llama_vocab_type llama_vocab::impl::get_type() const {
2145
 
2146
  std::string llama_vocab::impl::type_name() const{
2147
  switch (type) {
2148
- case LLAMA_VOCAB_TYPE_NONE: return "no vocab";
2149
- case LLAMA_VOCAB_TYPE_SPM: return "SPM";
2150
- case LLAMA_VOCAB_TYPE_BPE: return "BPE";
2151
- case LLAMA_VOCAB_TYPE_WPM: return "WPM";
2152
- case LLAMA_VOCAB_TYPE_UGM: return "UGM";
2153
- case LLAMA_VOCAB_TYPE_RWKV: return "RWKV";
2154
- default: return "unknown";
 
2155
  }
2156
  }
2157
 
@@ -2234,6 +2538,9 @@ void llama_vocab::impl::init_tokenizer(enum llama_vocab_type type) {
2234
  case LLAMA_VOCAB_TYPE_RWKV:
2235
  tokenizer = std::make_unique<llm_tokenizer_rwkv>(vocab);
2236
  break;
 
 
 
2237
  default:
2238
  GGML_ABORT("unsupported vocab type");
2239
  }
@@ -2566,6 +2873,23 @@ std::vector<llama_token> llama_vocab::impl::tokenize(
2566
  if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
2567
  std::string text = fragment.raw_text.substr(fragment.offset, fragment.length);
2568
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2569
  #ifdef PRETOKENIZERDEBUG
2570
  LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", text.length(), fragment.offset, fragment.length, text.c_str());
2571
  #endif
@@ -2664,6 +2988,24 @@ int32_t llama_vocab::impl::token_to_piece(llama_token token, char * buf, int32_t
2664
  memcpy(buf, result.data(), result.size());
2665
  return (int)result.size();
2666
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2667
  default:
2668
  GGML_ABORT("fatal error");
2669
  }
@@ -2908,6 +3250,12 @@ llama_token llama_vocab::byte_to_token(uint8_t ch) const {
2908
  case LLAMA_VOCAB_TYPE_BPE: {
2909
  return pimpl->token_to_id.at(unicode_byte_to_utf8(ch));
2910
  }
 
 
 
 
 
 
2911
  default:
2912
  GGML_ABORT("fatal error");
2913
  }
@@ -3009,6 +3357,10 @@ llama_token llama_vocab::token_fim_sep() const {
3009
  return pimpl->special_fim_sep_id;
3010
  }
3011
 
 
 
 
 
3012
  bool llama_vocab::get_add_space_prefix() const {
3013
  return pimpl->add_space_prefix;
3014
  }
@@ -3249,6 +3601,10 @@ llama_token llama_vocab_fim_sep(const struct llama_vocab * vocab) {
3249
  return vocab->token_fim_sep();
3250
  }
3251
 
 
 
 
 
3252
  // deprecated
3253
  const char * llama_token_get_text(const struct llama_vocab * vocab, llama_token token) {
3254
  return llama_vocab_get_text(vocab, token);
@@ -3385,4 +3741,3 @@ int32_t llama_detokenize(
3385
  bool unparse_special) {
3386
  return vocab->detokenize(tokens, n_tokens, text, text_len_max, remove_special, unparse_special);
3387
  }
3388
-
 
11
  #include <cassert>
12
  #include <cctype>
13
  #include <cfloat>
14
+ #include <cmath>
15
  #include <cstdarg>
16
  #include <cstring>
17
  #include <forward_list>
 
405
  "[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))*((?=[\\p{L}])([^A-Z]))+(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))+((?=[\\p{L}])([^A-Z]))*(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
406
  };
407
  break;
408
+ case LLAMA_VOCAB_PRE_TYPE_KIMI_K2:
409
+ regex_exprs = {
410
+ // K2 trigger pattern - this will activate the custom K2 handler in unicode.cpp
411
+ // The custom handler implements all K2 patterns with proper Han character exclusion
412
+ "\\p{Han}+",
413
+ };
414
+ break;
415
  case LLAMA_VOCAB_PRE_TYPE_SUPERBPE:
416
  regex_exprs = {
417
  "\\p{N}+",
 
1204
  const llm_tokenizer_rwkv & tokenizer;
1205
  };
1206
 
1207
+ struct llm_tokenizer_plamo2 : llm_tokenizer {
1208
+ llm_tokenizer_plamo2(const llama_vocab & vocab) {
1209
+ build(vocab);
1210
+ }
1211
+
1212
+ void build(const llama_vocab & vocab) {
1213
+ // Reset internal structures
1214
+ tokens_.clear();
1215
+ bytes_.assign(256, 0);
1216
+ to_suffix_id_.clear();
1217
+ table_.clear();
1218
+
1219
+ // Build token list and byte mapping
1220
+ std::unordered_map<std::string, float> suffix_to_score;
1221
+ std::unordered_map<std::string, llama_token> token_to_id;
1222
+
1223
+ for (size_t token_id = 0; token_id < vocab.n_tokens(); ++token_id) {
1224
+ const auto & entry = vocab.get_token_data(token_id);
1225
+ tokens_.push_back(entry.text);
1226
+ token_to_id[entry.text] = static_cast<llama_token>(token_id);
1227
+
1228
+ // Handle byte tokens
1229
+ if (vocab.is_byte(token_id)) {
1230
+ if (entry.text.length() == 6 && entry.text.substr(0, 3) == "<0x" && entry.text.back() == '>') {
1231
+ std::string hex_str = entry.text.substr(3, 2);
1232
+ int byte_val = std::stoi(hex_str, nullptr, 16);
1233
+ bytes_[byte_val] = static_cast<llama_token>(token_id);
1234
+ }
1235
+ continue;
1236
+ }
1237
+
1238
+ // Add token and all its suffixes to suffix_to_score
1239
+ suffix_to_score[entry.text] = entry.score;
1240
+
1241
+ // Extract suffixes character by character (UTF-8 aware)
1242
+ std::vector<uint32_t> cpts = unicode_cpts_from_utf8(entry.text);
1243
+ for (size_t i = 1; i < cpts.size(); ++i) {
1244
+ std::string suffix;
1245
+ for (size_t j = i; j < cpts.size(); ++j) {
1246
+ suffix += unicode_cpt_to_utf8(cpts[j]);
1247
+ }
1248
+ if (suffix_to_score.find(suffix) == suffix_to_score.end()) {
1249
+ suffix_to_score[suffix] = std::numeric_limits<float>::quiet_NaN();
1250
+ }
1251
+ }
1252
+ }
1253
+
1254
+ // Check that all byte tokens are set
1255
+ for (int i = 0; i < 256; ++i) {
1256
+ if (bytes_[i] == 0) {
1257
+ throw std::runtime_error("Byte token for <0x" + std::to_string(i) + "> is not set");
1258
+ }
1259
+ }
1260
+
1261
+ // Build suffix list in lexicographical order of reversed strings
1262
+ std::vector<std::string> suffixes;
1263
+ for (const auto & pair : suffix_to_score) {
1264
+ suffixes.push_back(pair.first);
1265
+ }
1266
+ suffixes.push_back(""); // Empty suffix
1267
+
1268
+ std::sort(suffixes.begin(), suffixes.end(), [](const std::string & a, const std::string & b) {
1269
+ std::string rev_a(a.rbegin(), a.rend());
1270
+ std::string rev_b(b.rbegin(), b.rend());
1271
+ return rev_a < rev_b;
1272
+ });
1273
+
1274
+ // Build suffix_to_id and to_suffix_id_
1275
+ std::unordered_map<std::string, int32_t> suffix_to_id;
1276
+ int32_t num_pieces = 0;
1277
+
1278
+ for (const auto & suffix : suffixes) {
1279
+ suffix_to_id[suffix] = num_pieces;
1280
+ if (!suffix.empty()) {
1281
+ std::vector<uint32_t> cpts = unicode_cpts_from_utf8(suffix);
1282
+
1283
+ std::string remaining;
1284
+ for (size_t i = 1; i < cpts.size(); ++i) {
1285
+ remaining += unicode_cpt_to_utf8(cpts[i]);
1286
+ }
1287
+
1288
+ int64_t piece_code = (static_cast<int64_t>(cpts[0]) << 32) | suffix_to_id[remaining];
1289
+ to_suffix_id_[piece_code] = num_pieces;
1290
+
1291
+ // Count number of pieces for this suffix
1292
+ int32_t pieces_for_suffix = 1; // sentinel row
1293
+ for (int32_t piece_length = static_cast<int32_t>(cpts.size()); piece_length > 0; --piece_length) {
1294
+ std::string piece;
1295
+ for (int32_t i = 0; i < piece_length; ++i) {
1296
+ piece += unicode_cpt_to_utf8(cpts[i]);
1297
+ }
1298
+ if (suffix_to_score.find(piece) != suffix_to_score.end()) {
1299
+ pieces_for_suffix++;
1300
+ }
1301
+ }
1302
+ num_pieces += pieces_for_suffix;
1303
+ } else {
1304
+ num_pieces++; // Empty suffix contributes one piece (sentinel row)
1305
+ }
1306
+ }
1307
+
1308
+ // Build flattened table
1309
+ table_.resize(num_pieces, std::vector<int32_t>(4, 0));
1310
+ int32_t table_idx = 0;
1311
+
1312
+ for (const auto & suffix : suffixes) {
1313
+ // Add all prefixes of the suffix to the table (in decreasing order of length)
1314
+ std::vector<uint32_t> cpts = unicode_cpts_from_utf8(suffix);
1315
+ for (int32_t piece_length = static_cast<int32_t>(cpts.size()); piece_length > 0; --piece_length) {
1316
+ std::string piece;
1317
+ for (int32_t i = 0; i < piece_length; ++i) {
1318
+ piece += unicode_cpt_to_utf8(cpts[i]);
1319
+ }
1320
+
1321
+ auto score_it = suffix_to_score.find(piece);
1322
+ if (score_it == suffix_to_score.end()) {
1323
+ continue;
1324
+ }
1325
+
1326
+ table_[table_idx][TABLE_PIECE_LENGTH] = piece_length;
1327
+ auto token_it = token_to_id.find(piece);
1328
+ table_[table_idx][TABLE_TOKEN_ID] = (token_it != token_to_id.end()) ? token_it->second : -1;
1329
+
1330
+ float score = score_it->second;
1331
+ table_[table_idx][TABLE_SCORE] = std::isfinite(score) ?
1332
+ static_cast<int32_t>(std::round(score * 1e4)) : INVALID_SCORE;
1333
+ table_[table_idx][TABLE_PIECE_ID] = suffix_to_id[piece];
1334
+
1335
+ table_idx++;
1336
+ }
1337
+
1338
+ // Add sentinel row
1339
+ table_[table_idx][TABLE_PIECE_LENGTH] = 1;
1340
+ table_[table_idx][TABLE_TOKEN_ID] = -1;
1341
+ table_[table_idx][TABLE_SCORE] = UNKNOWN_SCORE;
1342
+ table_idx++;
1343
+ }
1344
+ }
1345
+
1346
+ std::vector<llama_token> encode(const std::string & text) const {
1347
+ std::vector<uint32_t> unicode_data = unicode_cpts_from_utf8(text);
1348
+ // Skip the first code point if it is a BOM (Byte Order Mark)
1349
+ if (!unicode_data.empty() && unicode_data[0] == 0xFEFF) {
1350
+ unicode_data.erase(unicode_data.begin());
1351
+ }
1352
+
1353
+ if (unicode_data.empty()) {
1354
+ return {};
1355
+ }
1356
+
1357
+ const size_t data_len = unicode_data.size();
1358
+
1359
+ // Initialize scores array (dynamic programming)
1360
+ std::vector<int64_t> scores(data_len + 1, static_cast<int64_t>(1) << 60);
1361
+ scores[data_len] = 0;
1362
+
1363
+ // Path array to track best tokenization
1364
+ std::vector<std::vector<int32_t>> path(data_len + 1, std::vector<int32_t>(3, 0));
1365
+
1366
+ int32_t suffix_id = 0;
1367
+
1368
+ // Process from end to beginning
1369
+ for (int i = static_cast<int>(data_len) - 1; i >= 0; --i) {
1370
+ uint32_t c = unicode_data[i];
1371
+
1372
+ // Find next suffix ID
1373
+ for (size_t p = suffix_id; p < table_.size(); ++p) {
1374
+ int64_t piece_code = (static_cast<int64_t>(c) << 32) | table_[p][TABLE_PIECE_ID];
1375
+ auto it = to_suffix_id_.find(piece_code);
1376
+ suffix_id = (it != to_suffix_id_.end()) ? it->second : 0;
1377
+
1378
+ if (suffix_id > 0 || table_[p][TABLE_SCORE] == UNKNOWN_SCORE) {
1379
+ break;
1380
+ }
1381
+ }
1382
+
1383
+ // Update best path
1384
+ for (size_t p = suffix_id; p < table_.size(); ++p) {
1385
+ int32_t score = table_[p][TABLE_SCORE];
1386
+ if (score > INVALID_SCORE) {
1387
+ int32_t piece_length = table_[p][TABLE_PIECE_LENGTH];
1388
+ int64_t s = scores[i + piece_length] - score;
1389
+
1390
+ if (s < scores[i]) {
1391
+ scores[i] = s;
1392
+ path[i][PATH_TOKEN_LENGTH] = piece_length;
1393
+ path[i][PATH_TOKEN_ID] = table_[p][TABLE_TOKEN_ID];
1394
+ path[i][PATH_NUM_TOKENS] = path[i + piece_length][PATH_NUM_TOKENS] + 1;
1395
+
1396
+ if (score == UNKNOWN_SCORE) {
1397
+ // Add UTF-8 byte count
1398
+ path[i][PATH_NUM_TOKENS] += (c >= 0x80) + (c >= 0x800) + (c >= 0x10000);
1399
+ }
1400
+ }
1401
+ }
1402
+
1403
+ if (score == UNKNOWN_SCORE) {
1404
+ break;
1405
+ }
1406
+ }
1407
+ }
1408
+
1409
+ // Decode the best path
1410
+ std::vector<llama_token> token_ids;
1411
+ token_ids.reserve(path[0][PATH_NUM_TOKENS]);
1412
+
1413
+ int pos = 0;
1414
+ while (pos < static_cast<int>(data_len)) {
1415
+ if (path[pos][PATH_TOKEN_ID] >= 0) {
1416
+ token_ids.push_back(path[pos][PATH_TOKEN_ID]);
1417
+ } else {
1418
+ // Fall back to byte tokens
1419
+ uint32_t c = unicode_data[pos];
1420
+ int s = 1 + (c >= 0x80) + (c >= 0x800) + (c >= 0x10000);
1421
+
1422
+ for (int i = 0; i < s; ++i) {
1423
+ uint8_t b;
1424
+ if (s == 1) {
1425
+ b = c;
1426
+ } else {
1427
+ if (i == 0) {
1428
+ b = (0xF00 >> s) & 0xFF;
1429
+ } else {
1430
+ b = 0x80;
1431
+ }
1432
+ }
1433
+ token_ids.push_back(bytes_[b | ((c >> ((s - i - 1) * 6)) & 0x3F)]);
1434
+ }
1435
+ }
1436
+
1437
+ assert(path[pos][PATH_TOKEN_LENGTH] > 0);
1438
+ pos += path[pos][PATH_TOKEN_LENGTH];
1439
+ }
1440
+
1441
+ return token_ids;
1442
+ }
1443
+ private:
1444
+ // Constants for table structure
1445
+ static constexpr int32_t TABLE_PIECE_LENGTH = 0;
1446
+ static constexpr int32_t TABLE_TOKEN_ID = 1;
1447
+ static constexpr int32_t TABLE_SCORE = 2;
1448
+ static constexpr int32_t TABLE_PIECE_ID = 3;
1449
+
1450
+ // Constants for path array
1451
+ static constexpr int32_t PATH_TOKEN_LENGTH = 0;
1452
+ static constexpr int32_t PATH_TOKEN_ID = 1;
1453
+ static constexpr int32_t PATH_NUM_TOKENS = 2;
1454
+
1455
+ // Score constants
1456
+ static constexpr int32_t INVALID_SCORE = -20000000;
1457
+ static constexpr int32_t UNKNOWN_SCORE = -10000000;
1458
+
1459
+ // List of tokens in the vocabulary
1460
+ std::vector<std::string> tokens_;
1461
+
1462
+ // Mapping from byte code point to token ID (for byte fallback)
1463
+ std::vector<llama_token> bytes_;
1464
+
1465
+ // Mapping from piece code to suffix ID
1466
+ std::unordered_map<int64_t, int32_t> to_suffix_id_;
1467
+
1468
+ // Flattened table representing the Trie structure
1469
+ // Each row contains: [piece_length, token_id, score, piece_id]
1470
+ std::vector<std::vector<int32_t>> table_;
1471
+ };
1472
+
1473
+ struct llm_tokenizer_plamo2_session {
1474
+ llm_tokenizer_plamo2_session(const llm_tokenizer_plamo2 & tokenizer) : tokenizer(tokenizer) {}
1475
+
1476
+ void tokenize(const std::string & text, std::vector<llama_token> & output) {
1477
+ std::vector<llama_token> tokens = tokenizer.encode(text);
1478
+ output.insert(output.end(), tokens.begin(), tokens.end());
1479
+ }
1480
+
1481
+ private:
1482
+ const llm_tokenizer_plamo2 & tokenizer;
1483
+ };
1484
+
1485
  //
1486
  // impl
1487
  //
 
1785
  special_unk_id = LLAMA_TOKEN_NULL;
1786
  special_sep_id = LLAMA_TOKEN_NULL;
1787
  special_pad_id = LLAMA_TOKEN_NULL;
1788
+ } else if (tokenizer_model == "plamo2") {
1789
+ type = LLAMA_VOCAB_TYPE_PLAMO2;
1790
+
1791
+ // PLaMo-2 default special tokens (these will be overridden by model config)
1792
+ special_bos_id = 1; // <|plamo:bos|>
1793
+ special_eos_id = 2; // <|plamo:eos|>
1794
+ special_unk_id = 0; // <|plamo:unk|>
1795
+ special_sep_id = LLAMA_TOKEN_NULL;
1796
+ special_pad_id = 3; // <|plamo:pad|>
1797
+ special_mask_id = LLAMA_TOKEN_NULL;
1798
  } else {
1799
  throw std::runtime_error(format("unknown tokenizer: '%s'", tokenizer_model.c_str()));
1800
  }
 
1925
  } else if (
1926
  tokenizer_pre == "exaone") {
1927
  pre_type = LLAMA_VOCAB_PRE_TYPE_EXAONE;
1928
+ } else if (
1929
+ tokenizer_pre == "exaone4") {
1930
+ pre_type = LLAMA_VOCAB_PRE_TYPE_GPT2;
1931
  } else if (
1932
  tokenizer_pre == "chameleon") {
1933
  pre_type = LLAMA_VOCAB_PRE_TYPE_CHAMELEON;
 
1964
  tokenizer_pre == "hunyuan") {
1965
  pre_type = LLAMA_VOCAB_PRE_TYPE_HUNYUAN;
1966
  clean_spaces = false;
1967
+ } else if (
1968
+ tokenizer_pre == "kimi-k2") {
1969
+ pre_type = LLAMA_VOCAB_PRE_TYPE_KIMI_K2;
1970
+ clean_spaces = false;
1971
  } else {
1972
  throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
1973
  }
 
2448
 
2449
  std::string llama_vocab::impl::type_name() const{
2450
  switch (type) {
2451
+ case LLAMA_VOCAB_TYPE_NONE: return "no vocab";
2452
+ case LLAMA_VOCAB_TYPE_SPM: return "SPM";
2453
+ case LLAMA_VOCAB_TYPE_BPE: return "BPE";
2454
+ case LLAMA_VOCAB_TYPE_WPM: return "WPM";
2455
+ case LLAMA_VOCAB_TYPE_UGM: return "UGM";
2456
+ case LLAMA_VOCAB_TYPE_RWKV: return "RWKV";
2457
+ case LLAMA_VOCAB_TYPE_PLAMO2: return "PLaMo2";
2458
+ default: return "unknown";
2459
  }
2460
  }
2461
 
 
2538
  case LLAMA_VOCAB_TYPE_RWKV:
2539
  tokenizer = std::make_unique<llm_tokenizer_rwkv>(vocab);
2540
  break;
2541
+ case LLAMA_VOCAB_TYPE_PLAMO2:
2542
+ tokenizer = std::make_unique<llm_tokenizer_plamo2>(vocab);
2543
+ break;
2544
  default:
2545
  GGML_ABORT("unsupported vocab type");
2546
  }
 
2873
  if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
2874
  std::string text = fragment.raw_text.substr(fragment.offset, fragment.length);
2875
 
2876
+ #ifdef PRETOKENIZERDEBUG
2877
+ LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", text.length(), fragment.offset, fragment.length, text.c_str());
2878
+ #endif
2879
+
2880
+ session.tokenize(text, output);
2881
+ } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
2882
+ output.push_back(fragment.token);
2883
+ }
2884
+ }
2885
+ } break;
2886
+ case LLAMA_VOCAB_TYPE_PLAMO2:
2887
+ {
2888
+ llm_tokenizer_plamo2_session session(*static_cast<const llm_tokenizer_plamo2 *>(tokenizer.get()));
2889
+ for (const auto & fragment : fragment_buffer) {
2890
+ if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
2891
+ std::string text = fragment.raw_text.substr(fragment.offset, fragment.length);
2892
+
2893
  #ifdef PRETOKENIZERDEBUG
2894
  LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", text.length(), fragment.offset, fragment.length, text.c_str());
2895
  #endif
 
2988
  memcpy(buf, result.data(), result.size());
2989
  return (int)result.size();
2990
  }
2991
+ case LLAMA_VOCAB_TYPE_PLAMO2: {
2992
+ // PLaMo-2 uses similar token handling as BPE/SPM
2993
+ if (vocab.is_byte(token)) {
2994
+ // Handle byte tokens like <0xXX>
2995
+ if (token_text.length() == 6 && token_text.substr(0, 3) == "<0x" && token_text.back() == '>') {
2996
+ int hex_val = std::stoi(token_text.substr(3, 2), nullptr, 16);
2997
+ if (length < 1) {
2998
+ return -1;
2999
+ }
3000
+ buf[0] = static_cast<char>(hex_val);
3001
+ return 1;
3002
+ }
3003
+ }
3004
+
3005
+ // Normal token - just copy the text
3006
+ std::string result = token_text;
3007
+ return _try_copy(result.data(), result.size());
3008
+ }
3009
  default:
3010
  GGML_ABORT("fatal error");
3011
  }
 
3250
  case LLAMA_VOCAB_TYPE_BPE: {
3251
  return pimpl->token_to_id.at(unicode_byte_to_utf8(ch));
3252
  }
3253
+ case LLAMA_VOCAB_TYPE_PLAMO2: {
3254
+ // PLaMo-2 uses byte tokens in format <0xXX>
3255
+ char hex_str[8];
3256
+ snprintf(hex_str, sizeof(hex_str), "<0x%02X>", ch);
3257
+ return pimpl->token_to_id.at(hex_str);
3258
+ }
3259
  default:
3260
  GGML_ABORT("fatal error");
3261
  }
 
3357
  return pimpl->special_fim_sep_id;
3358
  }
3359
 
3360
+ llama_token llama_vocab::token_mask() const {
3361
+ return pimpl->special_mask_id;
3362
+ }
3363
+
3364
  bool llama_vocab::get_add_space_prefix() const {
3365
  return pimpl->add_space_prefix;
3366
  }
 
3601
  return vocab->token_fim_sep();
3602
  }
3603
 
3604
+ llama_token llama_vocab_mask(const struct llama_vocab* vocab) {
3605
+ return vocab->token_mask();
3606
+ }
3607
+
3608
  // deprecated
3609
  const char * llama_token_get_text(const struct llama_vocab * vocab, llama_token token) {
3610
  return llama_vocab_get_text(vocab, token);
 
3741
  bool unparse_special) {
3742
  return vocab->detokenize(tokens, n_tokens, text, text_len_max, remove_special, unparse_special);
3743
  }
 
examples/talk-llama/llama-vocab.h CHANGED
@@ -45,6 +45,7 @@ enum llama_vocab_pre_type {
45
  LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34,
46
  LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 35,
47
  LLAMA_VOCAB_PRE_TYPE_HUNYUAN = 36,
 
48
  };
49
 
50
  struct LLM_KV;
@@ -100,6 +101,7 @@ struct llama_vocab {
100
  llama_token token_sep() const;
101
  llama_token token_nl () const;
102
  llama_token token_pad() const;
 
103
 
104
  llama_token token_prefix() const;
105
  llama_token token_middle() const;
 
45
  LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34,
46
  LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 35,
47
  LLAMA_VOCAB_PRE_TYPE_HUNYUAN = 36,
48
+ LLAMA_VOCAB_PRE_TYPE_KIMI_K2 = 37,
49
  };
50
 
51
  struct LLM_KV;
 
101
  llama_token token_sep() const;
102
  llama_token token_nl () const;
103
  llama_token token_pad() const;
104
+ llama_token token_mask() const;
105
 
106
  llama_token token_prefix() const;
107
  llama_token token_middle() const;
examples/talk-llama/llama.h CHANGED
@@ -71,12 +71,13 @@ extern "C" {
71
  typedef int32_t llama_seq_id;
72
 
73
  enum llama_vocab_type {
74
- LLAMA_VOCAB_TYPE_NONE = 0, // For models without vocab
75
- LLAMA_VOCAB_TYPE_SPM = 1, // LLaMA tokenizer based on byte-level BPE with byte fallback
76
- LLAMA_VOCAB_TYPE_BPE = 2, // GPT-2 tokenizer based on byte-level BPE
77
- LLAMA_VOCAB_TYPE_WPM = 3, // BERT tokenizer based on WordPiece
78
- LLAMA_VOCAB_TYPE_UGM = 4, // T5 tokenizer based on Unigram
79
- LLAMA_VOCAB_TYPE_RWKV = 5, // RWKV tokenizer based on greedy tokenization
 
80
  };
81
 
82
  enum llama_rope_type {
@@ -334,6 +335,9 @@ extern "C" {
334
  bool swa_full; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
335
  // NOTE: setting to false when n_seq_max > 1 can cause bad performance in some cases
336
  // ref: https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573
 
 
 
337
  };
338
 
339
  // model quantization parameters
@@ -724,7 +728,7 @@ extern "C" {
724
  // - lazily on next llama_decode()
725
  // p0 < 0 : [0, p1]
726
  // p1 < 0 : [p0, inf)
727
- DEPRECATED(void llama_kv_self_seq_div(
728
  struct llama_context * ctx,
729
  llama_seq_id seq_id,
730
  llama_pos p0,
@@ -952,6 +956,7 @@ extern "C" {
952
  // in the order they have appeared in the batch.
953
  // Rows: number of tokens for which llama_batch.logits[i] != 0
954
  // Cols: n_vocab
 
955
  LLAMA_API float * llama_get_logits(struct llama_context * ctx);
956
 
957
  // Logits for the ith token. For positive indices, Equivalent to:
@@ -966,6 +971,7 @@ extern "C" {
966
  // in the order they have appeared in the batch.
967
  // shape: [n_outputs*n_embd]
968
  // Otherwise, returns NULL.
 
969
  LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
970
 
971
  // Get the embeddings for the ith token. For positive indices, Equivalent to:
@@ -1004,6 +1010,7 @@ extern "C" {
1004
  LLAMA_API llama_token llama_vocab_sep(const struct llama_vocab * vocab); // sentence separator
1005
  LLAMA_API llama_token llama_vocab_nl (const struct llama_vocab * vocab); // next-line
1006
  LLAMA_API llama_token llama_vocab_pad(const struct llama_vocab * vocab); // padding
 
1007
 
1008
  LLAMA_API bool llama_vocab_get_add_bos(const struct llama_vocab * vocab);
1009
  LLAMA_API bool llama_vocab_get_add_eos(const struct llama_vocab * vocab);
@@ -1389,6 +1396,7 @@ extern "C" {
1389
 
1390
  int32_t n_p_eval;
1391
  int32_t n_eval;
 
1392
  };
1393
 
1394
  struct llama_perf_sampler_data {
 
71
  typedef int32_t llama_seq_id;
72
 
73
  enum llama_vocab_type {
74
+ LLAMA_VOCAB_TYPE_NONE = 0, // For models without vocab
75
+ LLAMA_VOCAB_TYPE_SPM = 1, // LLaMA tokenizer based on byte-level BPE with byte fallback
76
+ LLAMA_VOCAB_TYPE_BPE = 2, // GPT-2 tokenizer based on byte-level BPE
77
+ LLAMA_VOCAB_TYPE_WPM = 3, // BERT tokenizer based on WordPiece
78
+ LLAMA_VOCAB_TYPE_UGM = 4, // T5 tokenizer based on Unigram
79
+ LLAMA_VOCAB_TYPE_RWKV = 5, // RWKV tokenizer based on greedy tokenization
80
+ LLAMA_VOCAB_TYPE_PLAMO2 = 6, // PLaMo-2 tokenizer based on Aho-Corasick with dynamic programming
81
  };
82
 
83
  enum llama_rope_type {
 
335
  bool swa_full; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
336
  // NOTE: setting to false when n_seq_max > 1 can cause bad performance in some cases
337
  // ref: https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573
338
+ bool kv_unified; // use a unified buffer across the input sequences when computing the attention
339
+ // try to disable when n_seq_max > 1 for improved performance when the sequences do not share a large prefix
340
+ // ref: https://github.com/ggml-org/llama.cpp/pull/14363
341
  };
342
 
343
  // model quantization parameters
 
728
  // - lazily on next llama_decode()
729
  // p0 < 0 : [0, p1]
730
  // p1 < 0 : [p0, inf)
731
+ DEPRECATED(LLAMA_API void llama_kv_self_seq_div(
732
  struct llama_context * ctx,
733
  llama_seq_id seq_id,
734
  llama_pos p0,
 
956
  // in the order they have appeared in the batch.
957
  // Rows: number of tokens for which llama_batch.logits[i] != 0
958
  // Cols: n_vocab
959
+ // TODO: deprecate in favor of llama_get_logits_ith() (ref: https://github.com/ggml-org/llama.cpp/pull/14853#issuecomment-3113143522)
960
  LLAMA_API float * llama_get_logits(struct llama_context * ctx);
961
 
962
  // Logits for the ith token. For positive indices, Equivalent to:
 
971
  // in the order they have appeared in the batch.
972
  // shape: [n_outputs*n_embd]
973
  // Otherwise, returns NULL.
974
+ // TODO: deprecate in favor of llama_get_embeddings_ith() (ref: https://github.com/ggml-org/llama.cpp/pull/14853#issuecomment-3113143522)
975
  LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
976
 
977
  // Get the embeddings for the ith token. For positive indices, Equivalent to:
 
1010
  LLAMA_API llama_token llama_vocab_sep(const struct llama_vocab * vocab); // sentence separator
1011
  LLAMA_API llama_token llama_vocab_nl (const struct llama_vocab * vocab); // next-line
1012
  LLAMA_API llama_token llama_vocab_pad(const struct llama_vocab * vocab); // padding
1013
+ LLAMA_API llama_token llama_vocab_mask(const struct llama_vocab * vocab); // mask
1014
 
1015
  LLAMA_API bool llama_vocab_get_add_bos(const struct llama_vocab * vocab);
1016
  LLAMA_API bool llama_vocab_get_add_eos(const struct llama_vocab * vocab);
 
1396
 
1397
  int32_t n_p_eval;
1398
  int32_t n_eval;
1399
+ int32_t n_reused; // number of times a ggml compute graph had been reused
1400
  };
1401
 
1402
  struct llama_perf_sampler_data {
examples/talk-llama/unicode.cpp CHANGED
@@ -557,6 +557,178 @@ static std::vector<size_t> unicode_regex_split_stl(const std::string & text, con
557
  return bpe_offsets;
558
  }
559
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
560
  static std::vector<size_t> unicode_regex_split_custom(const std::string & text, const std::string & regex_expr, const std::vector<size_t> & offsets) {
561
  std::vector<size_t> bpe_offsets;
562
 
@@ -567,6 +739,9 @@ static std::vector<size_t> unicode_regex_split_custom(const std::string & text,
567
  regex_expr == "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+") {
568
 
569
  bpe_offsets = unicode_regex_split_custom_llama3(text, offsets);
 
 
 
570
  }
571
 
572
  return bpe_offsets;
@@ -672,6 +847,38 @@ uint32_t unicode_tolower(uint32_t cpt) {
672
  return cpt; // Return the original code point if no lowercase mapping is found
673
  }
674
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
675
  std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs) {
676
  // unicode categories
677
  static const std::map<std::string, int> k_ucat_enum = {
 
557
  return bpe_offsets;
558
  }
559
 
560
+ // K2 system regex patterns (from tokenization_kimi.py):
561
+ // [\p{Han}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+
562
+ static std::vector<size_t> unicode_regex_split_custom_kimi_k2(const std::string & text, const std::vector<size_t> & offsets) {
563
+ std::vector<size_t> bpe_offsets;
564
+ bpe_offsets.reserve(offsets.size());
565
+
566
+ const auto cpts = unicode_cpts_from_utf8(text);
567
+
568
+ size_t start = 0;
569
+ for (auto offset : offsets) {
570
+ const size_t offset_ini = start;
571
+ const size_t offset_end = start + offset;
572
+ assert(offset_end <= cpts.size());
573
+ start = offset_end;
574
+
575
+ static const uint32_t OUT_OF_RANGE = 0xFFFFFFFF;
576
+ auto _get_cpt = [&] (const size_t pos) -> uint32_t {
577
+ return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE;
578
+ };
579
+
580
+ auto _get_flags = [&] (const size_t pos) -> unicode_cpt_flags {
581
+ return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags_from_cpt(cpts[pos]) : unicode_cpt_flags{};
582
+ };
583
+
584
+ size_t _prev_end = offset_ini;
585
+ auto _add_token = [&] (const size_t end) -> size_t {
586
+ assert(_prev_end <= end && end <= offset_end);
587
+ size_t len = end - _prev_end;
588
+ if (len > 0) {
589
+ bpe_offsets.push_back(len);
590
+ }
591
+ _prev_end = end;
592
+ return len;
593
+ };
594
+
595
+ for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) {
596
+ const uint32_t cpt = _get_cpt(pos);
597
+ const auto flags = _get_flags(pos);
598
+
599
+ // Pattern 1: [\p{Han}]+ (Chinese characters)
600
+ if (unicode_cpt_is_han(cpt)) {
601
+ while (unicode_cpt_is_han(_get_cpt(pos))) {
602
+ pos++;
603
+ }
604
+ _add_token(pos);
605
+ continue;
606
+ }
607
+
608
+ // Pattern 2 & 3: Letter words excluding Han characters with optional contractions
609
+ // [^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+(?:'s|'t|'re|'ve|'m|'ll|'d)?
610
+ // [^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*(?:'s|'t|'re|'ve|'m|'ll|'d)?
611
+ // Check if current char is a letter OR if current char could be a leading char and next char is a letter
612
+ bool is_letter_pattern = (flags.is_letter && !unicode_cpt_is_han(cpt)) ||
613
+ (!(cpt == '\r' || cpt == '\n' || flags.is_letter || flags.is_number) &&
614
+ _get_flags(pos + 1).is_letter && !unicode_cpt_is_han(_get_cpt(pos + 1)));
615
+
616
+ if (is_letter_pattern) {
617
+ // Handle optional leading non-letter/non-number character
618
+ bool has_leading_char = false;
619
+ if (!(cpt == '\r' || cpt == '\n' || flags.is_letter || flags.is_number)) {
620
+ has_leading_char = true;
621
+ pos++;
622
+ }
623
+
624
+ // Match letter sequence (excluding Han characters)
625
+ bool has_letters = false;
626
+ while (_get_flags(pos).is_letter && !unicode_cpt_is_han(_get_cpt(pos))) {
627
+ has_letters = true;
628
+ pos++;
629
+ }
630
+
631
+ // Only proceed if we found letters (after potentially skipping leading char)
632
+ if (has_letters || (!has_leading_char && _get_flags(pos).is_letter && !unicode_cpt_is_han(_get_cpt(pos)))) {
633
+ if (!has_letters) pos++; // consume the first letter if we didn't already
634
+
635
+ // Continue consuming letters
636
+ while (_get_flags(pos).is_letter && !unicode_cpt_is_han(_get_cpt(pos))) {
637
+ pos++;
638
+ }
639
+
640
+ // Check for optional contractions (?:'s|'t|'re|'ve|'m|'ll|'d)
641
+ if (_get_cpt(pos) == '\'' && pos + 1 < offset_end) {
642
+ uint32_t cpt_next = unicode_tolower(_get_cpt(pos + 1));
643
+ if (cpt_next == 's' || cpt_next == 't' || cpt_next == 'm' || cpt_next == 'd') {
644
+ pos += 2;
645
+ } else if (pos + 2 < offset_end) {
646
+ uint32_t cpt_next_next = unicode_tolower(_get_cpt(pos + 2));
647
+ if ((cpt_next == 'r' && cpt_next_next == 'e') ||
648
+ (cpt_next == 'v' && cpt_next_next == 'e') ||
649
+ (cpt_next == 'l' && cpt_next_next == 'l')) {
650
+ pos += 3;
651
+ }
652
+ }
653
+ }
654
+
655
+ _add_token(pos);
656
+ continue;
657
+ } else if (has_leading_char) {
658
+ // We consumed a leading char but found no letters, backtrack
659
+ pos--;
660
+ }
661
+ }
662
+
663
+ // Pattern 4: \p{N}{1,3} (numbers 1-3 digits)
664
+ if (flags.is_number) {
665
+ size_t ini = pos;
666
+ while (_get_flags(pos).is_number) {
667
+ if (++pos - ini >= 3) {
668
+ _add_token(pos);
669
+ ini = pos;
670
+ }
671
+ }
672
+ _add_token(pos);
673
+ continue;
674
+ }
675
+
676
+ // Pattern 5: ?[^\s\p{L}\p{N}]+[\r\n]* (optional space + non-word chars + optional newlines)
677
+ auto flags2 = (cpt == ' ' ? _get_flags(pos + 1) : flags);
678
+ if (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number) && flags2.as_uint()) {
679
+ pos += (cpt == ' ');
680
+ while (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number) && flags2.as_uint()) {
681
+ flags2 = _get_flags(++pos);
682
+ }
683
+ // Match optional [\r\n]*
684
+ uint32_t cpt2 = _get_cpt(pos);
685
+ while (cpt2 == '\r' || cpt2 == '\n') {
686
+ cpt2 = _get_cpt(++pos);
687
+ }
688
+ _add_token(pos);
689
+ continue;
690
+ }
691
+
692
+ // Count whitespace characters
693
+ size_t num_whitespaces = 0;
694
+ size_t last_end_r_or_n = 0;
695
+ while (_get_flags(pos + num_whitespaces).is_whitespace) {
696
+ uint32_t cpt2 = _get_cpt(pos + num_whitespaces);
697
+ if (cpt2 == '\r' || cpt2 == '\n') {
698
+ last_end_r_or_n = pos + num_whitespaces + 1;
699
+ }
700
+ num_whitespaces++;
701
+ }
702
+
703
+ // Pattern 6: \s*[\r\n]+ (whitespace with newlines)
704
+ if (last_end_r_or_n > 0) {
705
+ pos = last_end_r_or_n;
706
+ _add_token(pos);
707
+ continue;
708
+ }
709
+
710
+ // Pattern 7: \s+(?!\S) (trailing whitespace)
711
+ if (num_whitespaces > 1 && _get_cpt(pos + num_whitespaces) != OUT_OF_RANGE) {
712
+ pos += num_whitespaces - 1;
713
+ _add_token(pos);
714
+ continue;
715
+ }
716
+
717
+ // Pattern 8: \s+ (general whitespace)
718
+ if (num_whitespaces > 0) {
719
+ pos += num_whitespaces;
720
+ _add_token(pos);
721
+ continue;
722
+ }
723
+
724
+ // No matches - consume single character
725
+ _add_token(++pos);
726
+ }
727
+ }
728
+
729
+ return bpe_offsets;
730
+ }
731
+
732
  static std::vector<size_t> unicode_regex_split_custom(const std::string & text, const std::string & regex_expr, const std::vector<size_t> & offsets) {
733
  std::vector<size_t> bpe_offsets;
734
 
 
739
  regex_expr == "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+") {
740
 
741
  bpe_offsets = unicode_regex_split_custom_llama3(text, offsets);
742
+ } else if (regex_expr == "\\p{Han}+") {
743
+ // K2's first pattern - handle all K2 patterns together
744
+ bpe_offsets = unicode_regex_split_custom_kimi_k2(text, offsets);
745
  }
746
 
747
  return bpe_offsets;
 
847
  return cpt; // Return the original code point if no lowercase mapping is found
848
  }
849
 
850
+ bool unicode_cpt_is_han(uint32_t cpt) {
851
+ // Han character ranges (Chinese/CJK characters)
852
+ // CJK Unified Ideographs (most common)
853
+ if (cpt >= 0x4E00 && cpt <= 0x9FFF) return true;
854
+
855
+ // CJK Extension A
856
+ if (cpt >= 0x3400 && cpt <= 0x4DBF) return true;
857
+
858
+ // CJK Extension B
859
+ if (cpt >= 0x20000 && cpt <= 0x2A6DF) return true;
860
+
861
+ // CJK Extension C
862
+ if (cpt >= 0x2A700 && cpt <= 0x2B73F) return true;
863
+
864
+ // CJK Extension D
865
+ if (cpt >= 0x2B740 && cpt <= 0x2B81F) return true;
866
+
867
+ // CJK Extension E
868
+ if (cpt >= 0x2B820 && cpt <= 0x2CEAF) return true;
869
+
870
+ // CJK Extension F
871
+ if (cpt >= 0x2CEB0 && cpt <= 0x2EBEF) return true;
872
+
873
+ // CJK Compatibility Ideographs
874
+ if (cpt >= 0xF900 && cpt <= 0xFAFF) return true;
875
+
876
+ // CJK Compatibility Ideographs Supplement
877
+ if (cpt >= 0x2F800 && cpt <= 0x2FA1F) return true;
878
+
879
+ return false;
880
+ }
881
+
882
  std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs) {
883
  // unicode categories
884
  static const std::map<std::string, int> k_ucat_enum = {
examples/talk-llama/unicode.h CHANGED
@@ -63,4 +63,6 @@ uint8_t unicode_utf8_to_byte(const std::string & utf8);
63
 
64
  uint32_t unicode_tolower(uint32_t cpt);
65
 
 
 
66
  std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs);
 
63
 
64
  uint32_t unicode_tolower(uint32_t cpt);
65
 
66
+ bool unicode_cpt_is_han(uint32_t cpt);
67
+
68
  std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs);