Spaces:
Running
Running
Michael Podvitskiy
commited on
ggml : add abort_callback for cpu backend (ggml/725)
Browse files* a way to use abort_callback with the cpu backend
* whisper update
- ggml-backend.c +22 -4
- ggml-backend.h +3 -2
- ggml.c +1 -1
- ggml.h +7 -2
- whisper.cpp +4 -4
- whisper.h +1 -6
ggml-backend.c
CHANGED
|
@@ -653,6 +653,9 @@ struct ggml_backend_cpu_context {
|
|
| 653 |
int n_threads;
|
| 654 |
void * work_data;
|
| 655 |
size_t work_size;
|
|
|
|
|
|
|
|
|
|
| 656 |
};
|
| 657 |
|
| 658 |
GGML_CALL static const char * ggml_backend_cpu_name(ggml_backend_t backend) {
|
|
@@ -691,6 +694,9 @@ GGML_CALL static ggml_backend_graph_plan_t ggml_backend_cpu_graph_plan_create(gg
|
|
| 691 |
cpu_plan->cplan.work_data = malloc(cpu_plan->cplan.work_size);
|
| 692 |
}
|
| 693 |
|
|
|
|
|
|
|
|
|
|
| 694 |
return cpu_plan;
|
| 695 |
}
|
| 696 |
|
|
@@ -721,9 +727,11 @@ GGML_CALL static bool ggml_backend_cpu_graph_compute(ggml_backend_t backend, str
|
|
| 721 |
cpu_ctx->work_data = realloc(cpu_ctx->work_data, cplan.work_size);
|
| 722 |
cpu_ctx->work_size = cplan.work_size;
|
| 723 |
}
|
| 724 |
-
|
| 725 |
cplan.work_data = cpu_ctx->work_data;
|
| 726 |
|
|
|
|
|
|
|
|
|
|
| 727 |
ggml_graph_compute(cgraph, &cplan);
|
| 728 |
return true;
|
| 729 |
}
|
|
@@ -759,9 +767,11 @@ static struct ggml_backend_i cpu_backend_i = {
|
|
| 759 |
ggml_backend_t ggml_backend_cpu_init(void) {
|
| 760 |
struct ggml_backend_cpu_context * ctx = malloc(sizeof(struct ggml_backend_cpu_context));
|
| 761 |
|
| 762 |
-
ctx->n_threads
|
| 763 |
-
ctx->work_data
|
| 764 |
-
ctx->work_size
|
|
|
|
|
|
|
| 765 |
|
| 766 |
ggml_backend_t cpu_backend = malloc(sizeof(struct ggml_backend));
|
| 767 |
|
|
@@ -783,6 +793,14 @@ void ggml_backend_cpu_set_n_threads(ggml_backend_t backend_cpu, int n_threads) {
|
|
| 783 |
ctx->n_threads = n_threads;
|
| 784 |
}
|
| 785 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 786 |
GGML_CALL ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size) {
|
| 787 |
return ggml_backend_buffer_init(ggml_backend_cpu_buffer_type(), cpu_backend_buffer_i_from_ptr, ptr, size);
|
| 788 |
}
|
|
|
|
| 653 |
int n_threads;
|
| 654 |
void * work_data;
|
| 655 |
size_t work_size;
|
| 656 |
+
|
| 657 |
+
ggml_abort_callback abort_callback;
|
| 658 |
+
void * abort_callback_data;
|
| 659 |
};
|
| 660 |
|
| 661 |
GGML_CALL static const char * ggml_backend_cpu_name(ggml_backend_t backend) {
|
|
|
|
| 694 |
cpu_plan->cplan.work_data = malloc(cpu_plan->cplan.work_size);
|
| 695 |
}
|
| 696 |
|
| 697 |
+
cpu_plan->cplan.abort_callback = cpu_ctx->abort_callback;
|
| 698 |
+
cpu_plan->cplan.abort_callback_data = cpu_ctx->abort_callback_data;
|
| 699 |
+
|
| 700 |
return cpu_plan;
|
| 701 |
}
|
| 702 |
|
|
|
|
| 727 |
cpu_ctx->work_data = realloc(cpu_ctx->work_data, cplan.work_size);
|
| 728 |
cpu_ctx->work_size = cplan.work_size;
|
| 729 |
}
|
|
|
|
| 730 |
cplan.work_data = cpu_ctx->work_data;
|
| 731 |
|
| 732 |
+
cplan.abort_callback = cpu_ctx->abort_callback;
|
| 733 |
+
cplan.abort_callback_data = cpu_ctx->abort_callback_data;
|
| 734 |
+
|
| 735 |
ggml_graph_compute(cgraph, &cplan);
|
| 736 |
return true;
|
| 737 |
}
|
|
|
|
| 767 |
ggml_backend_t ggml_backend_cpu_init(void) {
|
| 768 |
struct ggml_backend_cpu_context * ctx = malloc(sizeof(struct ggml_backend_cpu_context));
|
| 769 |
|
| 770 |
+
ctx->n_threads = GGML_DEFAULT_N_THREADS;
|
| 771 |
+
ctx->work_data = NULL;
|
| 772 |
+
ctx->work_size = 0;
|
| 773 |
+
ctx->abort_callback = NULL;
|
| 774 |
+
ctx->abort_callback_data = NULL;
|
| 775 |
|
| 776 |
ggml_backend_t cpu_backend = malloc(sizeof(struct ggml_backend));
|
| 777 |
|
|
|
|
| 793 |
ctx->n_threads = n_threads;
|
| 794 |
}
|
| 795 |
|
| 796 |
+
void ggml_backend_cpu_set_abort_callback(ggml_backend_t backend_cpu, ggml_abort_callback abort_callback, void * abort_callback_data) {
|
| 797 |
+
GGML_ASSERT(ggml_backend_is_cpu(backend_cpu));
|
| 798 |
+
|
| 799 |
+
struct ggml_backend_cpu_context * ctx = (struct ggml_backend_cpu_context *)backend_cpu->context;
|
| 800 |
+
ctx->abort_callback = abort_callback;
|
| 801 |
+
ctx->abort_callback_data = abort_callback_data;
|
| 802 |
+
}
|
| 803 |
+
|
| 804 |
GGML_CALL ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size) {
|
| 805 |
return ggml_backend_buffer_init(ggml_backend_cpu_buffer_type(), cpu_backend_buffer_i_from_ptr, ptr, size);
|
| 806 |
}
|
ggml-backend.h
CHANGED
|
@@ -83,8 +83,9 @@ extern "C" {
|
|
| 83 |
|
| 84 |
GGML_API ggml_backend_t ggml_backend_cpu_init(void);
|
| 85 |
|
| 86 |
-
GGML_API GGML_CALL bool ggml_backend_is_cpu
|
| 87 |
-
GGML_API void ggml_backend_cpu_set_n_threads(ggml_backend_t backend_cpu, int n_threads);
|
|
|
|
| 88 |
|
| 89 |
// Create a backend buffer from an existing pointer
|
| 90 |
GGML_API GGML_CALL ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size);
|
|
|
|
| 83 |
|
| 84 |
GGML_API ggml_backend_t ggml_backend_cpu_init(void);
|
| 85 |
|
| 86 |
+
GGML_API GGML_CALL bool ggml_backend_is_cpu (ggml_backend_t backend);
|
| 87 |
+
GGML_API void ggml_backend_cpu_set_n_threads (ggml_backend_t backend_cpu, int n_threads);
|
| 88 |
+
GGML_API void ggml_backend_cpu_set_abort_callback(ggml_backend_t backend_cpu, ggml_abort_callback abort_callback, void * abort_callback_data);
|
| 89 |
|
| 90 |
// Create a backend buffer from an existing pointer
|
| 91 |
GGML_API GGML_CALL ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size);
|
ggml.c
CHANGED
|
@@ -16560,7 +16560,7 @@ struct ggml_compute_state_shared {
|
|
| 16560 |
atomic_int node_n; // active graph node
|
| 16561 |
atomic_int node_task; // active graph node task phase
|
| 16562 |
|
| 16563 |
-
|
| 16564 |
void * abort_callback_data;
|
| 16565 |
};
|
| 16566 |
|
|
|
|
| 16560 |
atomic_int node_n; // active graph node
|
| 16561 |
atomic_int node_task; // active graph node task phase
|
| 16562 |
|
| 16563 |
+
ggml_abort_callback abort_callback; // abort ggml_graph_compute when true
|
| 16564 |
void * abort_callback_data;
|
| 16565 |
};
|
| 16566 |
|
ggml.h
CHANGED
|
@@ -567,6 +567,11 @@ extern "C" {
|
|
| 567 |
|
| 568 |
static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor);
|
| 569 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 570 |
// the compute plan that needs to be prepared for ggml_graph_compute()
|
| 571 |
// since https://github.com/ggerganov/ggml/issues/287
|
| 572 |
struct ggml_cplan {
|
|
@@ -576,8 +581,8 @@ extern "C" {
|
|
| 576 |
int n_threads;
|
| 577 |
|
| 578 |
// abort ggml_graph_compute when true
|
| 579 |
-
|
| 580 |
-
void *
|
| 581 |
};
|
| 582 |
|
| 583 |
enum ggml_cgraph_eval_order {
|
|
|
|
| 567 |
|
| 568 |
static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor);
|
| 569 |
|
| 570 |
+
// Abort callback
|
| 571 |
+
// If not NULL, called before ggml computation
|
| 572 |
+
// If it returns true, the computation is aborted
|
| 573 |
+
typedef bool (*ggml_abort_callback)(void * data);
|
| 574 |
+
|
| 575 |
// the compute plan that needs to be prepared for ggml_graph_compute()
|
| 576 |
// since https://github.com/ggerganov/ggml/issues/287
|
| 577 |
struct ggml_cplan {
|
|
|
|
| 581 |
int n_threads;
|
| 582 |
|
| 583 |
// abort ggml_graph_compute when true
|
| 584 |
+
ggml_abort_callback abort_callback;
|
| 585 |
+
void * abort_callback_data;
|
| 586 |
};
|
| 587 |
|
| 588 |
enum ggml_cgraph_eval_order {
|
whisper.cpp
CHANGED
|
@@ -156,11 +156,11 @@ static bool ggml_graph_compute_helper(
|
|
| 156 |
struct ggml_cgraph * graph,
|
| 157 |
std::vector<uint8_t> & buf,
|
| 158 |
int n_threads,
|
| 159 |
-
|
| 160 |
void * abort_callback_data) {
|
| 161 |
struct ggml_cplan plan = ggml_graph_plan(graph, n_threads);
|
| 162 |
|
| 163 |
-
plan.abort_callback
|
| 164 |
plan.abort_callback_data = abort_callback_data;
|
| 165 |
|
| 166 |
if (plan.work_size > 0) {
|
|
@@ -2130,7 +2130,7 @@ static bool whisper_encode_internal(
|
|
| 2130 |
whisper_state & wstate,
|
| 2131 |
const int mel_offset,
|
| 2132 |
const int n_threads,
|
| 2133 |
-
|
| 2134 |
void * abort_callback_data) {
|
| 2135 |
const int64_t t_start_us = ggml_time_us();
|
| 2136 |
|
|
@@ -2561,7 +2561,7 @@ static bool whisper_decode_internal(
|
|
| 2561 |
whisper_state & wstate,
|
| 2562 |
const whisper_batch & batch,
|
| 2563 |
const int n_threads,
|
| 2564 |
-
|
| 2565 |
void * abort_callback_data) {
|
| 2566 |
const int64_t t_start_us = ggml_time_us();
|
| 2567 |
|
|
|
|
| 156 |
struct ggml_cgraph * graph,
|
| 157 |
std::vector<uint8_t> & buf,
|
| 158 |
int n_threads,
|
| 159 |
+
ggml_abort_callback abort_callback,
|
| 160 |
void * abort_callback_data) {
|
| 161 |
struct ggml_cplan plan = ggml_graph_plan(graph, n_threads);
|
| 162 |
|
| 163 |
+
plan.abort_callback = abort_callback;
|
| 164 |
plan.abort_callback_data = abort_callback_data;
|
| 165 |
|
| 166 |
if (plan.work_size > 0) {
|
|
|
|
| 2130 |
whisper_state & wstate,
|
| 2131 |
const int mel_offset,
|
| 2132 |
const int n_threads,
|
| 2133 |
+
ggml_abort_callback abort_callback,
|
| 2134 |
void * abort_callback_data) {
|
| 2135 |
const int64_t t_start_us = ggml_time_us();
|
| 2136 |
|
|
|
|
| 2561 |
whisper_state & wstate,
|
| 2562 |
const whisper_batch & batch,
|
| 2563 |
const int n_threads,
|
| 2564 |
+
ggml_abort_callback abort_callback,
|
| 2565 |
void * abort_callback_data) {
|
| 2566 |
const int64_t t_start_us = ggml_time_us();
|
| 2567 |
|
whisper.h
CHANGED
|
@@ -412,11 +412,6 @@ extern "C" {
|
|
| 412 |
// If it returns false, the computation is aborted
|
| 413 |
typedef bool (*whisper_encoder_begin_callback)(struct whisper_context * ctx, struct whisper_state * state, void * user_data);
|
| 414 |
|
| 415 |
-
// Abort callback
|
| 416 |
-
// If not NULL, called before ggml computation
|
| 417 |
-
// If it returns true, the computation is aborted
|
| 418 |
-
typedef bool (*whisper_abort_callback)(void * user_data);
|
| 419 |
-
|
| 420 |
// Logits filter callback
|
| 421 |
// Can be used to modify the logits before sampling
|
| 422 |
// If not NULL, called after applying temperature to logits
|
|
@@ -513,7 +508,7 @@ extern "C" {
|
|
| 513 |
void * encoder_begin_callback_user_data;
|
| 514 |
|
| 515 |
// called each time before ggml computation starts
|
| 516 |
-
|
| 517 |
void * abort_callback_user_data;
|
| 518 |
|
| 519 |
// called by each decoder to filter obtained logits
|
|
|
|
| 412 |
// If it returns false, the computation is aborted
|
| 413 |
typedef bool (*whisper_encoder_begin_callback)(struct whisper_context * ctx, struct whisper_state * state, void * user_data);
|
| 414 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 415 |
// Logits filter callback
|
| 416 |
// Can be used to modify the logits before sampling
|
| 417 |
// If not NULL, called after applying temperature to logits
|
|
|
|
| 508 |
void * encoder_begin_callback_user_data;
|
| 509 |
|
| 510 |
// called each time before ggml computation starts
|
| 511 |
+
ggml_abort_callback abort_callback;
|
| 512 |
void * abort_callback_user_data;
|
| 513 |
|
| 514 |
// called by each decoder to filter obtained logits
|