Luis Herrera evanqjones commited on
Commit
773c85f
·
unverified ·
1 Parent(s): d94de9a

talk-llama : only copy used KV cache in get / set state (#890)

Browse files

---------

Co-authored-by: ejones <[email protected]>

examples/talk-llama/llama.cpp CHANGED
@@ -1270,6 +1270,9 @@ static bool llama_eval_internal(
1270
  //embd_w.resize(n_vocab*N);
1271
  //memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);
1272
 
 
 
 
1273
  // extract logits
1274
  {
1275
  auto & logits_out = lctx.logits;
@@ -2386,7 +2389,7 @@ void llama_set_rng_seed(struct llama_context * ctx, int seed) {
2386
  ctx->rng.seed(seed);
2387
  }
2388
 
2389
- // Returns the size of the state
2390
  size_t llama_get_state_size(struct llama_context * ctx) {
2391
  // we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state.
2392
  // for reference, std::mt19937(1337) serializes to 6701 bytes.
@@ -2465,21 +2468,51 @@ size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest) {
2465
 
2466
  // copy kv cache
2467
  {
2468
- const size_t kv_size = ctx->model.kv_self.buf.size;
 
 
 
 
 
 
2469
  const int kv_ntok = llama_get_kv_cache_token_count(ctx);
2470
 
2471
  memcpy(out, &kv_size, sizeof(kv_size)); out += sizeof(kv_size);
2472
  memcpy(out, &kv_ntok, sizeof(kv_ntok)); out += sizeof(kv_ntok);
2473
 
2474
  if (kv_size) {
2475
- memcpy(out, ctx->model.kv_self.buf.addr, kv_size); out += kv_size;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2476
  }
2477
  }
2478
 
2479
  const size_t written = out - dest;
2480
- const size_t expected = llama_get_state_size(ctx);
2481
 
2482
- LLAMA_ASSERT(written == expected);
2483
 
2484
  return written;
2485
  }
@@ -2537,6 +2570,12 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
2537
 
2538
  // set kv cache
2539
  {
 
 
 
 
 
 
2540
  size_t kv_size;
2541
  int kv_ntok;
2542
 
@@ -2544,15 +2583,33 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
2544
  memcpy(&kv_ntok, in, sizeof(kv_ntok)); in += sizeof(kv_ntok);
2545
 
2546
  if (kv_size) {
2547
- LLAMA_ASSERT(ctx->model.kv_self.buf.size == kv_size);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2548
 
2549
- void * k_data = ctx->model.kv_self.k->data; // remember data pointers
2550
- void * v_data = ctx->model.kv_self.v->data; // because their value is stored in buf and overwritten by memcpy
 
2551
 
2552
- memcpy(ctx->model.kv_self.buf.addr, in, kv_size); in += kv_size;
 
 
2553
 
2554
- ctx->model.kv_self.k->data = k_data; // restore correct data pointers
2555
- ctx->model.kv_self.v->data = v_data;
 
2556
 
2557
  }
2558
 
@@ -2560,9 +2617,9 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
2560
  }
2561
 
2562
  const size_t nread = in - src;
2563
- const size_t expected = llama_get_state_size(ctx);
2564
 
2565
- LLAMA_ASSERT(nread == expected);
2566
 
2567
  return nread;
2568
  }
@@ -2733,14 +2790,14 @@ bool llama_load_session_file(struct llama_context * ctx, const char * path_sessi
2733
  // restore the context state
2734
  {
2735
  const size_t n_state_size_cur = file.size - file.tell();
2736
- const size_t n_state_size_exp = llama_get_state_size(ctx);
2737
 
2738
- if (n_state_size_cur != n_state_size_exp) {
2739
- fprintf(stderr, "%s : the state size in session file didn't match! expected %zu, got %zu\n", __func__, n_state_size_exp, n_state_size_cur);
2740
  return false;
2741
  }
2742
 
2743
- std::vector<uint8_t> state_data(n_state_size_cur);
2744
  file.read_raw(state_data.data(), n_state_size_cur);
2745
 
2746
  llama_set_state_data(ctx, state_data.data());
@@ -2763,12 +2820,12 @@ bool llama_save_session_file(struct llama_context * ctx, const char * path_sessi
2763
 
2764
  // save the context state
2765
  {
2766
- const size_t n_state_size = llama_get_state_size(ctx);
2767
 
2768
- std::vector<uint8_t> state_data(n_state_size);
2769
- llama_copy_state_data(ctx, state_data.data());
2770
 
2771
- file.write_raw(state_data.data(), n_state_size);
2772
  }
2773
 
2774
  return true;
 
1270
  //embd_w.resize(n_vocab*N);
1271
  //memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);
1272
 
1273
+ // update kv token count
1274
+ lctx.model.kv_self.n = n_past + N;
1275
+
1276
  // extract logits
1277
  {
1278
  auto & logits_out = lctx.logits;
 
2389
  ctx->rng.seed(seed);
2390
  }
2391
 
2392
+ // Returns the *maximum* size of the state
2393
  size_t llama_get_state_size(struct llama_context * ctx) {
2394
  // we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state.
2395
  // for reference, std::mt19937(1337) serializes to 6701 bytes.
 
2468
 
2469
  // copy kv cache
2470
  {
2471
+ const auto & kv_self = ctx->model.kv_self;
2472
+ const auto & hparams = ctx->model.hparams;
2473
+ const int n_layer = hparams.n_layer;
2474
+ const int n_embd = hparams.n_embd;
2475
+ const int n_ctx = hparams.n_ctx;
2476
+
2477
+ const size_t kv_size = kv_self.buf.size;
2478
  const int kv_ntok = llama_get_kv_cache_token_count(ctx);
2479
 
2480
  memcpy(out, &kv_size, sizeof(kv_size)); out += sizeof(kv_size);
2481
  memcpy(out, &kv_ntok, sizeof(kv_ntok)); out += sizeof(kv_ntok);
2482
 
2483
  if (kv_size) {
2484
+ const size_t elt_size = ggml_element_size(kv_self.k);
2485
+ char buffer[4096];
2486
+ ggml_context * cpy_ctx = ggml_init({ sizeof(buffer), buffer, /* no_alloc */ true });
2487
+ ggml_cgraph gf{};
2488
+ gf.n_threads = 1;
2489
+
2490
+ ggml_tensor * kout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_ntok, n_layer);
2491
+ kout3d->data = out;
2492
+ out += ggml_nbytes(kout3d);
2493
+
2494
+ ggml_tensor * vout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.v->type, kv_ntok, n_embd, n_layer);
2495
+ vout3d->data = out;
2496
+ out += ggml_nbytes(vout3d);
2497
+
2498
+ ggml_tensor * k3d = ggml_view_3d(cpy_ctx, kv_self.k,
2499
+ n_embd, kv_ntok, n_layer,
2500
+ elt_size*n_embd, elt_size*n_embd*n_ctx, 0);
2501
+
2502
+ ggml_tensor * v3d = ggml_view_3d(cpy_ctx, kv_self.v,
2503
+ kv_ntok, n_embd, n_layer,
2504
+ elt_size*n_ctx, elt_size*n_ctx*n_embd, 0);
2505
+
2506
+ ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, k3d, kout3d));
2507
+ ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, v3d, vout3d));
2508
+ ggml_graph_compute(cpy_ctx, &gf);
2509
  }
2510
  }
2511
 
2512
  const size_t written = out - dest;
2513
+ const size_t max_size = llama_get_state_size(ctx);
2514
 
2515
+ LLAMA_ASSERT(written <= max_size);
2516
 
2517
  return written;
2518
  }
 
2570
 
2571
  // set kv cache
2572
  {
2573
+ const auto & kv_self = ctx->model.kv_self;
2574
+ const auto & hparams = ctx->model.hparams;
2575
+ const int n_layer = hparams.n_layer;
2576
+ const int n_embd = hparams.n_embd;
2577
+ const int n_ctx = hparams.n_ctx;
2578
+
2579
  size_t kv_size;
2580
  int kv_ntok;
2581
 
 
2583
  memcpy(&kv_ntok, in, sizeof(kv_ntok)); in += sizeof(kv_ntok);
2584
 
2585
  if (kv_size) {
2586
+ LLAMA_ASSERT(kv_self.buf.size == kv_size);
2587
+
2588
+ const size_t elt_size = ggml_element_size(kv_self.k);
2589
+ char buffer[4096];
2590
+ ggml_context * cpy_ctx = ggml_init({ sizeof(buffer), buffer, /* no_alloc */ true });
2591
+ ggml_cgraph gf{};
2592
+ gf.n_threads = 1;
2593
+
2594
+ ggml_tensor * kin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_ntok, n_layer);
2595
+ kin3d->data = (void *) in;
2596
+ in += ggml_nbytes(kin3d);
2597
+
2598
+ ggml_tensor * vin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.v->type, kv_ntok, n_embd, n_layer);
2599
+ vin3d->data = (void *) in;
2600
+ in += ggml_nbytes(vin3d);
2601
 
2602
+ ggml_tensor * k3d = ggml_view_3d(cpy_ctx, kv_self.k,
2603
+ n_embd, kv_ntok, n_layer,
2604
+ elt_size*n_embd, elt_size*n_embd*n_ctx, 0);
2605
 
2606
+ ggml_tensor * v3d = ggml_view_3d(cpy_ctx, kv_self.v,
2607
+ kv_ntok, n_embd, n_layer,
2608
+ elt_size*n_ctx, elt_size*n_ctx*n_embd, 0);
2609
 
2610
+ ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, kin3d, k3d));
2611
+ ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, vin3d, v3d));
2612
+ ggml_graph_compute(cpy_ctx, &gf);
2613
 
2614
  }
2615
 
 
2617
  }
2618
 
2619
  const size_t nread = in - src;
2620
+ const size_t max_size = llama_get_state_size(ctx);
2621
 
2622
+ LLAMA_ASSERT(nread <= max_size);
2623
 
2624
  return nread;
2625
  }
 
2790
  // restore the context state
2791
  {
2792
  const size_t n_state_size_cur = file.size - file.tell();
2793
+ const size_t n_state_size_max = llama_get_state_size(ctx);
2794
 
2795
+ if (n_state_size_cur > n_state_size_max) {
2796
+ fprintf(stderr, "%s : the state size in session file is too big! max %zu, got %zu\n", __func__, n_state_size_max, n_state_size_cur);
2797
  return false;
2798
  }
2799
 
2800
+ std::vector<uint8_t> state_data(n_state_size_max);
2801
  file.read_raw(state_data.data(), n_state_size_cur);
2802
 
2803
  llama_set_state_data(ctx, state_data.data());
 
2820
 
2821
  // save the context state
2822
  {
2823
+ const size_t n_state_size_max = llama_get_state_size(ctx);
2824
 
2825
+ std::vector<uint8_t> state_data(n_state_size_max);
2826
+ const size_t n_state_size_cur = llama_copy_state_data(ctx, state_data.data());
2827
 
2828
+ file.write_raw(state_data.data(), n_state_size_cur);
2829
  }
2830
 
2831
  return true;
examples/talk-llama/llama.h CHANGED
@@ -23,7 +23,7 @@
23
  #define LLAMA_FILE_MAGIC 'ggjt'
24
  #define LLAMA_FILE_MAGIC_UNVERSIONED 'ggml'
25
  #define LLAMA_SESSION_MAGIC 'ggsn'
26
- #define LLAMA_SESSION_VERSION 0
27
 
28
  #ifdef __cplusplus
29
  extern "C" {
@@ -127,7 +127,8 @@ extern "C" {
127
  // Sets the current rng seed.
128
  LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, int seed);
129
 
130
- // Returns the size in bytes of the state (rng, logits, embedding and kv_cache)
 
131
  LLAMA_API size_t llama_get_state_size(struct llama_context * ctx);
132
 
133
  // Copies the state to the specified destination address.
 
23
  #define LLAMA_FILE_MAGIC 'ggjt'
24
  #define LLAMA_FILE_MAGIC_UNVERSIONED 'ggml'
25
  #define LLAMA_SESSION_MAGIC 'ggsn'
26
+ #define LLAMA_SESSION_VERSION 1
27
 
28
  #ifdef __cplusplus
29
  extern "C" {
 
127
  // Sets the current rng seed.
128
  LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, int seed);
129
 
130
+ // Returns the maximum size in bytes of the state (rng, logits, embedding
131
+ // and kv_cache) - will often be smaller after compacting tokens
132
  LLAMA_API size_t llama_get_state_size(struct llama_context * ctx);
133
 
134
  // Copies the state to the specified destination address.