Spaces:
Running
Running
talk-llama : only copy used KV cache in get / set state (#890)
Browse files---------
Co-authored-by: ejones <[email protected]>
- examples/talk-llama/llama.cpp +78 -21
- examples/talk-llama/llama.h +3 -2
examples/talk-llama/llama.cpp
CHANGED
|
@@ -1270,6 +1270,9 @@ static bool llama_eval_internal(
|
|
| 1270 |
//embd_w.resize(n_vocab*N);
|
| 1271 |
//memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);
|
| 1272 |
|
|
|
|
|
|
|
|
|
|
| 1273 |
// extract logits
|
| 1274 |
{
|
| 1275 |
auto & logits_out = lctx.logits;
|
|
@@ -2386,7 +2389,7 @@ void llama_set_rng_seed(struct llama_context * ctx, int seed) {
|
|
| 2386 |
ctx->rng.seed(seed);
|
| 2387 |
}
|
| 2388 |
|
| 2389 |
-
// Returns the size of the state
|
| 2390 |
size_t llama_get_state_size(struct llama_context * ctx) {
|
| 2391 |
// we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state.
|
| 2392 |
// for reference, std::mt19937(1337) serializes to 6701 bytes.
|
|
@@ -2465,21 +2468,51 @@ size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest) {
|
|
| 2465 |
|
| 2466 |
// copy kv cache
|
| 2467 |
{
|
| 2468 |
-
const
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2469 |
const int kv_ntok = llama_get_kv_cache_token_count(ctx);
|
| 2470 |
|
| 2471 |
memcpy(out, &kv_size, sizeof(kv_size)); out += sizeof(kv_size);
|
| 2472 |
memcpy(out, &kv_ntok, sizeof(kv_ntok)); out += sizeof(kv_ntok);
|
| 2473 |
|
| 2474 |
if (kv_size) {
|
| 2475 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2476 |
}
|
| 2477 |
}
|
| 2478 |
|
| 2479 |
const size_t written = out - dest;
|
| 2480 |
-
const size_t
|
| 2481 |
|
| 2482 |
-
LLAMA_ASSERT(written
|
| 2483 |
|
| 2484 |
return written;
|
| 2485 |
}
|
|
@@ -2537,6 +2570,12 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
|
|
| 2537 |
|
| 2538 |
// set kv cache
|
| 2539 |
{
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2540 |
size_t kv_size;
|
| 2541 |
int kv_ntok;
|
| 2542 |
|
|
@@ -2544,15 +2583,33 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
|
|
| 2544 |
memcpy(&kv_ntok, in, sizeof(kv_ntok)); in += sizeof(kv_ntok);
|
| 2545 |
|
| 2546 |
if (kv_size) {
|
| 2547 |
-
LLAMA_ASSERT(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2548 |
|
| 2549 |
-
|
| 2550 |
-
|
|
|
|
| 2551 |
|
| 2552 |
-
|
|
|
|
|
|
|
| 2553 |
|
| 2554 |
-
|
| 2555 |
-
|
|
|
|
| 2556 |
|
| 2557 |
}
|
| 2558 |
|
|
@@ -2560,9 +2617,9 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
|
|
| 2560 |
}
|
| 2561 |
|
| 2562 |
const size_t nread = in - src;
|
| 2563 |
-
const size_t
|
| 2564 |
|
| 2565 |
-
LLAMA_ASSERT(nread
|
| 2566 |
|
| 2567 |
return nread;
|
| 2568 |
}
|
|
@@ -2733,14 +2790,14 @@ bool llama_load_session_file(struct llama_context * ctx, const char * path_sessi
|
|
| 2733 |
// restore the context state
|
| 2734 |
{
|
| 2735 |
const size_t n_state_size_cur = file.size - file.tell();
|
| 2736 |
-
const size_t
|
| 2737 |
|
| 2738 |
-
|
| 2739 |
-
fprintf(stderr, "%s : the state size in session file
|
| 2740 |
return false;
|
| 2741 |
}
|
| 2742 |
|
| 2743 |
-
std::vector<uint8_t> state_data(
|
| 2744 |
file.read_raw(state_data.data(), n_state_size_cur);
|
| 2745 |
|
| 2746 |
llama_set_state_data(ctx, state_data.data());
|
|
@@ -2763,12 +2820,12 @@ bool llama_save_session_file(struct llama_context * ctx, const char * path_sessi
|
|
| 2763 |
|
| 2764 |
// save the context state
|
| 2765 |
{
|
| 2766 |
-
const size_t
|
| 2767 |
|
| 2768 |
-
std::vector<uint8_t> state_data(
|
| 2769 |
-
llama_copy_state_data(ctx, state_data.data());
|
| 2770 |
|
| 2771 |
-
file.write_raw(state_data.data(),
|
| 2772 |
}
|
| 2773 |
|
| 2774 |
return true;
|
|
|
|
| 1270 |
//embd_w.resize(n_vocab*N);
|
| 1271 |
//memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);
|
| 1272 |
|
| 1273 |
+
// update kv token count
|
| 1274 |
+
lctx.model.kv_self.n = n_past + N;
|
| 1275 |
+
|
| 1276 |
// extract logits
|
| 1277 |
{
|
| 1278 |
auto & logits_out = lctx.logits;
|
|
|
|
| 2389 |
ctx->rng.seed(seed);
|
| 2390 |
}
|
| 2391 |
|
| 2392 |
+
// Returns the *maximum* size of the state
|
| 2393 |
size_t llama_get_state_size(struct llama_context * ctx) {
|
| 2394 |
// we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state.
|
| 2395 |
// for reference, std::mt19937(1337) serializes to 6701 bytes.
|
|
|
|
| 2468 |
|
| 2469 |
// copy kv cache
|
| 2470 |
{
|
| 2471 |
+
const auto & kv_self = ctx->model.kv_self;
|
| 2472 |
+
const auto & hparams = ctx->model.hparams;
|
| 2473 |
+
const int n_layer = hparams.n_layer;
|
| 2474 |
+
const int n_embd = hparams.n_embd;
|
| 2475 |
+
const int n_ctx = hparams.n_ctx;
|
| 2476 |
+
|
| 2477 |
+
const size_t kv_size = kv_self.buf.size;
|
| 2478 |
const int kv_ntok = llama_get_kv_cache_token_count(ctx);
|
| 2479 |
|
| 2480 |
memcpy(out, &kv_size, sizeof(kv_size)); out += sizeof(kv_size);
|
| 2481 |
memcpy(out, &kv_ntok, sizeof(kv_ntok)); out += sizeof(kv_ntok);
|
| 2482 |
|
| 2483 |
if (kv_size) {
|
| 2484 |
+
const size_t elt_size = ggml_element_size(kv_self.k);
|
| 2485 |
+
char buffer[4096];
|
| 2486 |
+
ggml_context * cpy_ctx = ggml_init({ sizeof(buffer), buffer, /* no_alloc */ true });
|
| 2487 |
+
ggml_cgraph gf{};
|
| 2488 |
+
gf.n_threads = 1;
|
| 2489 |
+
|
| 2490 |
+
ggml_tensor * kout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_ntok, n_layer);
|
| 2491 |
+
kout3d->data = out;
|
| 2492 |
+
out += ggml_nbytes(kout3d);
|
| 2493 |
+
|
| 2494 |
+
ggml_tensor * vout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.v->type, kv_ntok, n_embd, n_layer);
|
| 2495 |
+
vout3d->data = out;
|
| 2496 |
+
out += ggml_nbytes(vout3d);
|
| 2497 |
+
|
| 2498 |
+
ggml_tensor * k3d = ggml_view_3d(cpy_ctx, kv_self.k,
|
| 2499 |
+
n_embd, kv_ntok, n_layer,
|
| 2500 |
+
elt_size*n_embd, elt_size*n_embd*n_ctx, 0);
|
| 2501 |
+
|
| 2502 |
+
ggml_tensor * v3d = ggml_view_3d(cpy_ctx, kv_self.v,
|
| 2503 |
+
kv_ntok, n_embd, n_layer,
|
| 2504 |
+
elt_size*n_ctx, elt_size*n_ctx*n_embd, 0);
|
| 2505 |
+
|
| 2506 |
+
ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, k3d, kout3d));
|
| 2507 |
+
ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, v3d, vout3d));
|
| 2508 |
+
ggml_graph_compute(cpy_ctx, &gf);
|
| 2509 |
}
|
| 2510 |
}
|
| 2511 |
|
| 2512 |
const size_t written = out - dest;
|
| 2513 |
+
const size_t max_size = llama_get_state_size(ctx);
|
| 2514 |
|
| 2515 |
+
LLAMA_ASSERT(written <= max_size);
|
| 2516 |
|
| 2517 |
return written;
|
| 2518 |
}
|
|
|
|
| 2570 |
|
| 2571 |
// set kv cache
|
| 2572 |
{
|
| 2573 |
+
const auto & kv_self = ctx->model.kv_self;
|
| 2574 |
+
const auto & hparams = ctx->model.hparams;
|
| 2575 |
+
const int n_layer = hparams.n_layer;
|
| 2576 |
+
const int n_embd = hparams.n_embd;
|
| 2577 |
+
const int n_ctx = hparams.n_ctx;
|
| 2578 |
+
|
| 2579 |
size_t kv_size;
|
| 2580 |
int kv_ntok;
|
| 2581 |
|
|
|
|
| 2583 |
memcpy(&kv_ntok, in, sizeof(kv_ntok)); in += sizeof(kv_ntok);
|
| 2584 |
|
| 2585 |
if (kv_size) {
|
| 2586 |
+
LLAMA_ASSERT(kv_self.buf.size == kv_size);
|
| 2587 |
+
|
| 2588 |
+
const size_t elt_size = ggml_element_size(kv_self.k);
|
| 2589 |
+
char buffer[4096];
|
| 2590 |
+
ggml_context * cpy_ctx = ggml_init({ sizeof(buffer), buffer, /* no_alloc */ true });
|
| 2591 |
+
ggml_cgraph gf{};
|
| 2592 |
+
gf.n_threads = 1;
|
| 2593 |
+
|
| 2594 |
+
ggml_tensor * kin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_ntok, n_layer);
|
| 2595 |
+
kin3d->data = (void *) in;
|
| 2596 |
+
in += ggml_nbytes(kin3d);
|
| 2597 |
+
|
| 2598 |
+
ggml_tensor * vin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.v->type, kv_ntok, n_embd, n_layer);
|
| 2599 |
+
vin3d->data = (void *) in;
|
| 2600 |
+
in += ggml_nbytes(vin3d);
|
| 2601 |
|
| 2602 |
+
ggml_tensor * k3d = ggml_view_3d(cpy_ctx, kv_self.k,
|
| 2603 |
+
n_embd, kv_ntok, n_layer,
|
| 2604 |
+
elt_size*n_embd, elt_size*n_embd*n_ctx, 0);
|
| 2605 |
|
| 2606 |
+
ggml_tensor * v3d = ggml_view_3d(cpy_ctx, kv_self.v,
|
| 2607 |
+
kv_ntok, n_embd, n_layer,
|
| 2608 |
+
elt_size*n_ctx, elt_size*n_ctx*n_embd, 0);
|
| 2609 |
|
| 2610 |
+
ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, kin3d, k3d));
|
| 2611 |
+
ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, vin3d, v3d));
|
| 2612 |
+
ggml_graph_compute(cpy_ctx, &gf);
|
| 2613 |
|
| 2614 |
}
|
| 2615 |
|
|
|
|
| 2617 |
}
|
| 2618 |
|
| 2619 |
const size_t nread = in - src;
|
| 2620 |
+
const size_t max_size = llama_get_state_size(ctx);
|
| 2621 |
|
| 2622 |
+
LLAMA_ASSERT(nread <= max_size);
|
| 2623 |
|
| 2624 |
return nread;
|
| 2625 |
}
|
|
|
|
| 2790 |
// restore the context state
|
| 2791 |
{
|
| 2792 |
const size_t n_state_size_cur = file.size - file.tell();
|
| 2793 |
+
const size_t n_state_size_max = llama_get_state_size(ctx);
|
| 2794 |
|
| 2795 |
+
if (n_state_size_cur > n_state_size_max) {
|
| 2796 |
+
fprintf(stderr, "%s : the state size in session file is too big! max %zu, got %zu\n", __func__, n_state_size_max, n_state_size_cur);
|
| 2797 |
return false;
|
| 2798 |
}
|
| 2799 |
|
| 2800 |
+
std::vector<uint8_t> state_data(n_state_size_max);
|
| 2801 |
file.read_raw(state_data.data(), n_state_size_cur);
|
| 2802 |
|
| 2803 |
llama_set_state_data(ctx, state_data.data());
|
|
|
|
| 2820 |
|
| 2821 |
// save the context state
|
| 2822 |
{
|
| 2823 |
+
const size_t n_state_size_max = llama_get_state_size(ctx);
|
| 2824 |
|
| 2825 |
+
std::vector<uint8_t> state_data(n_state_size_max);
|
| 2826 |
+
const size_t n_state_size_cur = llama_copy_state_data(ctx, state_data.data());
|
| 2827 |
|
| 2828 |
+
file.write_raw(state_data.data(), n_state_size_cur);
|
| 2829 |
}
|
| 2830 |
|
| 2831 |
return true;
|
examples/talk-llama/llama.h
CHANGED
|
@@ -23,7 +23,7 @@
|
|
| 23 |
#define LLAMA_FILE_MAGIC 'ggjt'
|
| 24 |
#define LLAMA_FILE_MAGIC_UNVERSIONED 'ggml'
|
| 25 |
#define LLAMA_SESSION_MAGIC 'ggsn'
|
| 26 |
-
#define LLAMA_SESSION_VERSION
|
| 27 |
|
| 28 |
#ifdef __cplusplus
|
| 29 |
extern "C" {
|
|
@@ -127,7 +127,8 @@ extern "C" {
|
|
| 127 |
// Sets the current rng seed.
|
| 128 |
LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, int seed);
|
| 129 |
|
| 130 |
-
// Returns the size in bytes of the state (rng, logits, embedding
|
|
|
|
| 131 |
LLAMA_API size_t llama_get_state_size(struct llama_context * ctx);
|
| 132 |
|
| 133 |
// Copies the state to the specified destination address.
|
|
|
|
| 23 |
#define LLAMA_FILE_MAGIC 'ggjt'
|
| 24 |
#define LLAMA_FILE_MAGIC_UNVERSIONED 'ggml'
|
| 25 |
#define LLAMA_SESSION_MAGIC 'ggsn'
|
| 26 |
+
#define LLAMA_SESSION_VERSION 1
|
| 27 |
|
| 28 |
#ifdef __cplusplus
|
| 29 |
extern "C" {
|
|
|
|
| 127 |
// Sets the current rng seed.
|
| 128 |
LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, int seed);
|
| 129 |
|
| 130 |
+
// Returns the maximum size in bytes of the state (rng, logits, embedding
|
| 131 |
+
// and kv_cache) - will often be smaller after compacting tokens
|
| 132 |
LLAMA_API size_t llama_get_state_size(struct llama_context * ctx);
|
| 133 |
|
| 134 |
// Copies the state to the specified destination address.
|