Spaces:
Running
ggml/examples: add backend support for numerical optimization (ggml/949)
Browse files* CUDA eval works
* stochastic gradient descent op
* Adam except decay
* CUDA CROSS_ENTROPY_LOSS_BACK
* CUDA mnist-fc training works
* backend CLI arg
* refactor gguf load
* remove sched from opt_step_adam
* implement l1 regularization (weight decay)
* extra call to add optimizer
* initialize gradients with ggml_graph_reset
* gradient accumulation
* increment iter per eval instead of epoch
* adjust backend interfaces
* fix ggml_graph_reset without backend
* fix ggml graph export/import
* fixup
* rename
* revert ggml_opt changes
* more general CUDA repeat_back
* update documentation, fix CNN
* validation split
* add clarifying comment
* optimize PyTorch training
* adjust buffer size, thread count
* fix 0.0f validation split
* Update examples/mnist/mnist-common.cpp
Co-authored-by: Georgi Gerganov <[email protected]>
* fix gradient accumulation
* tensor flag for accumulators -> tensor hash set
* Update include/ggml.h
Co-authored-by: slaren <[email protected]>
* Update tests/test-backend-ops.cpp
Co-authored-by: slaren <[email protected]>
* Update tests/test-backend-ops.cpp
Co-authored-by: slaren <[email protected]>
* fix test prints
* Update src/ggml-backend.c
Co-authored-by: Georgi Gerganov <[email protected]>
* better CUDA support for noncontiguous out_prod
* add comment
---------
Co-authored-by: Georgi Gerganov <[email protected]>
Co-authored-by: slaren <[email protected]>
- ggml/include/ggml-backend.h +2 -1
- ggml/include/ggml.h +32 -8
- ggml/src/ggml-backend-impl.h +10 -9
- ggml/src/ggml-backend.c +25 -0
- ggml/src/ggml-cann.cpp +1 -0
- ggml/src/ggml-cuda.cu +39 -1
- ggml/src/ggml-cuda/binbcast.cu +67 -0
- ggml/src/ggml-cuda/binbcast.cuh +2 -0
- ggml/src/ggml-cuda/cross-entropy-loss.cu +60 -0
- ggml/src/ggml-cuda/cross-entropy-loss.cuh +2 -0
- ggml/src/ggml-cuda/opt-step-adamw.cu +80 -0
- ggml/src/ggml-cuda/opt-step-adamw.cuh +5 -0
- ggml/src/ggml-cuda/out-prod.cu +52 -0
- ggml/src/ggml-cuda/out-prod.cuh +3 -0
- ggml/src/ggml-cuda/unary.cu +29 -0
- ggml/src/ggml-cuda/unary.cuh +3 -0
- ggml/src/ggml-kompute.cpp +1 -0
- ggml/src/ggml-metal.m +1 -0
- ggml/src/ggml-rpc.cpp +1 -0
- ggml/src/ggml-sycl.cpp +1 -0
- ggml/src/ggml-vulkan.cpp +1 -0
- ggml/src/ggml.c +339 -87
|
@@ -66,6 +66,7 @@ extern "C" {
|
|
| 66 |
// "offset" refers to the offset of the tensor data for setting/getting data
|
| 67 |
GGML_API GGML_CALL void ggml_backend_tensor_set( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
|
| 68 |
GGML_API GGML_CALL void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
|
|
|
|
| 69 |
|
| 70 |
GGML_API void ggml_backend_synchronize(ggml_backend_t backend);
|
| 71 |
|
|
@@ -122,7 +123,7 @@ extern "C" {
|
|
| 122 |
// The backend registry is a registry of all the available backends, and allows initializing backends in a generic way
|
| 123 |
|
| 124 |
GGML_API size_t ggml_backend_reg_get_count(void);
|
| 125 |
-
GGML_API size_t ggml_backend_reg_find_by_name(const char * name);
|
| 126 |
GGML_API ggml_backend_t ggml_backend_reg_init_backend_from_str(const char * backend_str); // str is backend_name:params (params is optional)
|
| 127 |
GGML_API const char * ggml_backend_reg_get_name(size_t i);
|
| 128 |
GGML_API ggml_backend_t ggml_backend_reg_init_backend(size_t i, const char * params); // params is backend-specific
|
|
|
|
| 66 |
// "offset" refers to the offset of the tensor data for setting/getting data
|
| 67 |
GGML_API GGML_CALL void ggml_backend_tensor_set( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
|
| 68 |
GGML_API GGML_CALL void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
|
| 69 |
+
GGML_API GGML_CALL void ggml_backend_tensor_memset( struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size);
|
| 70 |
|
| 71 |
GGML_API void ggml_backend_synchronize(ggml_backend_t backend);
|
| 72 |
|
|
|
|
| 123 |
// The backend registry is a registry of all the available backends, and allows initializing backends in a generic way
|
| 124 |
|
| 125 |
GGML_API size_t ggml_backend_reg_get_count(void);
|
| 126 |
+
GGML_API size_t ggml_backend_reg_find_by_name(const char * name); // returns index of backend with name, or SIZE_MAX if not found
|
| 127 |
GGML_API ggml_backend_t ggml_backend_reg_init_backend_from_str(const char * backend_str); // str is backend_name:params (params is optional)
|
| 128 |
GGML_API const char * ggml_backend_reg_get_name(size_t i);
|
| 129 |
GGML_API ggml_backend_t ggml_backend_reg_init_backend(size_t i, const char * params); // params is backend-specific
|
|
@@ -533,6 +533,7 @@ extern "C" {
|
|
| 533 |
|
| 534 |
GGML_OP_CROSS_ENTROPY_LOSS,
|
| 535 |
GGML_OP_CROSS_ENTROPY_LOSS_BACK,
|
|
|
|
| 536 |
|
| 537 |
GGML_OP_COUNT,
|
| 538 |
};
|
|
@@ -569,10 +570,12 @@ extern "C" {
|
|
| 569 |
GGML_LOG_LEVEL_DEBUG = 5
|
| 570 |
};
|
| 571 |
|
|
|
|
| 572 |
enum ggml_tensor_flag {
|
| 573 |
-
GGML_TENSOR_FLAG_INPUT
|
| 574 |
-
GGML_TENSOR_FLAG_OUTPUT
|
| 575 |
-
GGML_TENSOR_FLAG_PARAM
|
|
|
|
| 576 |
};
|
| 577 |
|
| 578 |
// ggml object
|
|
@@ -2080,17 +2083,38 @@ extern "C" {
|
|
| 2080 |
struct ggml_tensor * b,
|
| 2081 |
struct ggml_tensor * c);
|
| 2082 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2083 |
//
|
| 2084 |
// automatic differentiation
|
| 2085 |
//
|
| 2086 |
|
| 2087 |
-
GGML_API void ggml_set_param(
|
| 2088 |
-
|
| 2089 |
-
struct ggml_tensor * tensor);
|
| 2090 |
|
| 2091 |
|
| 2092 |
GGML_API void ggml_build_forward_expand (struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
|
| 2093 |
-
GGML_API void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool keep);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2094 |
|
| 2095 |
// graph allocation in a context
|
| 2096 |
GGML_API struct ggml_cgraph * ggml_new_graph (struct ggml_context * ctx); // size = GGML_DEFAULT_GRAPH_SIZE, grads = false
|
|
@@ -2098,7 +2122,7 @@ extern "C" {
|
|
| 2098 |
GGML_API struct ggml_cgraph * ggml_graph_dup (struct ggml_context * ctx, struct ggml_cgraph * cgraph);
|
| 2099 |
GGML_API struct ggml_cgraph ggml_graph_view (struct ggml_cgraph * cgraph, int i0, int i1);
|
| 2100 |
GGML_API void ggml_graph_cpy (struct ggml_cgraph * src, struct ggml_cgraph * dst);
|
| 2101 |
-
GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph);
|
| 2102 |
GGML_API void ggml_graph_clear (struct ggml_cgraph * cgraph);
|
| 2103 |
|
| 2104 |
GGML_API size_t ggml_graph_overhead(void);
|
|
|
|
| 533 |
|
| 534 |
GGML_OP_CROSS_ENTROPY_LOSS,
|
| 535 |
GGML_OP_CROSS_ENTROPY_LOSS_BACK,
|
| 536 |
+
GGML_OP_OPT_STEP_ADAMW,
|
| 537 |
|
| 538 |
GGML_OP_COUNT,
|
| 539 |
};
|
|
|
|
| 570 |
GGML_LOG_LEVEL_DEBUG = 5
|
| 571 |
};
|
| 572 |
|
| 573 |
+
// this tensor...
|
| 574 |
enum ggml_tensor_flag {
|
| 575 |
+
GGML_TENSOR_FLAG_INPUT = 1, // ...is an input for the GGML compute graph
|
| 576 |
+
GGML_TENSOR_FLAG_OUTPUT = 2, // ...is an output for the GGML compute graph
|
| 577 |
+
GGML_TENSOR_FLAG_PARAM = 4, // ...contains trainable parameters
|
| 578 |
+
GGML_TENSOR_FLAG_LOSS = 8, // ...defines loss for numerical optimization (multiple loss tensors add up)
|
| 579 |
};
|
| 580 |
|
| 581 |
// ggml object
|
|
|
|
| 2083 |
struct ggml_tensor * b,
|
| 2084 |
struct ggml_tensor * c);
|
| 2085 |
|
| 2086 |
+
// AdamW optimizer step
|
| 2087 |
+
// Paper: https://arxiv.org/pdf/1711.05101v3.pdf
|
| 2088 |
+
// PyTorch: https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html
|
| 2089 |
+
GGML_API struct ggml_tensor * ggml_opt_step_adamw(
|
| 2090 |
+
struct ggml_context * ctx,
|
| 2091 |
+
struct ggml_tensor * a,
|
| 2092 |
+
float alpha,
|
| 2093 |
+
float beta1,
|
| 2094 |
+
float beta2,
|
| 2095 |
+
float eps,
|
| 2096 |
+
float wd); // weight decay
|
| 2097 |
+
|
| 2098 |
//
|
| 2099 |
// automatic differentiation
|
| 2100 |
//
|
| 2101 |
|
| 2102 |
+
GGML_API void ggml_set_param(struct ggml_context * ctx, struct ggml_tensor * tensor);
|
| 2103 |
+
GGML_API void ggml_set_loss(struct ggml_tensor * tensor);
|
|
|
|
| 2104 |
|
| 2105 |
|
| 2106 |
GGML_API void ggml_build_forward_expand (struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
|
| 2107 |
+
GGML_API void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool accumulate, bool keep);
|
| 2108 |
+
|
| 2109 |
+
GGML_API void ggml_build_opt_adamw(
|
| 2110 |
+
struct ggml_context * ctx,
|
| 2111 |
+
struct ggml_cgraph * gf,
|
| 2112 |
+
struct ggml_cgraph * gb,
|
| 2113 |
+
float alpha,
|
| 2114 |
+
float beta1,
|
| 2115 |
+
float beta2,
|
| 2116 |
+
float eps,
|
| 2117 |
+
float wd); // weight decay
|
| 2118 |
|
| 2119 |
// graph allocation in a context
|
| 2120 |
GGML_API struct ggml_cgraph * ggml_new_graph (struct ggml_context * ctx); // size = GGML_DEFAULT_GRAPH_SIZE, grads = false
|
|
|
|
| 2122 |
GGML_API struct ggml_cgraph * ggml_graph_dup (struct ggml_context * ctx, struct ggml_cgraph * cgraph);
|
| 2123 |
GGML_API struct ggml_cgraph ggml_graph_view (struct ggml_cgraph * cgraph, int i0, int i1);
|
| 2124 |
GGML_API void ggml_graph_cpy (struct ggml_cgraph * src, struct ggml_cgraph * dst);
|
| 2125 |
+
GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph); // set regular grads + optimizer momenta to 0, set loss grad to 1
|
| 2126 |
GGML_API void ggml_graph_clear (struct ggml_cgraph * cgraph);
|
| 2127 |
|
| 2128 |
GGML_API size_t ggml_graph_overhead(void);
|
|
@@ -38,15 +38,16 @@ extern "C" {
|
|
| 38 |
typedef void * ggml_backend_buffer_context_t;
|
| 39 |
|
| 40 |
struct ggml_backend_buffer_i {
|
| 41 |
-
const char * (*GGML_CALL get_name)
|
| 42 |
-
void (*GGML_CALL free_buffer)(ggml_backend_buffer_t buffer);
|
| 43 |
-
void * (*GGML_CALL get_base)
|
| 44 |
-
void (*GGML_CALL init_tensor)(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
|
| 45 |
-
void (*GGML_CALL
|
| 46 |
-
void (*GGML_CALL
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
void (*GGML_CALL
|
|
|
|
| 50 |
};
|
| 51 |
|
| 52 |
struct ggml_backend_buffer {
|
|
|
|
| 38 |
typedef void * ggml_backend_buffer_context_t;
|
| 39 |
|
| 40 |
struct ggml_backend_buffer_i {
|
| 41 |
+
const char * (*GGML_CALL get_name) (ggml_backend_buffer_t buffer);
|
| 42 |
+
void (*GGML_CALL free_buffer) (ggml_backend_buffer_t buffer);
|
| 43 |
+
void * (*GGML_CALL get_base) (ggml_backend_buffer_t buffer);
|
| 44 |
+
void (*GGML_CALL init_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
|
| 45 |
+
void (*GGML_CALL memset_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size);
|
| 46 |
+
void (*GGML_CALL set_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
|
| 47 |
+
void (*GGML_CALL get_tensor) (ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
|
| 48 |
+
bool (*GGML_CALL cpy_tensor) (ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst); // dst is in the buffer, src may be in any buffer
|
| 49 |
+
void (*GGML_CALL clear) (ggml_backend_buffer_t buffer, uint8_t value);
|
| 50 |
+
void (*GGML_CALL reset) (ggml_backend_buffer_t buffer); // reset any internal state due to tensor initialization, such as tensor extras
|
| 51 |
};
|
| 52 |
|
| 53 |
struct ggml_backend_buffer {
|
|
@@ -246,6 +246,22 @@ GGML_CALL void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void *
|
|
| 246 |
buf->iface.get_tensor(buf, tensor, data, offset, size);
|
| 247 |
}
|
| 248 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 249 |
void ggml_backend_synchronize(ggml_backend_t backend) {
|
| 250 |
if (backend->iface.synchronize == NULL) {
|
| 251 |
return;
|
|
@@ -569,6 +585,12 @@ GGML_CALL static void ggml_backend_cpu_buffer_free_buffer(ggml_backend_buffer_t
|
|
| 569 |
free(buffer->context);
|
| 570 |
}
|
| 571 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 572 |
GGML_CALL static void ggml_backend_cpu_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
|
| 573 |
memcpy((char *)tensor->data + offset, data, size);
|
| 574 |
|
|
@@ -600,6 +622,7 @@ static struct ggml_backend_buffer_i cpu_backend_buffer_i = {
|
|
| 600 |
/* .free_buffer = */ ggml_backend_cpu_buffer_free_buffer,
|
| 601 |
/* .get_base = */ ggml_backend_cpu_buffer_get_base,
|
| 602 |
/* .init_tensor = */ NULL, // no initialization required
|
|
|
|
| 603 |
/* .set_tensor = */ ggml_backend_cpu_buffer_set_tensor,
|
| 604 |
/* .get_tensor = */ ggml_backend_cpu_buffer_get_tensor,
|
| 605 |
/* .cpy_tensor = */ ggml_backend_cpu_buffer_cpy_tensor,
|
|
@@ -613,6 +636,7 @@ static struct ggml_backend_buffer_i cpu_backend_buffer_i_from_ptr = {
|
|
| 613 |
/* .free_buffer = */ NULL, // ptr is not owned by the buffer, so it does not need to be freed
|
| 614 |
/* .get_base = */ ggml_backend_cpu_buffer_get_base,
|
| 615 |
/* .init_tensor = */ NULL, // no initialization required
|
|
|
|
| 616 |
/* .set_tensor = */ ggml_backend_cpu_buffer_set_tensor,
|
| 617 |
/* .get_tensor = */ ggml_backend_cpu_buffer_get_tensor,
|
| 618 |
/* .cpy_tensor = */ ggml_backend_cpu_buffer_cpy_tensor,
|
|
@@ -980,6 +1004,7 @@ static struct ggml_backend_buffer_i ggml_backend_multi_buffer_context_interface(
|
|
| 980 |
/* .free_buffer = */ ggml_backend_multi_buffer_free_buffer,
|
| 981 |
/* .get_base = */ NULL,
|
| 982 |
/* .init_tensor = */ NULL,
|
|
|
|
| 983 |
/* .set_tensor = */ NULL,
|
| 984 |
/* .get_tensor = */ NULL,
|
| 985 |
/* .cpy_tensor = */ NULL,
|
|
|
|
| 246 |
buf->iface.get_tensor(buf, tensor, data, offset, size);
|
| 247 |
}
|
| 248 |
|
| 249 |
+
GGML_API GGML_CALL void ggml_backend_tensor_memset(struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
|
| 250 |
+
ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
|
| 251 |
+
|
| 252 |
+
GGML_ASSERT(buf != NULL && "tensor buffer not set");
|
| 253 |
+
GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
|
| 254 |
+
GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
|
| 255 |
+
|
| 256 |
+
if (!size) {
|
| 257 |
+
return;
|
| 258 |
+
}
|
| 259 |
+
|
| 260 |
+
GGML_ASSERT(buf->iface.memset_tensor != NULL && "memset not supported by backend buffer");
|
| 261 |
+
|
| 262 |
+
buf->iface.memset_tensor(buf, tensor, value, offset, size);
|
| 263 |
+
}
|
| 264 |
+
|
| 265 |
void ggml_backend_synchronize(ggml_backend_t backend) {
|
| 266 |
if (backend->iface.synchronize == NULL) {
|
| 267 |
return;
|
|
|
|
| 585 |
free(buffer->context);
|
| 586 |
}
|
| 587 |
|
| 588 |
+
GGML_CALL static void ggml_backend_cpu_buffer_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
|
| 589 |
+
memset((char *)tensor->data + offset, value, size);
|
| 590 |
+
|
| 591 |
+
GGML_UNUSED(buffer);
|
| 592 |
+
}
|
| 593 |
+
|
| 594 |
GGML_CALL static void ggml_backend_cpu_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
|
| 595 |
memcpy((char *)tensor->data + offset, data, size);
|
| 596 |
|
|
|
|
| 622 |
/* .free_buffer = */ ggml_backend_cpu_buffer_free_buffer,
|
| 623 |
/* .get_base = */ ggml_backend_cpu_buffer_get_base,
|
| 624 |
/* .init_tensor = */ NULL, // no initialization required
|
| 625 |
+
/* .memset_tensor = */ ggml_backend_cpu_buffer_memset_tensor,
|
| 626 |
/* .set_tensor = */ ggml_backend_cpu_buffer_set_tensor,
|
| 627 |
/* .get_tensor = */ ggml_backend_cpu_buffer_get_tensor,
|
| 628 |
/* .cpy_tensor = */ ggml_backend_cpu_buffer_cpy_tensor,
|
|
|
|
| 636 |
/* .free_buffer = */ NULL, // ptr is not owned by the buffer, so it does not need to be freed
|
| 637 |
/* .get_base = */ ggml_backend_cpu_buffer_get_base,
|
| 638 |
/* .init_tensor = */ NULL, // no initialization required
|
| 639 |
+
/* .memset_tensor = */ ggml_backend_cpu_buffer_memset_tensor,
|
| 640 |
/* .set_tensor = */ ggml_backend_cpu_buffer_set_tensor,
|
| 641 |
/* .get_tensor = */ ggml_backend_cpu_buffer_get_tensor,
|
| 642 |
/* .cpy_tensor = */ ggml_backend_cpu_buffer_cpy_tensor,
|
|
|
|
| 1004 |
/* .free_buffer = */ ggml_backend_multi_buffer_free_buffer,
|
| 1005 |
/* .get_base = */ NULL,
|
| 1006 |
/* .init_tensor = */ NULL,
|
| 1007 |
+
/* .memset_tensor = */ NULL,
|
| 1008 |
/* .set_tensor = */ NULL,
|
| 1009 |
/* .get_tensor = */ NULL,
|
| 1010 |
/* .cpy_tensor = */ NULL,
|
|
@@ -1036,6 +1036,7 @@ static ggml_backend_buffer_i ggml_backend_cann_buffer_interface = {
|
|
| 1036 |
/* .free_buffer = */ ggml_backend_cann_buffer_free_buffer,
|
| 1037 |
/* .get_base = */ ggml_backend_cann_buffer_get_base,
|
| 1038 |
/* .init_tensor = */ ggml_backend_cann_buffer_init_tensor,
|
|
|
|
| 1039 |
/* .set_tensor = */ ggml_backend_cann_buffer_set_tensor,
|
| 1040 |
/* .get_tensor = */ ggml_backend_cann_buffer_get_tensor,
|
| 1041 |
/* .cpy_tensor = */ ggml_backend_cann_buffer_cpy_tensor,
|
|
|
|
| 1036 |
/* .free_buffer = */ ggml_backend_cann_buffer_free_buffer,
|
| 1037 |
/* .get_base = */ ggml_backend_cann_buffer_get_base,
|
| 1038 |
/* .init_tensor = */ ggml_backend_cann_buffer_init_tensor,
|
| 1039 |
+
/* .memset_tensor = */ NULL,
|
| 1040 |
/* .set_tensor = */ ggml_backend_cann_buffer_set_tensor,
|
| 1041 |
/* .get_tensor = */ ggml_backend_cann_buffer_get_tensor,
|
| 1042 |
/* .cpy_tensor = */ ggml_backend_cann_buffer_cpy_tensor,
|
|
@@ -21,6 +21,8 @@
|
|
| 21 |
#include "ggml-cuda/mmq.cuh"
|
| 22 |
#include "ggml-cuda/mmvq.cuh"
|
| 23 |
#include "ggml-cuda/norm.cuh"
|
|
|
|
|
|
|
| 24 |
#include "ggml-cuda/pad.cuh"
|
| 25 |
#include "ggml-cuda/pool2d.cuh"
|
| 26 |
#include "ggml-cuda/quantize.cuh"
|
|
@@ -493,6 +495,14 @@ GGML_CALL static void ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t
|
|
| 493 |
}
|
| 494 |
}
|
| 495 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 496 |
GGML_CALL static void ggml_backend_cuda_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
|
| 497 |
ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
|
| 498 |
|
|
@@ -544,6 +554,7 @@ static ggml_backend_buffer_i ggml_backend_cuda_buffer_interface = {
|
|
| 544 |
/* .free_buffer = */ ggml_backend_cuda_buffer_free_buffer,
|
| 545 |
/* .get_base = */ ggml_backend_cuda_buffer_get_base,
|
| 546 |
/* .init_tensor = */ ggml_backend_cuda_buffer_init_tensor,
|
|
|
|
| 547 |
/* .set_tensor = */ ggml_backend_cuda_buffer_set_tensor,
|
| 548 |
/* .get_tensor = */ ggml_backend_cuda_buffer_get_tensor,
|
| 549 |
/* .cpy_tensor = */ ggml_backend_cuda_buffer_cpy_tensor,
|
|
@@ -860,6 +871,7 @@ static struct ggml_backend_buffer_i ggml_backend_cuda_split_buffer_interface = {
|
|
| 860 |
/* .free_buffer = */ ggml_backend_cuda_split_buffer_free_buffer,
|
| 861 |
/* .get_base = */ ggml_backend_cuda_split_buffer_get_base,
|
| 862 |
/* .init_tensor = */ ggml_backend_cuda_split_buffer_init_tensor,
|
|
|
|
| 863 |
/* .set_tensor = */ ggml_backend_cuda_split_buffer_set_tensor,
|
| 864 |
/* .get_tensor = */ ggml_backend_cuda_split_buffer_get_tensor,
|
| 865 |
/* .cpy_tensor = */ NULL,
|
|
@@ -2168,6 +2180,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
|
| 2168 |
case GGML_OP_REPEAT:
|
| 2169 |
ggml_cuda_op_repeat(ctx, dst);
|
| 2170 |
break;
|
|
|
|
|
|
|
|
|
|
| 2171 |
case GGML_OP_GET_ROWS:
|
| 2172 |
ggml_cuda_op_get_rows(ctx, dst);
|
| 2173 |
break;
|
|
@@ -2201,6 +2216,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
|
| 2201 |
case GGML_UNARY_OP_NEG:
|
| 2202 |
ggml_cuda_op_neg(ctx, dst);
|
| 2203 |
break;
|
|
|
|
|
|
|
|
|
|
| 2204 |
case GGML_UNARY_OP_GELU:
|
| 2205 |
ggml_cuda_op_gelu(ctx, dst);
|
| 2206 |
break;
|
|
@@ -2267,6 +2285,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
|
| 2267 |
case GGML_OP_MUL_MAT_ID:
|
| 2268 |
ggml_cuda_mul_mat_id(ctx, dst);
|
| 2269 |
break;
|
|
|
|
|
|
|
|
|
|
| 2270 |
case GGML_OP_SCALE:
|
| 2271 |
ggml_cuda_op_scale(ctx, dst);
|
| 2272 |
break;
|
|
@@ -2324,6 +2345,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
|
| 2324 |
case GGML_OP_CROSS_ENTROPY_LOSS:
|
| 2325 |
ggml_cuda_cross_entropy_loss(ctx, dst);
|
| 2326 |
break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2327 |
default:
|
| 2328 |
return false;
|
| 2329 |
}
|
|
@@ -2757,6 +2784,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
|
| 2757 |
case GGML_OP_UNARY:
|
| 2758 |
switch (ggml_get_unary_op(op)) {
|
| 2759 |
case GGML_UNARY_OP_NEG:
|
|
|
|
| 2760 |
case GGML_UNARY_OP_GELU:
|
| 2761 |
case GGML_UNARY_OP_SILU:
|
| 2762 |
case GGML_UNARY_OP_RELU:
|
|
@@ -2809,6 +2837,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
|
| 2809 |
return false;
|
| 2810 |
}
|
| 2811 |
} break;
|
|
|
|
|
|
|
| 2812 |
case GGML_OP_GET_ROWS:
|
| 2813 |
{
|
| 2814 |
switch (op->src[0]->type) {
|
|
@@ -2865,6 +2895,12 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
|
| 2865 |
} break;
|
| 2866 |
case GGML_OP_DUP:
|
| 2867 |
case GGML_OP_REPEAT:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2868 |
case GGML_OP_CONCAT:
|
| 2869 |
{
|
| 2870 |
ggml_type src0_type = op->src[0]->type;
|
|
@@ -2931,9 +2967,11 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
|
| 2931 |
}
|
| 2932 |
return ggml_cuda_info().devices[cuda_ctx->device].cc >= CC_VOLTA &&
|
| 2933 |
op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
|
|
|
|
| 2934 |
case GGML_OP_CROSS_ENTROPY_LOSS:
|
|
|
|
|
|
|
| 2935 |
return true;
|
| 2936 |
-
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
| 2937 |
default:
|
| 2938 |
return false;
|
| 2939 |
}
|
|
|
|
| 21 |
#include "ggml-cuda/mmq.cuh"
|
| 22 |
#include "ggml-cuda/mmvq.cuh"
|
| 23 |
#include "ggml-cuda/norm.cuh"
|
| 24 |
+
#include "ggml-cuda/opt-step-adamw.cuh"
|
| 25 |
+
#include "ggml-cuda/out-prod.cuh"
|
| 26 |
#include "ggml-cuda/pad.cuh"
|
| 27 |
#include "ggml-cuda/pool2d.cuh"
|
| 28 |
#include "ggml-cuda/quantize.cuh"
|
|
|
|
| 495 |
}
|
| 496 |
}
|
| 497 |
|
| 498 |
+
GGML_CALL static void ggml_backend_cuda_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
|
| 499 |
+
ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
|
| 500 |
+
|
| 501 |
+
ggml_cuda_set_device(ctx->device);
|
| 502 |
+
CUDA_CHECK(cudaMemsetAsync((char *)tensor->data + offset, value, size, cudaStreamPerThread));
|
| 503 |
+
CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
|
| 504 |
+
}
|
| 505 |
+
|
| 506 |
GGML_CALL static void ggml_backend_cuda_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
|
| 507 |
ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
|
| 508 |
|
|
|
|
| 554 |
/* .free_buffer = */ ggml_backend_cuda_buffer_free_buffer,
|
| 555 |
/* .get_base = */ ggml_backend_cuda_buffer_get_base,
|
| 556 |
/* .init_tensor = */ ggml_backend_cuda_buffer_init_tensor,
|
| 557 |
+
/* .memset_tensor = */ ggml_backend_cuda_buffer_memset_tensor,
|
| 558 |
/* .set_tensor = */ ggml_backend_cuda_buffer_set_tensor,
|
| 559 |
/* .get_tensor = */ ggml_backend_cuda_buffer_get_tensor,
|
| 560 |
/* .cpy_tensor = */ ggml_backend_cuda_buffer_cpy_tensor,
|
|
|
|
| 871 |
/* .free_buffer = */ ggml_backend_cuda_split_buffer_free_buffer,
|
| 872 |
/* .get_base = */ ggml_backend_cuda_split_buffer_get_base,
|
| 873 |
/* .init_tensor = */ ggml_backend_cuda_split_buffer_init_tensor,
|
| 874 |
+
/* .memset_tensor = */ NULL,
|
| 875 |
/* .set_tensor = */ ggml_backend_cuda_split_buffer_set_tensor,
|
| 876 |
/* .get_tensor = */ ggml_backend_cuda_split_buffer_get_tensor,
|
| 877 |
/* .cpy_tensor = */ NULL,
|
|
|
|
| 2180 |
case GGML_OP_REPEAT:
|
| 2181 |
ggml_cuda_op_repeat(ctx, dst);
|
| 2182 |
break;
|
| 2183 |
+
case GGML_OP_REPEAT_BACK:
|
| 2184 |
+
ggml_cuda_op_repeat_back(ctx, dst);
|
| 2185 |
+
break;
|
| 2186 |
case GGML_OP_GET_ROWS:
|
| 2187 |
ggml_cuda_op_get_rows(ctx, dst);
|
| 2188 |
break;
|
|
|
|
| 2216 |
case GGML_UNARY_OP_NEG:
|
| 2217 |
ggml_cuda_op_neg(ctx, dst);
|
| 2218 |
break;
|
| 2219 |
+
case GGML_UNARY_OP_STEP:
|
| 2220 |
+
ggml_cuda_op_step(ctx, dst);
|
| 2221 |
+
break;
|
| 2222 |
case GGML_UNARY_OP_GELU:
|
| 2223 |
ggml_cuda_op_gelu(ctx, dst);
|
| 2224 |
break;
|
|
|
|
| 2285 |
case GGML_OP_MUL_MAT_ID:
|
| 2286 |
ggml_cuda_mul_mat_id(ctx, dst);
|
| 2287 |
break;
|
| 2288 |
+
case GGML_OP_OUT_PROD:
|
| 2289 |
+
ggml_cuda_out_prod(ctx, dst);
|
| 2290 |
+
break;
|
| 2291 |
case GGML_OP_SCALE:
|
| 2292 |
ggml_cuda_op_scale(ctx, dst);
|
| 2293 |
break;
|
|
|
|
| 2345 |
case GGML_OP_CROSS_ENTROPY_LOSS:
|
| 2346 |
ggml_cuda_cross_entropy_loss(ctx, dst);
|
| 2347 |
break;
|
| 2348 |
+
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
| 2349 |
+
ggml_cuda_cross_entropy_loss_back(ctx, dst);
|
| 2350 |
+
break;
|
| 2351 |
+
case GGML_OP_OPT_STEP_ADAMW:
|
| 2352 |
+
ggml_cuda_opt_step_adamw(ctx, dst);
|
| 2353 |
+
break;
|
| 2354 |
default:
|
| 2355 |
return false;
|
| 2356 |
}
|
|
|
|
| 2784 |
case GGML_OP_UNARY:
|
| 2785 |
switch (ggml_get_unary_op(op)) {
|
| 2786 |
case GGML_UNARY_OP_NEG:
|
| 2787 |
+
case GGML_UNARY_OP_STEP:
|
| 2788 |
case GGML_UNARY_OP_GELU:
|
| 2789 |
case GGML_UNARY_OP_SILU:
|
| 2790 |
case GGML_UNARY_OP_RELU:
|
|
|
|
| 2837 |
return false;
|
| 2838 |
}
|
| 2839 |
} break;
|
| 2840 |
+
case GGML_OP_OUT_PROD:
|
| 2841 |
+
return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1;
|
| 2842 |
case GGML_OP_GET_ROWS:
|
| 2843 |
{
|
| 2844 |
switch (op->src[0]->type) {
|
|
|
|
| 2895 |
} break;
|
| 2896 |
case GGML_OP_DUP:
|
| 2897 |
case GGML_OP_REPEAT:
|
| 2898 |
+
{
|
| 2899 |
+
ggml_type src0_type = op->src[0]->type;
|
| 2900 |
+
return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
|
| 2901 |
+
} break;
|
| 2902 |
+
case GGML_OP_REPEAT_BACK:
|
| 2903 |
+
return op->type == GGML_TYPE_F32 && op->src[0]->ne[3] == 1;
|
| 2904 |
case GGML_OP_CONCAT:
|
| 2905 |
{
|
| 2906 |
ggml_type src0_type = op->src[0]->type;
|
|
|
|
| 2967 |
}
|
| 2968 |
return ggml_cuda_info().devices[cuda_ctx->device].cc >= CC_VOLTA &&
|
| 2969 |
op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
|
| 2970 |
+
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
| 2971 |
case GGML_OP_CROSS_ENTROPY_LOSS:
|
| 2972 |
+
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
| 2973 |
+
case GGML_OP_OPT_STEP_ADAMW:
|
| 2974 |
return true;
|
|
|
|
| 2975 |
default:
|
| 2976 |
return false;
|
| 2977 |
}
|
|
@@ -1,4 +1,5 @@
|
|
| 1 |
#include "binbcast.cuh"
|
|
|
|
| 2 |
|
| 3 |
static __device__ __forceinline__ float op_repeat(const float a, const float b) {
|
| 4 |
return b;
|
|
@@ -90,6 +91,30 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * s
|
|
| 90 |
dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
|
| 91 |
}
|
| 92 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
template<float (*bin_op)(const float, const float)>
|
| 94 |
struct bin_bcast_cuda {
|
| 95 |
template<typename src0_t, typename src1_t, typename dst_t>
|
|
@@ -247,6 +272,16 @@ struct bin_bcast_cuda {
|
|
| 247 |
}
|
| 248 |
};
|
| 249 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 250 |
template<class op>
|
| 251 |
static void ggml_cuda_op_bin_bcast(
|
| 252 |
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
|
@@ -286,3 +321,35 @@ void ggml_cuda_op_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
| 286 |
void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 287 |
ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_div>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
|
| 288 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
#include "binbcast.cuh"
|
| 2 |
+
#include <cstdint>
|
| 3 |
|
| 4 |
static __device__ __forceinline__ float op_repeat(const float a, const float b) {
|
| 5 |
return b;
|
|
|
|
| 91 |
dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
|
| 92 |
}
|
| 93 |
|
| 94 |
+
template <typename T>
|
| 95 |
+
static __global__ void k_repeat_back(
|
| 96 |
+
const T * __restrict__ src, T * __restrict__ dst, const int64_t ne00, const int64_t ne01, const int64_t ne02,
|
| 97 |
+
const int64_t ne0, const int64_t ne1, const int64_t ne2) {
|
| 98 |
+
|
| 99 |
+
const int64_t tid0 = (int64_t) blockIdx.x*blockDim.x + threadIdx.x;
|
| 100 |
+
const int64_t tid1 = (int64_t) blockIdx.y*blockDim.y + threadIdx.y;
|
| 101 |
+
const int64_t tid2 = (int64_t) blockIdx.z*blockDim.z + threadIdx.z;
|
| 102 |
+
|
| 103 |
+
if (tid0 >= ne0) {
|
| 104 |
+
return;
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
T sum = 0;
|
| 108 |
+
for (int64_t i2 = tid2; i2 < ne02; i2 += ne2) {
|
| 109 |
+
for (int64_t i1 = tid1; i1 < ne01; i1 += ne1) {
|
| 110 |
+
for (int64_t i0 = tid0; i0 < ne00; i0 += ne0) {
|
| 111 |
+
sum += src[i2*ne01*ne00 + i1*ne00 + i0];
|
| 112 |
+
}
|
| 113 |
+
}
|
| 114 |
+
}
|
| 115 |
+
dst[tid2*ne1*ne0 + tid1*ne0 + tid0] = sum;
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
template<float (*bin_op)(const float, const float)>
|
| 119 |
struct bin_bcast_cuda {
|
| 120 |
template<typename src0_t, typename src1_t, typename dst_t>
|
|
|
|
| 272 |
}
|
| 273 |
};
|
| 274 |
|
| 275 |
+
template <typename T>
|
| 276 |
+
static void repeat_back_cuda(
|
| 277 |
+
const T * src, T * dst, const int64_t ne00, const int64_t ne01, const int64_t ne02,
|
| 278 |
+
const int64_t ne0, const int64_t ne1, const int64_t ne2, cudaStream_t stream) {
|
| 279 |
+
|
| 280 |
+
const dim3 block_dims(WARP_SIZE, 1, 1);
|
| 281 |
+
const dim3 block_nums((ne0 + WARP_SIZE - 1) / WARP_SIZE, ne1, ne2);
|
| 282 |
+
k_repeat_back<T><<<block_nums, block_dims, 0, stream>>>(src, dst, ne00, ne01, ne02, ne0, ne1, ne2);
|
| 283 |
+
}
|
| 284 |
+
|
| 285 |
template<class op>
|
| 286 |
static void ggml_cuda_op_bin_bcast(
|
| 287 |
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
|
|
|
| 321 |
void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 322 |
ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_div>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
|
| 323 |
}
|
| 324 |
+
|
| 325 |
+
void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 326 |
+
const ggml_tensor * src0 = dst->src[0];
|
| 327 |
+
|
| 328 |
+
GGML_ASSERT(src0->type == dst->type);
|
| 329 |
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
| 330 |
+
GGML_ASSERT(ggml_is_contiguous(dst));
|
| 331 |
+
GGML_ASSERT(ggml_can_repeat(dst, src0));
|
| 332 |
+
|
| 333 |
+
cudaStream_t stream = ctx.stream();
|
| 334 |
+
|
| 335 |
+
const int64_t ne00 = src0->ne[0];
|
| 336 |
+
const int64_t ne01 = src0->ne[1];
|
| 337 |
+
const int64_t ne02 = src0->ne[2];
|
| 338 |
+
GGML_ASSERT(src0->ne[3] == 1);
|
| 339 |
+
|
| 340 |
+
const int64_t ne0 = dst->ne[0];
|
| 341 |
+
const int64_t ne1 = dst->ne[1];
|
| 342 |
+
const int64_t ne2 = dst->ne[2];
|
| 343 |
+
GGML_ASSERT(dst->ne[3] == 1);
|
| 344 |
+
|
| 345 |
+
switch (dst->type) {
|
| 346 |
+
case GGML_TYPE_F32: {
|
| 347 |
+
const float * src0_d = (const float *) src0->data;
|
| 348 |
+
float * dst_d = (float *) dst->data;
|
| 349 |
+
repeat_back_cuda<float>(src0_d, dst_d, ne00, ne01, ne02, ne0, ne1, ne2, stream);
|
| 350 |
+
} break;
|
| 351 |
+
default: {
|
| 352 |
+
GGML_ASSERT(false);
|
| 353 |
+
} break;
|
| 354 |
+
}
|
| 355 |
+
}
|
|
@@ -5,3 +5,5 @@ void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
|
| 5 |
void ggml_cuda_op_sub(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
| 6 |
void ggml_cuda_op_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
| 7 |
void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
|
|
|
|
|
|
|
|
| 5 |
void ggml_cuda_op_sub(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
| 6 |
void ggml_cuda_op_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
| 7 |
void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
| 8 |
+
|
| 9 |
+
void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
|
@@ -71,6 +71,32 @@ static __global__ void cross_entropy_loss_f32(const float * logits, const float
|
|
| 71 |
dst[blockIdx.x] = loss;
|
| 72 |
}
|
| 73 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 75 |
const ggml_tensor * src0 = dst->src[0];
|
| 76 |
const ggml_tensor * src1 = dst->src[1];
|
|
@@ -104,3 +130,37 @@ void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor *
|
|
| 104 |
// Combine results from individual blocks:
|
| 105 |
sum_f32_cuda(pool, dst_tmp.ptr, dst_d, blocks_num.x, stream);
|
| 106 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
dst[blockIdx.x] = loss;
|
| 72 |
}
|
| 73 |
|
| 74 |
+
static __global__ void cross_entropy_loss_back_f32(const float * logits, const float * labels, const float * loss, float * dst, const int nclasses) {
|
| 75 |
+
extern __shared__ float tmp[];
|
| 76 |
+
|
| 77 |
+
float maxval = -INFINITY;
|
| 78 |
+
for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
|
| 79 |
+
const float val = logits[blockIdx.x*nclasses + i];
|
| 80 |
+
maxval = fmaxf(maxval, val);
|
| 81 |
+
tmp[i] = val;
|
| 82 |
+
}
|
| 83 |
+
maxval = warp_reduce_max(maxval);
|
| 84 |
+
|
| 85 |
+
float sum = 0.0f;
|
| 86 |
+
for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
|
| 87 |
+
const float val = expf(tmp[i] - maxval);
|
| 88 |
+
sum += val;
|
| 89 |
+
tmp[i] = val;
|
| 90 |
+
}
|
| 91 |
+
sum = warp_reduce_sum(sum);
|
| 92 |
+
const float sm_scale = 1.0f/sum;
|
| 93 |
+
|
| 94 |
+
const float d_by_nrows = *loss/gridDim.x;
|
| 95 |
+
for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
|
| 96 |
+
dst[blockIdx.x*nclasses + i] = (tmp[i]*sm_scale - labels[blockIdx.x*nclasses + i])*d_by_nrows;
|
| 97 |
+
}
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 101 |
const ggml_tensor * src0 = dst->src[0];
|
| 102 |
const ggml_tensor * src1 = dst->src[1];
|
|
|
|
| 130 |
// Combine results from individual blocks:
|
| 131 |
sum_f32_cuda(pool, dst_tmp.ptr, dst_d, blocks_num.x, stream);
|
| 132 |
}
|
| 133 |
+
|
| 134 |
+
void ggml_cuda_cross_entropy_loss_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 135 |
+
const ggml_tensor * src0 = dst->src[0];
|
| 136 |
+
const ggml_tensor * src1 = dst->src[1];
|
| 137 |
+
const ggml_tensor * opt0 = dst->src[2];
|
| 138 |
+
|
| 139 |
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
| 140 |
+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
| 141 |
+
GGML_ASSERT(opt0->type == GGML_TYPE_F32);
|
| 142 |
+
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
| 143 |
+
|
| 144 |
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
| 145 |
+
GGML_ASSERT(ggml_is_contiguous(src1));
|
| 146 |
+
GGML_ASSERT(ggml_is_contiguous(opt0));
|
| 147 |
+
GGML_ASSERT(ggml_is_contiguous(dst));
|
| 148 |
+
GGML_ASSERT(ggml_are_same_shape(src0, src1));
|
| 149 |
+
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
| 150 |
+
|
| 151 |
+
const int64_t ne00 = src0->ne[0];
|
| 152 |
+
const int64_t nrows = ggml_nrows(src0);
|
| 153 |
+
|
| 154 |
+
const float * src0_d = (const float *) src0->data;
|
| 155 |
+
const float * src1_d = (const float *) src1->data;
|
| 156 |
+
const float * opt0_d = (const float *) opt0->data;
|
| 157 |
+
float * dst_d = (float *) dst->data;
|
| 158 |
+
|
| 159 |
+
cudaStream_t stream = ctx.stream();
|
| 160 |
+
|
| 161 |
+
const dim3 blocks_dim(WARP_SIZE, 1, 1);
|
| 162 |
+
const dim3 blocks_num(nrows, 1, 1);
|
| 163 |
+
const int shmem = ne00*sizeof(float);
|
| 164 |
+
|
| 165 |
+
cross_entropy_loss_back_f32<<<blocks_num, blocks_dim, shmem, stream>>>(src0_d, src1_d, opt0_d, dst_d, ne00);
|
| 166 |
+
}
|
|
@@ -3,3 +3,5 @@
|
|
| 3 |
#define CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE 256
|
| 4 |
|
| 5 |
void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
|
|
|
|
|
|
|
|
| 3 |
#define CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE 256
|
| 4 |
|
| 5 |
void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
| 6 |
+
|
| 7 |
+
void ggml_cuda_cross_entropy_loss_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "opt-step-adamw.cuh"
|
| 2 |
+
|
| 3 |
+
#include <cstdint>
|
| 4 |
+
|
| 5 |
+
static __global__ void opt_step_adamw_f32(
|
| 6 |
+
float * __restrict__ x, const float * __restrict__ g, float * __restrict__ g_m, float * __restrict__ g_v, const int64_t k,
|
| 7 |
+
const float alpha, const float beta1, const float beta2, const float eps, const float wd,
|
| 8 |
+
const float beta1h, const float beta2h) {
|
| 9 |
+
|
| 10 |
+
const int64_t i = (int64_t) blockIdx.x*blockDim.x + threadIdx.x;
|
| 11 |
+
|
| 12 |
+
if (i >= k) {
|
| 13 |
+
return;
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
const float gi = g[i];
|
| 17 |
+
const float gmi = g_m[i]*beta1 + gi*(1.0f - beta1);
|
| 18 |
+
const float gvi = g_v[i]*beta2 + gi*gi*(1.0f - beta2);
|
| 19 |
+
|
| 20 |
+
g_m[i] = gmi;
|
| 21 |
+
g_v[i] = gvi;
|
| 22 |
+
|
| 23 |
+
const float mh = gmi*beta1h;
|
| 24 |
+
const float vh = sqrtf(gvi*beta2h) + eps;
|
| 25 |
+
|
| 26 |
+
x[i] = x[i]*(1.0f - alpha*wd) - mh/vh;
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
static void opt_step_adamw_f32_cuda(
|
| 30 |
+
float * x, const float * g, float * g_m, float * g_v, const int64_t k,
|
| 31 |
+
const float alpha, const float beta1, const float beta2, const float eps, const float wd,
|
| 32 |
+
const float beta1h, const float beta2h, cudaStream_t stream) {
|
| 33 |
+
|
| 34 |
+
const dim3 block_dims(CUDA_OPT_STEP_ADAMW_BLOCK_SIZE, 1, 1);
|
| 35 |
+
const dim3 block_nums((k + CUDA_OPT_STEP_ADAMW_BLOCK_SIZE - 1) / CUDA_OPT_STEP_ADAMW_BLOCK_SIZE, 1, 1);
|
| 36 |
+
opt_step_adamw_f32<<<block_nums, block_dims, 0, stream>>>(x, g, g_m, g_v, k, alpha, beta1, beta2, eps, wd, beta1h, beta2h);
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
void ggml_cuda_opt_step_adamw(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 40 |
+
const ggml_tensor * src0 = dst->src[0];
|
| 41 |
+
const ggml_tensor * src0_grad = dst->src[1];
|
| 42 |
+
const ggml_tensor * src0_grad_m = dst->src[2];
|
| 43 |
+
const ggml_tensor * src0_grad_v = dst->src[3];
|
| 44 |
+
|
| 45 |
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
| 46 |
+
GGML_ASSERT(src0_grad->type == GGML_TYPE_F32);
|
| 47 |
+
GGML_ASSERT(src0_grad_m->type == GGML_TYPE_F32);
|
| 48 |
+
GGML_ASSERT(src0_grad_v->type == GGML_TYPE_F32);
|
| 49 |
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
| 50 |
+
GGML_ASSERT(ggml_is_contiguous(src0_grad));
|
| 51 |
+
GGML_ASSERT(ggml_is_contiguous(src0_grad_m));
|
| 52 |
+
GGML_ASSERT(ggml_is_contiguous(src0_grad_v));
|
| 53 |
+
GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
|
| 54 |
+
GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_m));
|
| 55 |
+
GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_v));
|
| 56 |
+
|
| 57 |
+
float * src0_d = (float *) src0->data;
|
| 58 |
+
const float * src0_grad_d = (const float *) src0_grad->data;
|
| 59 |
+
float * src0_grad_m_d = (float *) src0_grad_m->data;
|
| 60 |
+
float * src0_grad_v_d = (float *) src0_grad_v->data;
|
| 61 |
+
|
| 62 |
+
cudaStream_t stream = ctx.stream();
|
| 63 |
+
|
| 64 |
+
const int64_t ne = ggml_nelements(src0);
|
| 65 |
+
|
| 66 |
+
int64_t iter; memcpy(&iter, &dst->op_params[0], sizeof(int64_t));
|
| 67 |
+
float alpha; memcpy(&alpha, &dst->op_params[2], sizeof(float));
|
| 68 |
+
float beta1; memcpy(&beta1, &dst->op_params[3], sizeof(float));
|
| 69 |
+
float beta2; memcpy(&beta2, &dst->op_params[4], sizeof(float));
|
| 70 |
+
float eps; memcpy(&eps, &dst->op_params[5], sizeof(float));
|
| 71 |
+
float wd; memcpy(&wd, &dst->op_params[6], sizeof(float));
|
| 72 |
+
|
| 73 |
+
const float beta1h = alpha/(1.0f - powf(beta1, iter));
|
| 74 |
+
const float beta2h = 1.0f/(1.0f - powf(beta2, iter));
|
| 75 |
+
|
| 76 |
+
opt_step_adamw_f32_cuda(src0_d, src0_grad_d, src0_grad_m_d, src0_grad_v_d, ne, alpha, beta1, beta2, eps, wd, beta1h, beta2h, stream);
|
| 77 |
+
|
| 78 |
+
iter++;
|
| 79 |
+
memcpy(&dst->op_params[0], &iter, sizeof(int64_t));
|
| 80 |
+
}
|
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "common.cuh"
|
| 2 |
+
|
| 3 |
+
#define CUDA_OPT_STEP_ADAMW_BLOCK_SIZE 256
|
| 4 |
+
|
| 5 |
+
void ggml_cuda_opt_step_adamw(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "out-prod.cuh"
|
| 2 |
+
#include "vendors/cuda.h"
|
| 3 |
+
|
| 4 |
+
#include <cstdint>
|
| 5 |
+
|
| 6 |
+
void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 7 |
+
const ggml_tensor * src0 = dst->src[0];
|
| 8 |
+
const ggml_tensor * src1 = dst->src[1];
|
| 9 |
+
|
| 10 |
+
GGML_TENSOR_BINARY_OP_LOCALS
|
| 11 |
+
|
| 12 |
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
| 13 |
+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
| 14 |
+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
| 15 |
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
| 16 |
+
GGML_ASSERT(ggml_is_contiguous(dst));
|
| 17 |
+
|
| 18 |
+
GGML_ASSERT(ne01 == ne11);
|
| 19 |
+
GGML_ASSERT(ne0 == ne00);
|
| 20 |
+
GGML_ASSERT(ne1 == ne10);
|
| 21 |
+
|
| 22 |
+
GGML_ASSERT(ne2 == src0->ne[2]);
|
| 23 |
+
GGML_ASSERT(ne2 == src1->ne[2]);
|
| 24 |
+
GGML_ASSERT(ne3 == src0->ne[3]);
|
| 25 |
+
GGML_ASSERT(ne3 == src1->ne[3]);
|
| 26 |
+
|
| 27 |
+
const float * src0_d = (const float *) src0->data;
|
| 28 |
+
const float * src1_d = (const float *) src1->data;
|
| 29 |
+
float * dst_d = (float *) dst->data;
|
| 30 |
+
|
| 31 |
+
cudaStream_t stream = ctx.stream();
|
| 32 |
+
cublasHandle_t handle = ctx.cublas_handle();
|
| 33 |
+
|
| 34 |
+
const float alpha = 1.0f;
|
| 35 |
+
const float beta = 0.0f;
|
| 36 |
+
|
| 37 |
+
GGML_ASSERT(ne2 == 1);
|
| 38 |
+
GGML_ASSERT(ne3 == 1);
|
| 39 |
+
CUBLAS_CHECK(cublasSetStream(handle, stream));
|
| 40 |
+
|
| 41 |
+
const bool src1_T = ggml_is_transposed(src1);
|
| 42 |
+
const cublasOperation_t src1_cublas_op = src1_T ? CUBLAS_OP_N : CUBLAS_OP_T;
|
| 43 |
+
const int64_t ldb = (src1_T ? nb10 : nb11) / sizeof(float);
|
| 44 |
+
GGML_ASSERT( (src1_T ? nb11 : nb10) == sizeof(float));
|
| 45 |
+
|
| 46 |
+
CUBLAS_CHECK(
|
| 47 |
+
cublasSgemm(handle, CUBLAS_OP_N, src1_cublas_op,
|
| 48 |
+
ne0, ne1, ne01,
|
| 49 |
+
&alpha, src0_d, ne00,
|
| 50 |
+
src1_d, ldb,
|
| 51 |
+
&beta, dst_d, ne0));
|
| 52 |
+
}
|
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "common.cuh"
|
| 2 |
+
|
| 3 |
+
void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
|
@@ -10,6 +10,16 @@ static __global__ void neg_f32(const float * x, float * dst, const int k) {
|
|
| 10 |
dst[i] = -x[i];
|
| 11 |
}
|
| 12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
static __global__ void gelu_f32(const float * x, float * dst, const int k) {
|
| 14 |
const float GELU_COEF_A = 0.044715f;
|
| 15 |
const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
|
@@ -134,6 +144,11 @@ static void neg_f32_cuda(const float * x, float * dst, const int k, cudaStream_t
|
|
| 134 |
neg_f32<<<num_blocks, CUDA_NEG_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
| 135 |
}
|
| 136 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
static void gelu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
|
| 138 |
const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
|
| 139 |
gelu_f32<<<num_blocks, CUDA_GELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
|
@@ -213,6 +228,20 @@ void ggml_cuda_op_neg(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
| 213 |
neg_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
|
| 214 |
}
|
| 215 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 216 |
void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 217 |
const ggml_tensor * src0 = dst->src[0];
|
| 218 |
const float * src0_d = (const float *)src0->data;
|
|
|
|
| 10 |
dst[i] = -x[i];
|
| 11 |
}
|
| 12 |
|
| 13 |
+
static __global__ void step_f32(const float * x, float * dst, const int k) {
|
| 14 |
+
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
| 15 |
+
|
| 16 |
+
if (i >= k) {
|
| 17 |
+
return;
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
dst[i] = x[i] > 0.0f;
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
static __global__ void gelu_f32(const float * x, float * dst, const int k) {
|
| 24 |
const float GELU_COEF_A = 0.044715f;
|
| 25 |
const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
|
|
|
| 144 |
neg_f32<<<num_blocks, CUDA_NEG_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
| 145 |
}
|
| 146 |
|
| 147 |
+
static void step_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
|
| 148 |
+
const int num_blocks = (k + CUDA_STEP_BLOCK_SIZE - 1) / CUDA_STEP_BLOCK_SIZE;
|
| 149 |
+
step_f32<<<num_blocks, CUDA_STEP_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
static void gelu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
|
| 153 |
const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
|
| 154 |
gelu_f32<<<num_blocks, CUDA_GELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
|
|
|
| 228 |
neg_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
|
| 229 |
}
|
| 230 |
|
| 231 |
+
void ggml_cuda_op_step(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 232 |
+
const ggml_tensor * src0 = dst->src[0];
|
| 233 |
+
const float * src0_d = (const float *)src0->data;
|
| 234 |
+
float * dst_d = (float *)dst->data;
|
| 235 |
+
cudaStream_t stream = ctx.stream();
|
| 236 |
+
|
| 237 |
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
| 238 |
+
|
| 239 |
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
| 240 |
+
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
| 241 |
+
|
| 242 |
+
step_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
|
| 243 |
+
}
|
| 244 |
+
|
| 245 |
void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 246 |
const ggml_tensor * src0 = dst->src[0];
|
| 247 |
const float * src0_d = (const float *)src0->data;
|
|
@@ -1,6 +1,7 @@
|
|
| 1 |
#include "common.cuh"
|
| 2 |
|
| 3 |
#define CUDA_NEG_BLOCK_SIZE 256
|
|
|
|
| 4 |
#define CUDA_GELU_BLOCK_SIZE 256
|
| 5 |
#define CUDA_SILU_BLOCK_SIZE 256
|
| 6 |
#define CUDA_TANH_BLOCK_SIZE 256
|
|
@@ -15,6 +16,8 @@
|
|
| 15 |
|
| 16 |
void ggml_cuda_op_neg(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
| 17 |
|
|
|
|
|
|
|
| 18 |
void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
| 19 |
|
| 20 |
void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
|
|
|
| 1 |
#include "common.cuh"
|
| 2 |
|
| 3 |
#define CUDA_NEG_BLOCK_SIZE 256
|
| 4 |
+
#define CUDA_STEP_BLOCK_SIZE 256
|
| 5 |
#define CUDA_GELU_BLOCK_SIZE 256
|
| 6 |
#define CUDA_SILU_BLOCK_SIZE 256
|
| 7 |
#define CUDA_TANH_BLOCK_SIZE 256
|
|
|
|
| 16 |
|
| 17 |
void ggml_cuda_op_neg(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
| 18 |
|
| 19 |
+
void ggml_cuda_op_step(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
| 20 |
+
|
| 21 |
void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
| 22 |
|
| 23 |
void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
|
@@ -1872,6 +1872,7 @@ static ggml_backend_buffer_i ggml_backend_kompute_buffer_i = {
|
|
| 1872 |
/* .free_buffer = */ ggml_backend_kompute_buffer_free_buffer,
|
| 1873 |
/* .get_base = */ ggml_backend_kompute_buffer_get_base,
|
| 1874 |
/* .init_tensor = */ NULL,
|
|
|
|
| 1875 |
/* .set_tensor = */ ggml_backend_kompute_buffer_set_tensor,
|
| 1876 |
/* .get_tensor = */ ggml_backend_kompute_buffer_get_tensor,
|
| 1877 |
/* .cpy_tensor = */ NULL,
|
|
|
|
| 1872 |
/* .free_buffer = */ ggml_backend_kompute_buffer_free_buffer,
|
| 1873 |
/* .get_base = */ ggml_backend_kompute_buffer_get_base,
|
| 1874 |
/* .init_tensor = */ NULL,
|
| 1875 |
+
/* .memset_tensor = */ NULL,
|
| 1876 |
/* .set_tensor = */ ggml_backend_kompute_buffer_set_tensor,
|
| 1877 |
/* .get_tensor = */ ggml_backend_kompute_buffer_get_tensor,
|
| 1878 |
/* .cpy_tensor = */ NULL,
|
|
@@ -3165,6 +3165,7 @@ static struct ggml_backend_buffer_i ggml_backend_metal_buffer_i = {
|
|
| 3165 |
/* .free_buffer = */ ggml_backend_metal_buffer_free_buffer,
|
| 3166 |
/* .get_base = */ ggml_backend_metal_buffer_get_base,
|
| 3167 |
/* .init_tensor = */ NULL,
|
|
|
|
| 3168 |
/* .set_tensor = */ ggml_backend_metal_buffer_set_tensor,
|
| 3169 |
/* .get_tensor = */ ggml_backend_metal_buffer_get_tensor,
|
| 3170 |
/* .cpy_tensor = */ ggml_backend_metal_buffer_cpy_tensor,
|
|
|
|
| 3165 |
/* .free_buffer = */ ggml_backend_metal_buffer_free_buffer,
|
| 3166 |
/* .get_base = */ ggml_backend_metal_buffer_get_base,
|
| 3167 |
/* .init_tensor = */ NULL,
|
| 3168 |
+
/* .memset_tensor = */ NULL,
|
| 3169 |
/* .set_tensor = */ ggml_backend_metal_buffer_set_tensor,
|
| 3170 |
/* .get_tensor = */ ggml_backend_metal_buffer_get_tensor,
|
| 3171 |
/* .cpy_tensor = */ ggml_backend_metal_buffer_cpy_tensor,
|
|
@@ -469,6 +469,7 @@ static ggml_backend_buffer_i ggml_backend_rpc_buffer_interface = {
|
|
| 469 |
/* .free_buffer = */ ggml_backend_rpc_buffer_free_buffer,
|
| 470 |
/* .get_base = */ ggml_backend_rpc_buffer_get_base,
|
| 471 |
/* .init_tensor = */ ggml_backend_rpc_buffer_init_tensor,
|
|
|
|
| 472 |
/* .set_tensor = */ ggml_backend_rpc_buffer_set_tensor,
|
| 473 |
/* .get_tensor = */ ggml_backend_rpc_buffer_get_tensor,
|
| 474 |
/* .cpy_tensor = */ ggml_backend_rpc_buffer_cpy_tensor,
|
|
|
|
| 469 |
/* .free_buffer = */ ggml_backend_rpc_buffer_free_buffer,
|
| 470 |
/* .get_base = */ ggml_backend_rpc_buffer_get_base,
|
| 471 |
/* .init_tensor = */ ggml_backend_rpc_buffer_init_tensor,
|
| 472 |
+
/* .memset_tensor = */ NULL,
|
| 473 |
/* .set_tensor = */ ggml_backend_rpc_buffer_set_tensor,
|
| 474 |
/* .get_tensor = */ ggml_backend_rpc_buffer_get_tensor,
|
| 475 |
/* .cpy_tensor = */ ggml_backend_rpc_buffer_cpy_tensor,
|
|
@@ -4318,6 +4318,7 @@ static struct ggml_backend_buffer_i ggml_backend_sycl_buffer_interface = {
|
|
| 4318 |
/* .free_buffer = */ ggml_backend_sycl_buffer_free_buffer,
|
| 4319 |
/* .get_base = */ ggml_backend_sycl_buffer_get_base,
|
| 4320 |
/* .init_tensor = */ ggml_backend_sycl_buffer_init_tensor,
|
|
|
|
| 4321 |
/* .set_tensor = */ ggml_backend_sycl_buffer_set_tensor,
|
| 4322 |
/* .get_tensor = */ ggml_backend_sycl_buffer_get_tensor,
|
| 4323 |
/* .cpy_tensor = */ ggml_backend_sycl_buffer_cpy_tensor,
|
|
|
|
| 4318 |
/* .free_buffer = */ ggml_backend_sycl_buffer_free_buffer,
|
| 4319 |
/* .get_base = */ ggml_backend_sycl_buffer_get_base,
|
| 4320 |
/* .init_tensor = */ ggml_backend_sycl_buffer_init_tensor,
|
| 4321 |
+
/* .memset_tensor = */ NULL,
|
| 4322 |
/* .set_tensor = */ ggml_backend_sycl_buffer_set_tensor,
|
| 4323 |
/* .get_tensor = */ ggml_backend_sycl_buffer_get_tensor,
|
| 4324 |
/* .cpy_tensor = */ ggml_backend_sycl_buffer_cpy_tensor,
|
|
@@ -6221,6 +6221,7 @@ static ggml_backend_buffer_i ggml_backend_vk_buffer_interface = {
|
|
| 6221 |
/* .free_buffer = */ ggml_backend_vk_buffer_free_buffer,
|
| 6222 |
/* .get_base = */ ggml_backend_vk_buffer_get_base,
|
| 6223 |
/* .init_tensor = */ ggml_backend_vk_buffer_init_tensor,
|
|
|
|
| 6224 |
/* .set_tensor = */ ggml_backend_vk_buffer_set_tensor,
|
| 6225 |
/* .get_tensor = */ ggml_backend_vk_buffer_get_tensor,
|
| 6226 |
/* .cpy_tensor = */ ggml_backend_vk_buffer_cpy_tensor,
|
|
|
|
| 6221 |
/* .free_buffer = */ ggml_backend_vk_buffer_free_buffer,
|
| 6222 |
/* .get_base = */ ggml_backend_vk_buffer_get_base,
|
| 6223 |
/* .init_tensor = */ ggml_backend_vk_buffer_init_tensor,
|
| 6224 |
+
/* .memset_tensor = */ NULL,
|
| 6225 |
/* .set_tensor = */ ggml_backend_vk_buffer_set_tensor,
|
| 6226 |
/* .get_tensor = */ ggml_backend_vk_buffer_get_tensor,
|
| 6227 |
/* .cpy_tensor = */ ggml_backend_vk_buffer_cpy_tensor,
|
|
@@ -1,6 +1,7 @@
|
|
| 1 |
#define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnings on Windows
|
| 2 |
#define _USE_MATH_DEFINES // For M_PI on MSVC
|
| 3 |
|
|
|
|
| 4 |
#include "ggml-impl.h"
|
| 5 |
#include "ggml-quants.h"
|
| 6 |
#include "ggml.h"
|
|
@@ -2977,9 +2978,10 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|
| 2977 |
|
| 2978 |
"CROSS_ENTROPY_LOSS",
|
| 2979 |
"CROSS_ENTROPY_LOSS_BACK",
|
|
|
|
| 2980 |
};
|
| 2981 |
|
| 2982 |
-
static_assert(GGML_OP_COUNT ==
|
| 2983 |
|
| 2984 |
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
| 2985 |
"none",
|
|
@@ -3070,9 +3072,10 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|
| 3070 |
|
| 3071 |
"cross_entropy_loss(x,y)",
|
| 3072 |
"cross_entropy_loss_back(x,y)",
|
|
|
|
| 3073 |
};
|
| 3074 |
|
| 3075 |
-
static_assert(GGML_OP_COUNT ==
|
| 3076 |
|
| 3077 |
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
| 3078 |
|
|
@@ -4079,7 +4082,11 @@ static void ggml_set_op_params_f32(struct ggml_tensor * tensor, uint32_t i, floa
|
|
| 4079 |
}
|
| 4080 |
|
| 4081 |
struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor) {
|
| 4082 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4083 |
return tensor;
|
| 4084 |
}
|
| 4085 |
|
|
@@ -8305,11 +8312,46 @@ struct ggml_tensor * ggml_cross_entropy_loss_back(
|
|
| 8305 |
return result;
|
| 8306 |
}
|
| 8307 |
|
| 8308 |
-
|
| 8309 |
|
| 8310 |
-
|
| 8311 |
struct ggml_context * ctx,
|
| 8312 |
-
struct ggml_tensor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8313 |
tensor->flags |= GGML_TENSOR_FLAG_PARAM;
|
| 8314 |
|
| 8315 |
GGML_ASSERT(tensor->grad == NULL);
|
|
@@ -8317,6 +8359,13 @@ void ggml_set_param(
|
|
| 8317 |
ggml_format_name(tensor->grad, "%s (grad)", tensor->name);
|
| 8318 |
}
|
| 8319 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8320 |
// ggml_compute_forward_dup
|
| 8321 |
|
| 8322 |
static void ggml_compute_forward_dup_same_cont(
|
|
@@ -17391,7 +17440,7 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
|
|
| 17391 |
const int64_t ir0 = dr*ith;
|
| 17392 |
const int64_t ir1 = MIN(ir0 + dr, nr);
|
| 17393 |
|
| 17394 |
-
float
|
| 17395 |
|
| 17396 |
for (int64_t i1 = ir0; i1 < ir1; i1++) {
|
| 17397 |
float * ds0 = (float *)((char *) dst->data + i1*dst->nb[1]);
|
|
@@ -17415,7 +17464,7 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
|
|
| 17415 |
|
| 17416 |
// grad(src0) = (softmax(src0) - src1) * grad(cross_entropy_loss(src0, src1)) / nr
|
| 17417 |
ggml_vec_sub_f32(nc, ds0, ds0, s1);
|
| 17418 |
-
ggml_vec_scale_f32(nc, ds0,
|
| 17419 |
|
| 17420 |
#ifndef NDEBUG
|
| 17421 |
for (int i = 0; i < nc; ++i) {
|
|
@@ -17444,6 +17493,94 @@ static void ggml_compute_forward_cross_entropy_loss_back(
|
|
| 17444 |
}
|
| 17445 |
}
|
| 17446 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17447 |
/////////////////////////////////
|
| 17448 |
|
| 17449 |
static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
|
|
@@ -17789,6 +17926,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|
| 17789 |
ggml_compute_forward_cross_entropy_loss_back(params, tensor);
|
| 17790 |
}
|
| 17791 |
break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17792 |
case GGML_OP_NONE:
|
| 17793 |
{
|
| 17794 |
// nop
|
|
@@ -17943,7 +18085,7 @@ void ggml_build_backward_gradient_checkpointing(
|
|
| 17943 |
struct ggml_tensor * * checkpoints,
|
| 17944 |
int n_checkpoints) {
|
| 17945 |
ggml_graph_cpy(gf, gb_tmp);
|
| 17946 |
-
ggml_build_backward_expand(ctx, gf, gb_tmp, true);
|
| 17947 |
|
| 17948 |
if (n_checkpoints <= 0) {
|
| 17949 |
ggml_graph_cpy(gb_tmp, gb);
|
|
@@ -17981,42 +18123,93 @@ void ggml_build_backward_gradient_checkpointing(
|
|
| 17981 |
ggml_hash_map_free(replacements);
|
| 17982 |
}
|
| 17983 |
|
| 17984 |
-
// functions to change gradients
|
| 17985 |
-
|
| 17986 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17987 |
if (ggml_hash_contains(zero_table, a)) {
|
| 17988 |
return b;
|
| 17989 |
-
} else {
|
| 17990 |
-
return ggml_add_impl(ctx, a, b, false);
|
| 17991 |
}
|
|
|
|
| 17992 |
}
|
| 17993 |
|
| 17994 |
-
static struct ggml_tensor * ggml_acc_or_set(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17995 |
if (ggml_hash_contains(zero_table, a)) {
|
| 17996 |
-
struct ggml_tensor * a_zero = ggml_scale(ctx, a, 0.0f);
|
| 17997 |
return ggml_acc_impl(ctx, a_zero, b, nb1, nb2, nb3, offset, false);
|
| 17998 |
-
} else {
|
| 17999 |
-
return ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, false);
|
| 18000 |
}
|
|
|
|
| 18001 |
}
|
| 18002 |
|
| 18003 |
-
static struct ggml_tensor * ggml_add1_or_set(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18004 |
if (ggml_hash_contains(zero_table, a)) {
|
| 18005 |
return ggml_repeat(ctx, b, a);
|
| 18006 |
-
} else {
|
| 18007 |
-
return ggml_add1_impl(ctx, a, b, false);
|
| 18008 |
}
|
|
|
|
| 18009 |
}
|
| 18010 |
|
| 18011 |
-
static struct ggml_tensor * ggml_sub_or_set(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18012 |
if (ggml_hash_contains(zero_table, a)) {
|
| 18013 |
return ggml_neg(ctx, b);
|
| 18014 |
-
} else {
|
| 18015 |
-
return ggml_sub_impl(ctx, a, b, false);
|
| 18016 |
}
|
|
|
|
| 18017 |
}
|
| 18018 |
|
| 18019 |
-
static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, struct ggml_hash_set * zero_table) {
|
| 18020 |
struct ggml_tensor * src0 = tensor->src[0];
|
| 18021 |
struct ggml_tensor * src1 = tensor->src[1];
|
| 18022 |
struct ggml_tensor * src2 = tensor->src[2];
|
|
@@ -18025,38 +18218,38 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
| 18025 |
case GGML_OP_DUP:
|
| 18026 |
{
|
| 18027 |
if (src0->grad) {
|
| 18028 |
-
src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
|
| 18029 |
}
|
| 18030 |
} break;
|
| 18031 |
case GGML_OP_ADD:
|
| 18032 |
{
|
| 18033 |
if (src0->grad) {
|
| 18034 |
-
src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
|
| 18035 |
}
|
| 18036 |
if (src1->grad) {
|
| 18037 |
if (ggml_are_same_shape(src0, src1)) {
|
| 18038 |
-
src1->grad = ggml_add_or_set(ctx, src1->grad, tensor->grad, zero_table);
|
| 18039 |
} else {
|
| 18040 |
-
src1->grad = ggml_add_or_set(ctx, src1->grad, ggml_repeat_back(ctx, tensor->grad, src1), zero_table);
|
| 18041 |
}
|
| 18042 |
}
|
| 18043 |
} break;
|
| 18044 |
case GGML_OP_ADD1:
|
| 18045 |
{
|
| 18046 |
if (src0->grad) {
|
| 18047 |
-
src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
|
| 18048 |
}
|
| 18049 |
if (src1->grad) {
|
| 18050 |
src1->grad = ggml_add_or_set(ctx,
|
| 18051 |
src1->grad,
|
| 18052 |
ggml_mean(ctx, tensor->grad), // TODO: should probably be sum instead of mean
|
| 18053 |
-
zero_table);
|
| 18054 |
}
|
| 18055 |
} break;
|
| 18056 |
case GGML_OP_ACC:
|
| 18057 |
{
|
| 18058 |
if (src0->grad) {
|
| 18059 |
-
src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
|
| 18060 |
}
|
| 18061 |
if (src1->grad) {
|
| 18062 |
const size_t nb1 = ((int32_t *) tensor->op_params)[0];
|
|
@@ -18078,16 +18271,16 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
| 18078 |
ggml_reshape(ctx,
|
| 18079 |
ggml_cont(ctx, tensor_grad_view),
|
| 18080 |
src1->grad),
|
| 18081 |
-
zero_table);
|
| 18082 |
}
|
| 18083 |
} break;
|
| 18084 |
case GGML_OP_SUB:
|
| 18085 |
{
|
| 18086 |
if (src0->grad) {
|
| 18087 |
-
src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
|
| 18088 |
}
|
| 18089 |
if (src1->grad) {
|
| 18090 |
-
src1->grad = ggml_sub_or_set(ctx, src1->grad, tensor->grad, zero_table);
|
| 18091 |
}
|
| 18092 |
} break;
|
| 18093 |
case GGML_OP_MUL:
|
|
@@ -18097,14 +18290,14 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
| 18097 |
ggml_add_or_set(ctx,
|
| 18098 |
src0->grad,
|
| 18099 |
ggml_mul(ctx, src1, tensor->grad),
|
| 18100 |
-
zero_table);
|
| 18101 |
}
|
| 18102 |
if (src1->grad) {
|
| 18103 |
src1->grad =
|
| 18104 |
ggml_add_or_set(ctx,
|
| 18105 |
src1->grad,
|
| 18106 |
ggml_mul(ctx, src0, tensor->grad),
|
| 18107 |
-
zero_table);
|
| 18108 |
}
|
| 18109 |
} break;
|
| 18110 |
case GGML_OP_DIV:
|
|
@@ -18114,7 +18307,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
| 18114 |
ggml_add_or_set(ctx,
|
| 18115 |
src0->grad,
|
| 18116 |
ggml_div(ctx, tensor->grad, src1),
|
| 18117 |
-
zero_table);
|
| 18118 |
}
|
| 18119 |
if (src1->grad) {
|
| 18120 |
src1->grad =
|
|
@@ -18123,7 +18316,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
| 18123 |
ggml_mul(ctx,
|
| 18124 |
tensor->grad,
|
| 18125 |
ggml_div(ctx, tensor, src1)),
|
| 18126 |
-
zero_table);
|
| 18127 |
}
|
| 18128 |
} break;
|
| 18129 |
case GGML_OP_SQR:
|
|
@@ -18135,7 +18328,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
| 18135 |
ggml_scale(ctx,
|
| 18136 |
ggml_mul(ctx, src0, tensor->grad),
|
| 18137 |
2.0f),
|
| 18138 |
-
zero_table);
|
| 18139 |
}
|
| 18140 |
} break;
|
| 18141 |
case GGML_OP_SQRT:
|
|
@@ -18149,7 +18342,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
| 18149 |
tensor->grad,
|
| 18150 |
tensor),
|
| 18151 |
0.5f),
|
| 18152 |
-
zero_table);
|
| 18153 |
}
|
| 18154 |
} break;
|
| 18155 |
case GGML_OP_LOG:
|
|
@@ -18161,7 +18354,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
| 18161 |
ggml_div(ctx,
|
| 18162 |
tensor->grad,
|
| 18163 |
src0),
|
| 18164 |
-
zero_table);
|
| 18165 |
}
|
| 18166 |
} break;
|
| 18167 |
case GGML_OP_SIN:
|
|
@@ -18173,7 +18366,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
| 18173 |
ggml_mul(ctx,
|
| 18174 |
tensor->grad,
|
| 18175 |
ggml_cos(ctx, src0)),
|
| 18176 |
-
zero_table);
|
| 18177 |
}
|
| 18178 |
} break;
|
| 18179 |
case GGML_OP_COS:
|
|
@@ -18185,7 +18378,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
| 18185 |
ggml_mul(ctx,
|
| 18186 |
tensor->grad,
|
| 18187 |
ggml_sin(ctx, src0)),
|
| 18188 |
-
zero_table);
|
| 18189 |
}
|
| 18190 |
} break;
|
| 18191 |
case GGML_OP_SUM:
|
|
@@ -18195,7 +18388,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
| 18195 |
ggml_add1_or_set(ctx,
|
| 18196 |
src0->grad,
|
| 18197 |
tensor->grad,
|
| 18198 |
-
zero_table);
|
| 18199 |
}
|
| 18200 |
} break;
|
| 18201 |
case GGML_OP_SUM_ROWS:
|
|
@@ -18207,7 +18400,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
| 18207 |
ggml_repeat(ctx,
|
| 18208 |
tensor->grad,
|
| 18209 |
src0->grad),
|
| 18210 |
-
zero_table);
|
| 18211 |
}
|
| 18212 |
} break;
|
| 18213 |
case GGML_OP_MEAN:
|
|
@@ -18222,7 +18415,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
| 18222 |
src0->grad = ggml_add_or_set(ctx,
|
| 18223 |
src0->grad,
|
| 18224 |
ggml_repeat_back(ctx, tensor->grad, src0->grad),
|
| 18225 |
-
zero_table);
|
| 18226 |
}
|
| 18227 |
} break;
|
| 18228 |
case GGML_OP_REPEAT_BACK:
|
|
@@ -18232,7 +18425,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
| 18232 |
src0->grad = ggml_add_or_set(ctx,
|
| 18233 |
src0->grad,
|
| 18234 |
ggml_repeat(ctx, tensor->grad, src0->grad),
|
| 18235 |
-
zero_table);
|
| 18236 |
}
|
| 18237 |
} break;
|
| 18238 |
case GGML_OP_CONCAT:
|
|
@@ -18257,7 +18450,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
| 18257 |
src0->grad = ggml_add_or_set(ctx,
|
| 18258 |
src0->grad,
|
| 18259 |
ggml_rms_norm_back(ctx, src0, tensor->grad, eps),
|
| 18260 |
-
zero_table);
|
| 18261 |
}
|
| 18262 |
} break;
|
| 18263 |
case GGML_OP_RMS_NORM_BACK:
|
|
@@ -18305,7 +18498,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
| 18305 |
ggml_add_or_set(ctx,
|
| 18306 |
src0->grad, // [n,m,q1,r1]
|
| 18307 |
s1_tg, // [n,m,q1,r1]
|
| 18308 |
-
zero_table);
|
| 18309 |
}
|
| 18310 |
if (src1->grad) {
|
| 18311 |
src1->grad =
|
|
@@ -18323,7 +18516,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
| 18323 |
src0, // [n,m,q1,r1]
|
| 18324 |
ggml_transpose(ctx, // [p,m,qq,rr]
|
| 18325 |
tensor->grad)), // [m,p,qq,rr]
|
| 18326 |
-
zero_table);
|
| 18327 |
}
|
| 18328 |
} break;
|
| 18329 |
case GGML_OP_MUL_MAT_ID:
|
|
@@ -18345,7 +18538,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
| 18345 |
ggml_add_or_set(ctx,
|
| 18346 |
src0->grad,
|
| 18347 |
ggml_scale_impl(ctx, tensor->grad, s, false),
|
| 18348 |
-
zero_table);
|
| 18349 |
}
|
| 18350 |
} break;
|
| 18351 |
case GGML_OP_SET:
|
|
@@ -18374,7 +18567,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
| 18374 |
tensor->grad,
|
| 18375 |
ggml_neg(ctx, tensor_grad_view),
|
| 18376 |
nb1, nb2, nb3, offset, false),
|
| 18377 |
-
zero_table);
|
| 18378 |
}
|
| 18379 |
|
| 18380 |
if (src1->grad) {
|
|
@@ -18384,7 +18577,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
| 18384 |
ggml_reshape(ctx,
|
| 18385 |
ggml_cont(ctx, tensor_grad_view),
|
| 18386 |
src1->grad),
|
| 18387 |
-
zero_table);
|
| 18388 |
}
|
| 18389 |
} break;
|
| 18390 |
case GGML_OP_CPY:
|
|
@@ -18395,7 +18588,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
| 18395 |
// tensor = src0 * 1 + src1 * 0
|
| 18396 |
if (src0->grad) {
|
| 18397 |
// dsrc0 = dtensor * 1
|
| 18398 |
-
src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
|
| 18399 |
}
|
| 18400 |
if (src1->grad) {
|
| 18401 |
// dsrc1 = dtensor * 0 -> noop
|
|
@@ -18407,7 +18600,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
| 18407 |
if (src0->grad) {
|
| 18408 |
GGML_ASSERT(ggml_is_contiguous(src0->grad));
|
| 18409 |
GGML_ASSERT(ggml_is_contiguous(tensor->grad));
|
| 18410 |
-
src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
|
| 18411 |
}
|
| 18412 |
} break;
|
| 18413 |
case GGML_OP_RESHAPE:
|
|
@@ -18421,7 +18614,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
| 18421 |
? tensor->grad
|
| 18422 |
: ggml_cont(ctx, tensor->grad),
|
| 18423 |
src0->grad),
|
| 18424 |
-
zero_table);
|
| 18425 |
}
|
| 18426 |
} break;
|
| 18427 |
case GGML_OP_VIEW:
|
|
@@ -18450,7 +18643,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
| 18450 |
nb3 = (nb3 / n0) * ng;
|
| 18451 |
}
|
| 18452 |
|
| 18453 |
-
src0->grad = ggml_acc_or_set(ctx, src0->grad, tensor->grad, nb1, nb2, nb3, offset, zero_table);
|
| 18454 |
}
|
| 18455 |
} break;
|
| 18456 |
case GGML_OP_PERMUTE:
|
|
@@ -18475,7 +18668,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
| 18475 |
axes_backward[1],
|
| 18476 |
axes_backward[2],
|
| 18477 |
axes_backward[3]),
|
| 18478 |
-
zero_table);
|
| 18479 |
}
|
| 18480 |
} break;
|
| 18481 |
case GGML_OP_TRANSPOSE:
|
|
@@ -18485,7 +18678,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
| 18485 |
src0->grad =
|
| 18486 |
ggml_add_or_set(ctx, src0->grad,
|
| 18487 |
ggml_transpose(ctx, tensor->grad),
|
| 18488 |
-
zero_table);
|
| 18489 |
}
|
| 18490 |
} break;
|
| 18491 |
case GGML_OP_GET_ROWS:
|
|
@@ -18497,7 +18690,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
| 18497 |
// last ggml_get_rows_back argument src0->grad is only
|
| 18498 |
// necessary to setup correct output shape
|
| 18499 |
ggml_get_rows_back(ctx, tensor->grad, src1, src0->grad),
|
| 18500 |
-
zero_table);
|
| 18501 |
}
|
| 18502 |
if (src1->grad) {
|
| 18503 |
// noop
|
|
@@ -18521,7 +18714,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
| 18521 |
/* ggml_diag_mask_inf_impl() shouldn't be here */
|
| 18522 |
/* ref: https://github.com/ggerganov/llama.cpp/pull/4203#discussion_r1412377992 */
|
| 18523 |
ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false),
|
| 18524 |
-
zero_table);
|
| 18525 |
}
|
| 18526 |
} break;
|
| 18527 |
case GGML_OP_DIAG_MASK_ZERO:
|
|
@@ -18532,7 +18725,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
| 18532 |
src0->grad =
|
| 18533 |
ggml_add_or_set(ctx, src0->grad,
|
| 18534 |
ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false),
|
| 18535 |
-
zero_table);
|
| 18536 |
}
|
| 18537 |
} break;
|
| 18538 |
case GGML_OP_SOFT_MAX:
|
|
@@ -18542,7 +18735,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
| 18542 |
src0->grad =
|
| 18543 |
ggml_add_or_set(ctx, src0->grad,
|
| 18544 |
ggml_soft_max_back(ctx, tensor->grad, tensor),
|
| 18545 |
-
zero_table);
|
| 18546 |
}
|
| 18547 |
|
| 18548 |
} break;
|
|
@@ -18583,7 +18776,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
| 18583 |
attn_factor,
|
| 18584 |
beta_fast,
|
| 18585 |
beta_slow),
|
| 18586 |
-
zero_table);
|
| 18587 |
}
|
| 18588 |
} break;
|
| 18589 |
case GGML_OP_ROPE_BACK:
|
|
@@ -18619,7 +18812,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
| 18619 |
beta_fast,
|
| 18620 |
beta_slow,
|
| 18621 |
false),
|
| 18622 |
-
zero_table);
|
| 18623 |
}
|
| 18624 |
} break;
|
| 18625 |
case GGML_OP_CLAMP:
|
|
@@ -18644,7 +18837,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
| 18644 |
src1->grad = ggml_add_or_set(ctx,
|
| 18645 |
src1->grad,
|
| 18646 |
ggml_im2col_back(ctx, src0, tensor->grad, src1->ne, s0, s1, p0, p1, d0, d1, is_2D),
|
| 18647 |
-
zero_table);
|
| 18648 |
}
|
| 18649 |
} break;
|
| 18650 |
case GGML_OP_IM2COL_BACK:
|
|
@@ -18673,7 +18866,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
| 18673 |
src0->grad = ggml_add_or_set(ctx,
|
| 18674 |
src0->grad,
|
| 18675 |
ggml_pool_2d_back(ctx, tensor->grad, src0, op, k0, k1, s0, s1, p0, p1),
|
| 18676 |
-
zero_table);
|
| 18677 |
}
|
| 18678 |
} break;
|
| 18679 |
case GGML_OP_POOL_2D_BACK:
|
|
@@ -18738,7 +18931,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
| 18738 |
src0->grad = ggml_add_or_set(ctx,
|
| 18739 |
src0->grad,
|
| 18740 |
grad_q,
|
| 18741 |
-
zero_table);
|
| 18742 |
}
|
| 18743 |
if (src1->grad) {
|
| 18744 |
struct ggml_tensor * view_k = ggml_view_1d(ctx, flash_grad, elem_k, offs_k);
|
|
@@ -18746,7 +18939,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
| 18746 |
src1->grad = ggml_add_or_set(ctx,
|
| 18747 |
src1->grad,
|
| 18748 |
grad_k,
|
| 18749 |
-
zero_table);
|
| 18750 |
}
|
| 18751 |
if (src2->grad) {
|
| 18752 |
struct ggml_tensor * view_v = ggml_view_1d(ctx, flash_grad, elem_v, offs_v);
|
|
@@ -18754,7 +18947,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
| 18754 |
src2->grad = ggml_add_or_set(ctx,
|
| 18755 |
src2->grad,
|
| 18756 |
grad_v,
|
| 18757 |
-
zero_table);
|
| 18758 |
}
|
| 18759 |
} break;
|
| 18760 |
case GGML_OP_FLASH_ATTN_BACK:
|
|
@@ -18780,7 +18973,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
| 18780 |
ggml_mul(ctx,
|
| 18781 |
ggml_sgn(ctx, src0),
|
| 18782 |
tensor->grad),
|
| 18783 |
-
zero_table);
|
| 18784 |
}
|
| 18785 |
} break;
|
| 18786 |
case GGML_UNARY_OP_SGN:
|
|
@@ -18792,7 +18985,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
| 18792 |
case GGML_UNARY_OP_NEG:
|
| 18793 |
{
|
| 18794 |
if (src0->grad) {
|
| 18795 |
-
src0->grad = ggml_sub_or_set(ctx, src0->grad, tensor->grad, zero_table);
|
| 18796 |
}
|
| 18797 |
} break;
|
| 18798 |
case GGML_UNARY_OP_STEP:
|
|
@@ -18817,7 +19010,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
| 18817 |
ggml_mul(ctx,
|
| 18818 |
ggml_step(ctx, src0),
|
| 18819 |
tensor->grad),
|
| 18820 |
-
zero_table);
|
| 18821 |
}
|
| 18822 |
} break;
|
| 18823 |
case GGML_UNARY_OP_SIGMOID:
|
|
@@ -18839,7 +19032,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
| 18839 |
src0->grad = ggml_add_or_set(ctx,
|
| 18840 |
src0->grad,
|
| 18841 |
ggml_silu_back(ctx, src0, tensor->grad),
|
| 18842 |
-
zero_table);
|
| 18843 |
}
|
| 18844 |
} break;
|
| 18845 |
case GGML_UNARY_OP_EXP:
|
|
@@ -18848,7 +19041,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
| 18848 |
src0->grad = ggml_add_or_set(ctx,
|
| 18849 |
src0->grad,
|
| 18850 |
ggml_mul(ctx, tensor, tensor->grad),
|
| 18851 |
-
zero_table);
|
| 18852 |
}
|
| 18853 |
} break;
|
| 18854 |
default:
|
|
@@ -18878,13 +19071,17 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
| 18878 |
src0,
|
| 18879 |
src1,
|
| 18880 |
tensor->grad),
|
| 18881 |
-
zero_table);
|
| 18882 |
}
|
| 18883 |
} break;
|
| 18884 |
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
| 18885 |
{
|
| 18886 |
GGML_ABORT("fatal error"); // not supported
|
| 18887 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18888 |
case GGML_OP_NONE:
|
| 18889 |
{
|
| 18890 |
// nop
|
|
@@ -18974,7 +19171,7 @@ void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor *
|
|
| 18974 |
ggml_build_forward_impl(cgraph, tensor, true);
|
| 18975 |
}
|
| 18976 |
|
| 18977 |
-
void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool keep) {
|
| 18978 |
GGML_ASSERT(gf->n_nodes > 0);
|
| 18979 |
GGML_ASSERT(gf->grads);
|
| 18980 |
|
|
@@ -18990,21 +19187,35 @@ void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph *
|
|
| 18990 |
}
|
| 18991 |
}
|
| 18992 |
|
| 18993 |
-
//
|
| 18994 |
struct ggml_hash_set zero_table = ggml_hash_set_new(gf->size);
|
|
|
|
| 18995 |
for (int i = 0; i < gf->n_nodes; i++) {
|
| 18996 |
-
|
| 18997 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18998 |
}
|
| 18999 |
}
|
| 19000 |
|
| 19001 |
for (int i = gf->n_nodes - 1; i >= 0; i--) {
|
| 19002 |
struct ggml_tensor * node = gf->nodes[i];
|
| 19003 |
|
| 19004 |
-
// inplace operations to add gradients are not created by ggml_compute_backward
|
| 19005 |
// use allocator to automatically make inplace operations
|
| 19006 |
if (node->grad) {
|
| 19007 |
-
ggml_compute_backward(ctx, node, &zero_table);
|
| 19008 |
}
|
| 19009 |
}
|
| 19010 |
|
|
@@ -19018,8 +19229,30 @@ void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph *
|
|
| 19018 |
}
|
| 19019 |
|
| 19020 |
ggml_hash_set_free(&zero_table);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19021 |
}
|
| 19022 |
|
|
|
|
| 19023 |
static void * incr_ptr_aligned(void ** p, size_t size, size_t align) {
|
| 19024 |
void * ptr = *p;
|
| 19025 |
ptr = (void *) GGML_PAD((uintptr_t) ptr, align);
|
|
@@ -19147,10 +19380,28 @@ void ggml_graph_reset(struct ggml_cgraph * cgraph) {
|
|
| 19147 |
GGML_ASSERT(cgraph->grads != NULL);
|
| 19148 |
|
| 19149 |
for (int i = 0; i < cgraph->n_nodes; i++) {
|
| 19150 |
-
struct ggml_tensor *
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19151 |
|
| 19152 |
-
|
| 19153 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19154 |
}
|
| 19155 |
}
|
| 19156 |
}
|
|
@@ -19415,6 +19666,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
|
| 19415 |
} break;
|
| 19416 |
case GGML_OP_CROSS_ENTROPY_LOSS:
|
| 19417 |
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
|
|
|
| 19418 |
{
|
| 19419 |
n_tasks = n_threads;
|
| 19420 |
} break;
|
|
@@ -21777,7 +22029,7 @@ enum ggml_opt_result ggml_opt_resume(
|
|
| 21777 |
ggml_build_forward_expand(gf, f);
|
| 21778 |
|
| 21779 |
struct ggml_cgraph * gb = ggml_graph_dup(ctx, gf);
|
| 21780 |
-
ggml_build_backward_expand(ctx, gf, gb, true);
|
| 21781 |
|
| 21782 |
return ggml_opt_resume_g(ctx, opt, f, gf, gb, NULL, NULL);
|
| 21783 |
}
|
|
|
|
| 1 |
#define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnings on Windows
|
| 2 |
#define _USE_MATH_DEFINES // For M_PI on MSVC
|
| 3 |
|
| 4 |
+
#include "ggml-backend.h"
|
| 5 |
#include "ggml-impl.h"
|
| 6 |
#include "ggml-quants.h"
|
| 7 |
#include "ggml.h"
|
|
|
|
| 2978 |
|
| 2979 |
"CROSS_ENTROPY_LOSS",
|
| 2980 |
"CROSS_ENTROPY_LOSS_BACK",
|
| 2981 |
+
"OPT_STEP_ADAMW",
|
| 2982 |
};
|
| 2983 |
|
| 2984 |
+
static_assert(GGML_OP_COUNT == 80, "GGML_OP_COUNT != 80");
|
| 2985 |
|
| 2986 |
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
| 2987 |
"none",
|
|
|
|
| 3072 |
|
| 3073 |
"cross_entropy_loss(x,y)",
|
| 3074 |
"cross_entropy_loss_back(x,y)",
|
| 3075 |
+
"adamw(x)",
|
| 3076 |
};
|
| 3077 |
|
| 3078 |
+
static_assert(GGML_OP_COUNT == 80, "GGML_OP_COUNT != 80");
|
| 3079 |
|
| 3080 |
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
| 3081 |
|
|
|
|
| 4082 |
}
|
| 4083 |
|
| 4084 |
struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor) {
|
| 4085 |
+
if (tensor->buffer) {
|
| 4086 |
+
ggml_backend_tensor_memset(tensor, 0, 0, ggml_nbytes(tensor));
|
| 4087 |
+
} else {
|
| 4088 |
+
memset(tensor->data, 0, ggml_nbytes(tensor));
|
| 4089 |
+
}
|
| 4090 |
return tensor;
|
| 4091 |
}
|
| 4092 |
|
|
|
|
| 8312 |
return result;
|
| 8313 |
}
|
| 8314 |
|
| 8315 |
+
// opt_step_adamw
|
| 8316 |
|
| 8317 |
+
struct ggml_tensor * ggml_opt_step_adamw(
|
| 8318 |
struct ggml_context * ctx,
|
| 8319 |
+
struct ggml_tensor * a,
|
| 8320 |
+
float alpha,
|
| 8321 |
+
float beta1,
|
| 8322 |
+
float beta2,
|
| 8323 |
+
float eps,
|
| 8324 |
+
float wd) {
|
| 8325 |
+
GGML_ASSERT(a->grad);
|
| 8326 |
+
GGML_ASSERT(alpha > 0.0f);
|
| 8327 |
+
GGML_ASSERT(beta1 >= 0.0f && beta1 <= 1.0f);
|
| 8328 |
+
GGML_ASSERT(beta2 >= 0.0f && beta2 <= 1.0f);
|
| 8329 |
+
GGML_ASSERT(eps >= 0.0f);
|
| 8330 |
+
GGML_ASSERT(wd >= 0.0f && wd <= 1.0f);
|
| 8331 |
+
|
| 8332 |
+
struct ggml_tensor * result = ggml_view_tensor(ctx, a);
|
| 8333 |
+
|
| 8334 |
+
result->op = GGML_OP_OPT_STEP_ADAMW;
|
| 8335 |
+
result->grad = NULL;
|
| 8336 |
+
result->src[0] = a;
|
| 8337 |
+
result->src[1] = a->grad;
|
| 8338 |
+
result->src[2] = ggml_dup_tensor(ctx, a->grad);
|
| 8339 |
+
result->src[3] = ggml_dup_tensor(ctx, a->grad);
|
| 8340 |
+
|
| 8341 |
+
const int64_t iter = 1;
|
| 8342 |
+
memcpy(&result->op_params[0], &iter, sizeof(int64_t));
|
| 8343 |
+
ggml_set_op_params_f32(result, 2, alpha);
|
| 8344 |
+
ggml_set_op_params_f32(result, 3, beta1);
|
| 8345 |
+
ggml_set_op_params_f32(result, 4, beta2);
|
| 8346 |
+
ggml_set_op_params_f32(result, 5, eps);
|
| 8347 |
+
ggml_set_op_params_f32(result, 6, wd);
|
| 8348 |
+
|
| 8349 |
+
return result;
|
| 8350 |
+
}
|
| 8351 |
+
|
| 8352 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 8353 |
+
|
| 8354 |
+
void ggml_set_param(struct ggml_context * ctx, struct ggml_tensor * tensor) {
|
| 8355 |
tensor->flags |= GGML_TENSOR_FLAG_PARAM;
|
| 8356 |
|
| 8357 |
GGML_ASSERT(tensor->grad == NULL);
|
|
|
|
| 8359 |
ggml_format_name(tensor->grad, "%s (grad)", tensor->name);
|
| 8360 |
}
|
| 8361 |
|
| 8362 |
+
void ggml_set_loss(struct ggml_tensor * tensor) {
|
| 8363 |
+
GGML_ASSERT(ggml_is_scalar(tensor));
|
| 8364 |
+
GGML_ASSERT(tensor->type == GGML_TYPE_F32);
|
| 8365 |
+
GGML_ASSERT(tensor->grad);
|
| 8366 |
+
tensor->flags |= GGML_TENSOR_FLAG_LOSS;
|
| 8367 |
+
}
|
| 8368 |
+
|
| 8369 |
// ggml_compute_forward_dup
|
| 8370 |
|
| 8371 |
static void ggml_compute_forward_dup_same_cont(
|
|
|
|
| 17440 |
const int64_t ir0 = dr*ith;
|
| 17441 |
const int64_t ir1 = MIN(ir0 + dr, nr);
|
| 17442 |
|
| 17443 |
+
const float d_by_nr = ((const float *) opt0->data)[0] / (float) nr;
|
| 17444 |
|
| 17445 |
for (int64_t i1 = ir0; i1 < ir1; i1++) {
|
| 17446 |
float * ds0 = (float *)((char *) dst->data + i1*dst->nb[1]);
|
|
|
|
| 17464 |
|
| 17465 |
// grad(src0) = (softmax(src0) - src1) * grad(cross_entropy_loss(src0, src1)) / nr
|
| 17466 |
ggml_vec_sub_f32(nc, ds0, ds0, s1);
|
| 17467 |
+
ggml_vec_scale_f32(nc, ds0, d_by_nr);
|
| 17468 |
|
| 17469 |
#ifndef NDEBUG
|
| 17470 |
for (int i = 0; i < nc; ++i) {
|
|
|
|
| 17493 |
}
|
| 17494 |
}
|
| 17495 |
|
| 17496 |
+
static void ggml_compute_forward_opt_step_adamw_f32(
|
| 17497 |
+
const struct ggml_compute_params * params,
|
| 17498 |
+
struct ggml_tensor * dst) {
|
| 17499 |
+
|
| 17500 |
+
const struct ggml_tensor * src0 = dst->src[0];
|
| 17501 |
+
const struct ggml_tensor * src0_grad = dst->src[1];
|
| 17502 |
+
const struct ggml_tensor * src0_grad_m = dst->src[2];
|
| 17503 |
+
const struct ggml_tensor * src0_grad_v = dst->src[3];
|
| 17504 |
+
GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
|
| 17505 |
+
|
| 17506 |
+
const int ith = params->ith;
|
| 17507 |
+
const int nth = params->nth;
|
| 17508 |
+
|
| 17509 |
+
const int nr = ggml_nrows(src0);
|
| 17510 |
+
|
| 17511 |
+
GGML_TENSOR_UNARY_OP_LOCALS
|
| 17512 |
+
GGML_ASSERT(nb00 == sizeof(float));
|
| 17513 |
+
|
| 17514 |
+
// rows per thread
|
| 17515 |
+
const int dr = (nr + nth - 1)/nth;
|
| 17516 |
+
|
| 17517 |
+
// row range for this thread
|
| 17518 |
+
const int ir0 = dr*ith;
|
| 17519 |
+
const int ir1 = MIN(ir0 + dr, nr);
|
| 17520 |
+
|
| 17521 |
+
/* const float gnorm = 1.0f; */
|
| 17522 |
+
int64_t iter; memcpy(&iter, &dst->op_params[0], sizeof(int64_t));
|
| 17523 |
+
const float alpha = ggml_get_op_params_f32(dst, 2);
|
| 17524 |
+
const float beta1 = ggml_get_op_params_f32(dst, 3);
|
| 17525 |
+
const float beta2 = ggml_get_op_params_f32(dst, 4);
|
| 17526 |
+
const float eps = ggml_get_op_params_f32(dst, 5);
|
| 17527 |
+
const float wd = ggml_get_op_params_f32(dst, 6);
|
| 17528 |
+
|
| 17529 |
+
const float beta1h = alpha/(1.0f - powf(beta1, iter));
|
| 17530 |
+
const float beta2h = 1.0f/(1.0f - powf(beta2, iter));
|
| 17531 |
+
|
| 17532 |
+
for (int ir = ir0; ir < ir1; ++ir) {
|
| 17533 |
+
const int64_t i03 = ir/(ne02*ne01);
|
| 17534 |
+
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
|
| 17535 |
+
const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
|
| 17536 |
+
|
| 17537 |
+
const size_t offset = i03*nb03 + i02*nb02 + i01*nb01;
|
| 17538 |
+
|
| 17539 |
+
float * w = (float *) ((char *) src0->data + offset); // weight
|
| 17540 |
+
const float * g = (const float *) ((const char *) src0_grad->data + offset); // grad
|
| 17541 |
+
float * m = (float *) ((char *) src0_grad_m->data + offset);
|
| 17542 |
+
float * v = (float *) ((char *) src0_grad_v->data + offset);
|
| 17543 |
+
|
| 17544 |
+
for (int i00 = 0; i00 < ne00; ++i00) {
|
| 17545 |
+
m[i00] = m[i00]*beta1 + g[i00]*(1.0f - beta1);
|
| 17546 |
+
v[i00] = v[i00]*beta2 + g[i00]*g[i00]*(1.0f - beta2);
|
| 17547 |
+
|
| 17548 |
+
const float mh = m[i00]*beta1h;
|
| 17549 |
+
const float vh = sqrtf(v[i00]*beta2h) + eps;
|
| 17550 |
+
|
| 17551 |
+
// The weight decay is applied independently of the Adam momenta m and v.
|
| 17552 |
+
// This is NOT equivalent to l2 regularization that adds w[i00]*w[i00] to the loss.
|
| 17553 |
+
// See: https://arxiv.org/pdf/1711.05101v3.pdf
|
| 17554 |
+
w[i00] = w[i00]*(1.0f - alpha*wd) - mh/vh;
|
| 17555 |
+
}
|
| 17556 |
+
}
|
| 17557 |
+
|
| 17558 |
+
ggml_barrier(params->threadpool);
|
| 17559 |
+
if (ith != 0) {
|
| 17560 |
+
return;
|
| 17561 |
+
}
|
| 17562 |
+
|
| 17563 |
+
iter++;
|
| 17564 |
+
memcpy(&dst->op_params[0], &iter, sizeof(int64_t));
|
| 17565 |
+
}
|
| 17566 |
+
|
| 17567 |
+
static void ggml_compute_forward_opt_step_adamw(
|
| 17568 |
+
const struct ggml_compute_params * params,
|
| 17569 |
+
struct ggml_tensor * dst) {
|
| 17570 |
+
|
| 17571 |
+
const struct ggml_tensor * src0 = dst->src[0];
|
| 17572 |
+
|
| 17573 |
+
switch (src0->type) {
|
| 17574 |
+
case GGML_TYPE_F32:
|
| 17575 |
+
{
|
| 17576 |
+
ggml_compute_forward_opt_step_adamw_f32(params, dst);
|
| 17577 |
+
} break;
|
| 17578 |
+
default:
|
| 17579 |
+
{
|
| 17580 |
+
GGML_ABORT("fatal error");
|
| 17581 |
+
}
|
| 17582 |
+
}
|
| 17583 |
+
}
|
| 17584 |
/////////////////////////////////
|
| 17585 |
|
| 17586 |
static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
|
|
|
|
| 17926 |
ggml_compute_forward_cross_entropy_loss_back(params, tensor);
|
| 17927 |
}
|
| 17928 |
break;
|
| 17929 |
+
case GGML_OP_OPT_STEP_ADAMW:
|
| 17930 |
+
{
|
| 17931 |
+
ggml_compute_forward_opt_step_adamw(params, tensor);
|
| 17932 |
+
}
|
| 17933 |
+
break;
|
| 17934 |
case GGML_OP_NONE:
|
| 17935 |
{
|
| 17936 |
// nop
|
|
|
|
| 18085 |
struct ggml_tensor * * checkpoints,
|
| 18086 |
int n_checkpoints) {
|
| 18087 |
ggml_graph_cpy(gf, gb_tmp);
|
| 18088 |
+
ggml_build_backward_expand(ctx, gf, gb_tmp, false, true);
|
| 18089 |
|
| 18090 |
if (n_checkpoints <= 0) {
|
| 18091 |
ggml_graph_cpy(gb_tmp, gb);
|
|
|
|
| 18123 |
ggml_hash_map_free(replacements);
|
| 18124 |
}
|
| 18125 |
|
| 18126 |
+
// utility functions to change gradients
|
| 18127 |
+
// if a is in acc_table, modify gradients in-place and mark result as gradient accumulator
|
| 18128 |
+
// else if a is in zero_table, replace a
|
| 18129 |
+
// else, just add/subtract/etc. the gradients
|
| 18130 |
+
|
| 18131 |
+
static struct ggml_tensor * ggml_add_or_set(
|
| 18132 |
+
struct ggml_context * ctx,
|
| 18133 |
+
struct ggml_tensor * a,
|
| 18134 |
+
struct ggml_tensor * b,
|
| 18135 |
+
struct ggml_hash_set * zero_table,
|
| 18136 |
+
struct ggml_hash_set * acc_table) {
|
| 18137 |
+
if (ggml_hash_contains(acc_table, a)) {
|
| 18138 |
+
struct ggml_tensor * ret = ggml_add_impl(ctx, a, b, true);
|
| 18139 |
+
const size_t insert_result = ggml_hash_insert(acc_table, ret);
|
| 18140 |
+
GGML_ASSERT(insert_result != GGML_HASHSET_FULL);
|
| 18141 |
+
GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS);
|
| 18142 |
+
return ret;
|
| 18143 |
+
}
|
| 18144 |
if (ggml_hash_contains(zero_table, a)) {
|
| 18145 |
return b;
|
|
|
|
|
|
|
| 18146 |
}
|
| 18147 |
+
return ggml_add_impl(ctx, a, b, false);
|
| 18148 |
}
|
| 18149 |
|
| 18150 |
+
static struct ggml_tensor * ggml_acc_or_set(
|
| 18151 |
+
struct ggml_context * ctx,
|
| 18152 |
+
struct ggml_tensor * a,
|
| 18153 |
+
struct ggml_tensor * b,
|
| 18154 |
+
const size_t nb1,
|
| 18155 |
+
const size_t nb2,
|
| 18156 |
+
const size_t nb3,
|
| 18157 |
+
const size_t offset,
|
| 18158 |
+
struct ggml_hash_set * zero_table,
|
| 18159 |
+
struct ggml_hash_set * acc_table) {
|
| 18160 |
+
if (ggml_hash_contains(acc_table, a)) {
|
| 18161 |
+
struct ggml_tensor * ret = ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, true);
|
| 18162 |
+
const size_t insert_result = ggml_hash_insert(acc_table, ret);
|
| 18163 |
+
GGML_ASSERT(insert_result != GGML_HASHSET_FULL);
|
| 18164 |
+
GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS);
|
| 18165 |
+
return ret;
|
| 18166 |
+
}
|
| 18167 |
if (ggml_hash_contains(zero_table, a)) {
|
| 18168 |
+
struct ggml_tensor * a_zero = ggml_scale(ctx, a, 0.0f); // FIXME this is going to produce NaN if a contains inf/NaN
|
| 18169 |
return ggml_acc_impl(ctx, a_zero, b, nb1, nb2, nb3, offset, false);
|
|
|
|
|
|
|
| 18170 |
}
|
| 18171 |
+
return ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, false);
|
| 18172 |
}
|
| 18173 |
|
| 18174 |
+
static struct ggml_tensor * ggml_add1_or_set(
|
| 18175 |
+
struct ggml_context * ctx,
|
| 18176 |
+
struct ggml_tensor * a,
|
| 18177 |
+
struct ggml_tensor * b,
|
| 18178 |
+
struct ggml_hash_set * zero_table,
|
| 18179 |
+
struct ggml_hash_set * acc_table) {
|
| 18180 |
+
if (ggml_hash_contains(acc_table, a)) {
|
| 18181 |
+
struct ggml_tensor * ret = ggml_add1_impl(ctx, a, b, true);
|
| 18182 |
+
const size_t insert_result = ggml_hash_insert(acc_table, ret);
|
| 18183 |
+
GGML_ASSERT(insert_result != GGML_HASHSET_FULL);
|
| 18184 |
+
GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS);
|
| 18185 |
+
return ret;
|
| 18186 |
+
}
|
| 18187 |
if (ggml_hash_contains(zero_table, a)) {
|
| 18188 |
return ggml_repeat(ctx, b, a);
|
|
|
|
|
|
|
| 18189 |
}
|
| 18190 |
+
return ggml_add1_impl(ctx, a, b, false);
|
| 18191 |
}
|
| 18192 |
|
| 18193 |
+
static struct ggml_tensor * ggml_sub_or_set(
|
| 18194 |
+
struct ggml_context * ctx,
|
| 18195 |
+
struct ggml_tensor * a,
|
| 18196 |
+
struct ggml_tensor * b,
|
| 18197 |
+
struct ggml_hash_set * zero_table,
|
| 18198 |
+
struct ggml_hash_set * acc_table) {
|
| 18199 |
+
if (ggml_hash_contains(acc_table, a)) {
|
| 18200 |
+
struct ggml_tensor * ret = ggml_sub_impl(ctx, a, b, true);
|
| 18201 |
+
const size_t insert_result = ggml_hash_insert(acc_table, ret);
|
| 18202 |
+
GGML_ASSERT(insert_result != GGML_HASHSET_FULL);
|
| 18203 |
+
GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS);
|
| 18204 |
+
return ret;
|
| 18205 |
+
}
|
| 18206 |
if (ggml_hash_contains(zero_table, a)) {
|
| 18207 |
return ggml_neg(ctx, b);
|
|
|
|
|
|
|
| 18208 |
}
|
| 18209 |
+
return ggml_sub_impl(ctx, a, b, false);
|
| 18210 |
}
|
| 18211 |
|
| 18212 |
+
static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, struct ggml_hash_set * zero_table, struct ggml_hash_set * acc_table) {
|
| 18213 |
struct ggml_tensor * src0 = tensor->src[0];
|
| 18214 |
struct ggml_tensor * src1 = tensor->src[1];
|
| 18215 |
struct ggml_tensor * src2 = tensor->src[2];
|
|
|
|
| 18218 |
case GGML_OP_DUP:
|
| 18219 |
{
|
| 18220 |
if (src0->grad) {
|
| 18221 |
+
src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
|
| 18222 |
}
|
| 18223 |
} break;
|
| 18224 |
case GGML_OP_ADD:
|
| 18225 |
{
|
| 18226 |
if (src0->grad) {
|
| 18227 |
+
src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
|
| 18228 |
}
|
| 18229 |
if (src1->grad) {
|
| 18230 |
if (ggml_are_same_shape(src0, src1)) {
|
| 18231 |
+
src1->grad = ggml_add_or_set(ctx, src1->grad, tensor->grad, zero_table, acc_table);
|
| 18232 |
} else {
|
| 18233 |
+
src1->grad = ggml_add_or_set(ctx, src1->grad, ggml_repeat_back(ctx, tensor->grad, src1), zero_table, acc_table);
|
| 18234 |
}
|
| 18235 |
}
|
| 18236 |
} break;
|
| 18237 |
case GGML_OP_ADD1:
|
| 18238 |
{
|
| 18239 |
if (src0->grad) {
|
| 18240 |
+
src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
|
| 18241 |
}
|
| 18242 |
if (src1->grad) {
|
| 18243 |
src1->grad = ggml_add_or_set(ctx,
|
| 18244 |
src1->grad,
|
| 18245 |
ggml_mean(ctx, tensor->grad), // TODO: should probably be sum instead of mean
|
| 18246 |
+
zero_table, acc_table);
|
| 18247 |
}
|
| 18248 |
} break;
|
| 18249 |
case GGML_OP_ACC:
|
| 18250 |
{
|
| 18251 |
if (src0->grad) {
|
| 18252 |
+
src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
|
| 18253 |
}
|
| 18254 |
if (src1->grad) {
|
| 18255 |
const size_t nb1 = ((int32_t *) tensor->op_params)[0];
|
|
|
|
| 18271 |
ggml_reshape(ctx,
|
| 18272 |
ggml_cont(ctx, tensor_grad_view),
|
| 18273 |
src1->grad),
|
| 18274 |
+
zero_table, acc_table);
|
| 18275 |
}
|
| 18276 |
} break;
|
| 18277 |
case GGML_OP_SUB:
|
| 18278 |
{
|
| 18279 |
if (src0->grad) {
|
| 18280 |
+
src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
|
| 18281 |
}
|
| 18282 |
if (src1->grad) {
|
| 18283 |
+
src1->grad = ggml_sub_or_set(ctx, src1->grad, tensor->grad, zero_table, acc_table);
|
| 18284 |
}
|
| 18285 |
} break;
|
| 18286 |
case GGML_OP_MUL:
|
|
|
|
| 18290 |
ggml_add_or_set(ctx,
|
| 18291 |
src0->grad,
|
| 18292 |
ggml_mul(ctx, src1, tensor->grad),
|
| 18293 |
+
zero_table, acc_table);
|
| 18294 |
}
|
| 18295 |
if (src1->grad) {
|
| 18296 |
src1->grad =
|
| 18297 |
ggml_add_or_set(ctx,
|
| 18298 |
src1->grad,
|
| 18299 |
ggml_mul(ctx, src0, tensor->grad),
|
| 18300 |
+
zero_table, acc_table);
|
| 18301 |
}
|
| 18302 |
} break;
|
| 18303 |
case GGML_OP_DIV:
|
|
|
|
| 18307 |
ggml_add_or_set(ctx,
|
| 18308 |
src0->grad,
|
| 18309 |
ggml_div(ctx, tensor->grad, src1),
|
| 18310 |
+
zero_table, acc_table);
|
| 18311 |
}
|
| 18312 |
if (src1->grad) {
|
| 18313 |
src1->grad =
|
|
|
|
| 18316 |
ggml_mul(ctx,
|
| 18317 |
tensor->grad,
|
| 18318 |
ggml_div(ctx, tensor, src1)),
|
| 18319 |
+
zero_table, acc_table);
|
| 18320 |
}
|
| 18321 |
} break;
|
| 18322 |
case GGML_OP_SQR:
|
|
|
|
| 18328 |
ggml_scale(ctx,
|
| 18329 |
ggml_mul(ctx, src0, tensor->grad),
|
| 18330 |
2.0f),
|
| 18331 |
+
zero_table, acc_table);
|
| 18332 |
}
|
| 18333 |
} break;
|
| 18334 |
case GGML_OP_SQRT:
|
|
|
|
| 18342 |
tensor->grad,
|
| 18343 |
tensor),
|
| 18344 |
0.5f),
|
| 18345 |
+
zero_table, acc_table);
|
| 18346 |
}
|
| 18347 |
} break;
|
| 18348 |
case GGML_OP_LOG:
|
|
|
|
| 18354 |
ggml_div(ctx,
|
| 18355 |
tensor->grad,
|
| 18356 |
src0),
|
| 18357 |
+
zero_table, acc_table);
|
| 18358 |
}
|
| 18359 |
} break;
|
| 18360 |
case GGML_OP_SIN:
|
|
|
|
| 18366 |
ggml_mul(ctx,
|
| 18367 |
tensor->grad,
|
| 18368 |
ggml_cos(ctx, src0)),
|
| 18369 |
+
zero_table, acc_table);
|
| 18370 |
}
|
| 18371 |
} break;
|
| 18372 |
case GGML_OP_COS:
|
|
|
|
| 18378 |
ggml_mul(ctx,
|
| 18379 |
tensor->grad,
|
| 18380 |
ggml_sin(ctx, src0)),
|
| 18381 |
+
zero_table, acc_table);
|
| 18382 |
}
|
| 18383 |
} break;
|
| 18384 |
case GGML_OP_SUM:
|
|
|
|
| 18388 |
ggml_add1_or_set(ctx,
|
| 18389 |
src0->grad,
|
| 18390 |
tensor->grad,
|
| 18391 |
+
zero_table, acc_table);
|
| 18392 |
}
|
| 18393 |
} break;
|
| 18394 |
case GGML_OP_SUM_ROWS:
|
|
|
|
| 18400 |
ggml_repeat(ctx,
|
| 18401 |
tensor->grad,
|
| 18402 |
src0->grad),
|
| 18403 |
+
zero_table, acc_table);
|
| 18404 |
}
|
| 18405 |
} break;
|
| 18406 |
case GGML_OP_MEAN:
|
|
|
|
| 18415 |
src0->grad = ggml_add_or_set(ctx,
|
| 18416 |
src0->grad,
|
| 18417 |
ggml_repeat_back(ctx, tensor->grad, src0->grad),
|
| 18418 |
+
zero_table, acc_table);
|
| 18419 |
}
|
| 18420 |
} break;
|
| 18421 |
case GGML_OP_REPEAT_BACK:
|
|
|
|
| 18425 |
src0->grad = ggml_add_or_set(ctx,
|
| 18426 |
src0->grad,
|
| 18427 |
ggml_repeat(ctx, tensor->grad, src0->grad),
|
| 18428 |
+
zero_table, acc_table);
|
| 18429 |
}
|
| 18430 |
} break;
|
| 18431 |
case GGML_OP_CONCAT:
|
|
|
|
| 18450 |
src0->grad = ggml_add_or_set(ctx,
|
| 18451 |
src0->grad,
|
| 18452 |
ggml_rms_norm_back(ctx, src0, tensor->grad, eps),
|
| 18453 |
+
zero_table, acc_table);
|
| 18454 |
}
|
| 18455 |
} break;
|
| 18456 |
case GGML_OP_RMS_NORM_BACK:
|
|
|
|
| 18498 |
ggml_add_or_set(ctx,
|
| 18499 |
src0->grad, // [n,m,q1,r1]
|
| 18500 |
s1_tg, // [n,m,q1,r1]
|
| 18501 |
+
zero_table, acc_table);
|
| 18502 |
}
|
| 18503 |
if (src1->grad) {
|
| 18504 |
src1->grad =
|
|
|
|
| 18516 |
src0, // [n,m,q1,r1]
|
| 18517 |
ggml_transpose(ctx, // [p,m,qq,rr]
|
| 18518 |
tensor->grad)), // [m,p,qq,rr]
|
| 18519 |
+
zero_table, acc_table);
|
| 18520 |
}
|
| 18521 |
} break;
|
| 18522 |
case GGML_OP_MUL_MAT_ID:
|
|
|
|
| 18538 |
ggml_add_or_set(ctx,
|
| 18539 |
src0->grad,
|
| 18540 |
ggml_scale_impl(ctx, tensor->grad, s, false),
|
| 18541 |
+
zero_table, acc_table);
|
| 18542 |
}
|
| 18543 |
} break;
|
| 18544 |
case GGML_OP_SET:
|
|
|
|
| 18567 |
tensor->grad,
|
| 18568 |
ggml_neg(ctx, tensor_grad_view),
|
| 18569 |
nb1, nb2, nb3, offset, false),
|
| 18570 |
+
zero_table, acc_table);
|
| 18571 |
}
|
| 18572 |
|
| 18573 |
if (src1->grad) {
|
|
|
|
| 18577 |
ggml_reshape(ctx,
|
| 18578 |
ggml_cont(ctx, tensor_grad_view),
|
| 18579 |
src1->grad),
|
| 18580 |
+
zero_table, acc_table);
|
| 18581 |
}
|
| 18582 |
} break;
|
| 18583 |
case GGML_OP_CPY:
|
|
|
|
| 18588 |
// tensor = src0 * 1 + src1 * 0
|
| 18589 |
if (src0->grad) {
|
| 18590 |
// dsrc0 = dtensor * 1
|
| 18591 |
+
src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
|
| 18592 |
}
|
| 18593 |
if (src1->grad) {
|
| 18594 |
// dsrc1 = dtensor * 0 -> noop
|
|
|
|
| 18600 |
if (src0->grad) {
|
| 18601 |
GGML_ASSERT(ggml_is_contiguous(src0->grad));
|
| 18602 |
GGML_ASSERT(ggml_is_contiguous(tensor->grad));
|
| 18603 |
+
src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
|
| 18604 |
}
|
| 18605 |
} break;
|
| 18606 |
case GGML_OP_RESHAPE:
|
|
|
|
| 18614 |
? tensor->grad
|
| 18615 |
: ggml_cont(ctx, tensor->grad),
|
| 18616 |
src0->grad),
|
| 18617 |
+
zero_table, acc_table);
|
| 18618 |
}
|
| 18619 |
} break;
|
| 18620 |
case GGML_OP_VIEW:
|
|
|
|
| 18643 |
nb3 = (nb3 / n0) * ng;
|
| 18644 |
}
|
| 18645 |
|
| 18646 |
+
src0->grad = ggml_acc_or_set(ctx, src0->grad, tensor->grad, nb1, nb2, nb3, offset, zero_table, acc_table);
|
| 18647 |
}
|
| 18648 |
} break;
|
| 18649 |
case GGML_OP_PERMUTE:
|
|
|
|
| 18668 |
axes_backward[1],
|
| 18669 |
axes_backward[2],
|
| 18670 |
axes_backward[3]),
|
| 18671 |
+
zero_table, acc_table);
|
| 18672 |
}
|
| 18673 |
} break;
|
| 18674 |
case GGML_OP_TRANSPOSE:
|
|
|
|
| 18678 |
src0->grad =
|
| 18679 |
ggml_add_or_set(ctx, src0->grad,
|
| 18680 |
ggml_transpose(ctx, tensor->grad),
|
| 18681 |
+
zero_table, acc_table);
|
| 18682 |
}
|
| 18683 |
} break;
|
| 18684 |
case GGML_OP_GET_ROWS:
|
|
|
|
| 18690 |
// last ggml_get_rows_back argument src0->grad is only
|
| 18691 |
// necessary to setup correct output shape
|
| 18692 |
ggml_get_rows_back(ctx, tensor->grad, src1, src0->grad),
|
| 18693 |
+
zero_table, acc_table);
|
| 18694 |
}
|
| 18695 |
if (src1->grad) {
|
| 18696 |
// noop
|
|
|
|
| 18714 |
/* ggml_diag_mask_inf_impl() shouldn't be here */
|
| 18715 |
/* ref: https://github.com/ggerganov/llama.cpp/pull/4203#discussion_r1412377992 */
|
| 18716 |
ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false),
|
| 18717 |
+
zero_table, acc_table);
|
| 18718 |
}
|
| 18719 |
} break;
|
| 18720 |
case GGML_OP_DIAG_MASK_ZERO:
|
|
|
|
| 18725 |
src0->grad =
|
| 18726 |
ggml_add_or_set(ctx, src0->grad,
|
| 18727 |
ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false),
|
| 18728 |
+
zero_table, acc_table);
|
| 18729 |
}
|
| 18730 |
} break;
|
| 18731 |
case GGML_OP_SOFT_MAX:
|
|
|
|
| 18735 |
src0->grad =
|
| 18736 |
ggml_add_or_set(ctx, src0->grad,
|
| 18737 |
ggml_soft_max_back(ctx, tensor->grad, tensor),
|
| 18738 |
+
zero_table, acc_table);
|
| 18739 |
}
|
| 18740 |
|
| 18741 |
} break;
|
|
|
|
| 18776 |
attn_factor,
|
| 18777 |
beta_fast,
|
| 18778 |
beta_slow),
|
| 18779 |
+
zero_table, acc_table);
|
| 18780 |
}
|
| 18781 |
} break;
|
| 18782 |
case GGML_OP_ROPE_BACK:
|
|
|
|
| 18812 |
beta_fast,
|
| 18813 |
beta_slow,
|
| 18814 |
false),
|
| 18815 |
+
zero_table, acc_table);
|
| 18816 |
}
|
| 18817 |
} break;
|
| 18818 |
case GGML_OP_CLAMP:
|
|
|
|
| 18837 |
src1->grad = ggml_add_or_set(ctx,
|
| 18838 |
src1->grad,
|
| 18839 |
ggml_im2col_back(ctx, src0, tensor->grad, src1->ne, s0, s1, p0, p1, d0, d1, is_2D),
|
| 18840 |
+
zero_table, acc_table);
|
| 18841 |
}
|
| 18842 |
} break;
|
| 18843 |
case GGML_OP_IM2COL_BACK:
|
|
|
|
| 18866 |
src0->grad = ggml_add_or_set(ctx,
|
| 18867 |
src0->grad,
|
| 18868 |
ggml_pool_2d_back(ctx, tensor->grad, src0, op, k0, k1, s0, s1, p0, p1),
|
| 18869 |
+
zero_table, acc_table);
|
| 18870 |
}
|
| 18871 |
} break;
|
| 18872 |
case GGML_OP_POOL_2D_BACK:
|
|
|
|
| 18931 |
src0->grad = ggml_add_or_set(ctx,
|
| 18932 |
src0->grad,
|
| 18933 |
grad_q,
|
| 18934 |
+
zero_table, acc_table);
|
| 18935 |
}
|
| 18936 |
if (src1->grad) {
|
| 18937 |
struct ggml_tensor * view_k = ggml_view_1d(ctx, flash_grad, elem_k, offs_k);
|
|
|
|
| 18939 |
src1->grad = ggml_add_or_set(ctx,
|
| 18940 |
src1->grad,
|
| 18941 |
grad_k,
|
| 18942 |
+
zero_table, acc_table);
|
| 18943 |
}
|
| 18944 |
if (src2->grad) {
|
| 18945 |
struct ggml_tensor * view_v = ggml_view_1d(ctx, flash_grad, elem_v, offs_v);
|
|
|
|
| 18947 |
src2->grad = ggml_add_or_set(ctx,
|
| 18948 |
src2->grad,
|
| 18949 |
grad_v,
|
| 18950 |
+
zero_table, acc_table);
|
| 18951 |
}
|
| 18952 |
} break;
|
| 18953 |
case GGML_OP_FLASH_ATTN_BACK:
|
|
|
|
| 18973 |
ggml_mul(ctx,
|
| 18974 |
ggml_sgn(ctx, src0),
|
| 18975 |
tensor->grad),
|
| 18976 |
+
zero_table, acc_table);
|
| 18977 |
}
|
| 18978 |
} break;
|
| 18979 |
case GGML_UNARY_OP_SGN:
|
|
|
|
| 18985 |
case GGML_UNARY_OP_NEG:
|
| 18986 |
{
|
| 18987 |
if (src0->grad) {
|
| 18988 |
+
src0->grad = ggml_sub_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
|
| 18989 |
}
|
| 18990 |
} break;
|
| 18991 |
case GGML_UNARY_OP_STEP:
|
|
|
|
| 19010 |
ggml_mul(ctx,
|
| 19011 |
ggml_step(ctx, src0),
|
| 19012 |
tensor->grad),
|
| 19013 |
+
zero_table, acc_table);
|
| 19014 |
}
|
| 19015 |
} break;
|
| 19016 |
case GGML_UNARY_OP_SIGMOID:
|
|
|
|
| 19032 |
src0->grad = ggml_add_or_set(ctx,
|
| 19033 |
src0->grad,
|
| 19034 |
ggml_silu_back(ctx, src0, tensor->grad),
|
| 19035 |
+
zero_table, acc_table);
|
| 19036 |
}
|
| 19037 |
} break;
|
| 19038 |
case GGML_UNARY_OP_EXP:
|
|
|
|
| 19041 |
src0->grad = ggml_add_or_set(ctx,
|
| 19042 |
src0->grad,
|
| 19043 |
ggml_mul(ctx, tensor, tensor->grad),
|
| 19044 |
+
zero_table, acc_table);
|
| 19045 |
}
|
| 19046 |
} break;
|
| 19047 |
default:
|
|
|
|
| 19071 |
src0,
|
| 19072 |
src1,
|
| 19073 |
tensor->grad),
|
| 19074 |
+
zero_table, acc_table);
|
| 19075 |
}
|
| 19076 |
} break;
|
| 19077 |
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
| 19078 |
{
|
| 19079 |
GGML_ABORT("fatal error"); // not supported
|
| 19080 |
}
|
| 19081 |
+
case GGML_OP_OPT_STEP_ADAMW:
|
| 19082 |
+
{
|
| 19083 |
+
GGML_ABORT("fatal error"); // not supported
|
| 19084 |
+
}
|
| 19085 |
case GGML_OP_NONE:
|
| 19086 |
{
|
| 19087 |
// nop
|
|
|
|
| 19171 |
ggml_build_forward_impl(cgraph, tensor, true);
|
| 19172 |
}
|
| 19173 |
|
| 19174 |
+
void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool accumulate, bool keep) {
|
| 19175 |
GGML_ASSERT(gf->n_nodes > 0);
|
| 19176 |
GGML_ASSERT(gf->grads);
|
| 19177 |
|
|
|
|
| 19187 |
}
|
| 19188 |
}
|
| 19189 |
|
| 19190 |
+
// keep tables of original gradients for replacement/accumulation logic
|
| 19191 |
struct ggml_hash_set zero_table = ggml_hash_set_new(gf->size);
|
| 19192 |
+
struct ggml_hash_set acc_table = ggml_hash_set_new(gf->size);
|
| 19193 |
for (int i = 0; i < gf->n_nodes; i++) {
|
| 19194 |
+
struct ggml_tensor * node = gf->nodes[i];
|
| 19195 |
+
|
| 19196 |
+
if (node->grad) {
|
| 19197 |
+
{
|
| 19198 |
+
const size_t insert_result = ggml_hash_insert(&zero_table, node->grad);
|
| 19199 |
+
GGML_ASSERT(insert_result != GGML_HASHSET_FULL);
|
| 19200 |
+
GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS);
|
| 19201 |
+
}
|
| 19202 |
+
|
| 19203 |
+
// only gradients of trainable parameters should be accumulated
|
| 19204 |
+
if (accumulate && (node->flags & GGML_TENSOR_FLAG_PARAM)) {
|
| 19205 |
+
const size_t insert_result = ggml_hash_insert(&acc_table, node->grad);
|
| 19206 |
+
GGML_ASSERT(insert_result != GGML_HASHSET_FULL);
|
| 19207 |
+
GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS);
|
| 19208 |
+
}
|
| 19209 |
}
|
| 19210 |
}
|
| 19211 |
|
| 19212 |
for (int i = gf->n_nodes - 1; i >= 0; i--) {
|
| 19213 |
struct ggml_tensor * node = gf->nodes[i];
|
| 19214 |
|
| 19215 |
+
// inplace operations to add gradients are not created by ggml_compute_backward except for gradient accumulation
|
| 19216 |
// use allocator to automatically make inplace operations
|
| 19217 |
if (node->grad) {
|
| 19218 |
+
ggml_compute_backward(ctx, node, &zero_table, &acc_table);
|
| 19219 |
}
|
| 19220 |
}
|
| 19221 |
|
|
|
|
| 19229 |
}
|
| 19230 |
|
| 19231 |
ggml_hash_set_free(&zero_table);
|
| 19232 |
+
ggml_hash_set_free(&acc_table);
|
| 19233 |
+
}
|
| 19234 |
+
|
| 19235 |
+
void ggml_build_opt_adamw(
|
| 19236 |
+
struct ggml_context * ctx,
|
| 19237 |
+
struct ggml_cgraph * gf,
|
| 19238 |
+
struct ggml_cgraph * gb,
|
| 19239 |
+
float alpha,
|
| 19240 |
+
float beta1,
|
| 19241 |
+
float beta2,
|
| 19242 |
+
float eps,
|
| 19243 |
+
float wd) {
|
| 19244 |
+
for (int i = 0; i < gf->n_nodes; i++) {
|
| 19245 |
+
struct ggml_tensor * node = gf->nodes[i];
|
| 19246 |
+
|
| 19247 |
+
if (node->flags & GGML_TENSOR_FLAG_PARAM) {
|
| 19248 |
+
GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node);
|
| 19249 |
+
struct ggml_tensor * opt_step = ggml_opt_step_adamw(ctx, node, alpha, beta1, beta2, eps, wd);
|
| 19250 |
+
ggml_build_forward_expand(gb, opt_step);
|
| 19251 |
+
}
|
| 19252 |
+
}
|
| 19253 |
}
|
| 19254 |
|
| 19255 |
+
|
| 19256 |
static void * incr_ptr_aligned(void ** p, size_t size, size_t align) {
|
| 19257 |
void * ptr = *p;
|
| 19258 |
ptr = (void *) GGML_PAD((uintptr_t) ptr, align);
|
|
|
|
| 19380 |
GGML_ASSERT(cgraph->grads != NULL);
|
| 19381 |
|
| 19382 |
for (int i = 0; i < cgraph->n_nodes; i++) {
|
| 19383 |
+
struct ggml_tensor * node = cgraph->nodes[i];
|
| 19384 |
+
|
| 19385 |
+
// initial gradients of loss should be 1, 0 otherwise
|
| 19386 |
+
if (node->grad) {
|
| 19387 |
+
if (node->flags & GGML_TENSOR_FLAG_LOSS) {
|
| 19388 |
+
GGML_ASSERT(node->grad->buffer);
|
| 19389 |
+
GGML_ASSERT(node->type == GGML_TYPE_F32);
|
| 19390 |
+
GGML_ASSERT(ggml_is_scalar(node));
|
| 19391 |
+
|
| 19392 |
+
const float onef = 1.0f;
|
| 19393 |
+
ggml_backend_tensor_set(node->grad, &onef, 0, ggml_nbytes(node->grad));
|
| 19394 |
+
} else {
|
| 19395 |
+
ggml_set_zero(node->grad);
|
| 19396 |
+
}
|
| 19397 |
+
}
|
| 19398 |
|
| 19399 |
+
GGML_ASSERT(node);
|
| 19400 |
+
if (node->op == GGML_OP_OPT_STEP_ADAMW) {
|
| 19401 |
+
// set iteration to 1 and clear momenta
|
| 19402 |
+
ggml_set_op_params_i32(node, 0, 1);
|
| 19403 |
+
ggml_set_zero(node->src[2]);
|
| 19404 |
+
ggml_set_zero(node->src[3]);
|
| 19405 |
}
|
| 19406 |
}
|
| 19407 |
}
|
|
|
|
| 19666 |
} break;
|
| 19667 |
case GGML_OP_CROSS_ENTROPY_LOSS:
|
| 19668 |
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
| 19669 |
+
case GGML_OP_OPT_STEP_ADAMW:
|
| 19670 |
{
|
| 19671 |
n_tasks = n_threads;
|
| 19672 |
} break;
|
|
|
|
| 22029 |
ggml_build_forward_expand(gf, f);
|
| 22030 |
|
| 22031 |
struct ggml_cgraph * gb = ggml_graph_dup(ctx, gf);
|
| 22032 |
+
ggml_build_backward_expand(ctx, gf, gb, false, true);
|
| 22033 |
|
| 22034 |
return ggml_opt_resume_g(ctx, opt, f, gf, gb, NULL, NULL);
|
| 22035 |
}
|