Spaces:
Running
Running
Sigbjørn Skjæret
commited on
Commit
·
f798922
1
Parent(s):
4434043
ggml : implement GEGLU_ERF and GEGLU_QUICK ops (llama/14445)
Browse files- ggml/include/ggml.h +28 -0
- ggml/src/ggml-cpu/ggml-cpu.c +2 -0
- ggml/src/ggml-cpu/ops.cpp +294 -0
- ggml/src/ggml-cpu/vec.h +40 -0
- ggml/src/ggml-cuda/ggml-cuda.cu +8 -0
- ggml/src/ggml-cuda/unary.cu +8 -0
- ggml/src/ggml-cuda/unary.cuh +4 -0
- ggml/src/ggml-metal/ggml-metal.m +12 -0
- ggml/src/ggml-metal/ggml-metal.metal +44 -0
- ggml/src/ggml-opencl/ggml-opencl.cpp +28 -8
- ggml/src/ggml-opencl/kernels/glu.cl +136 -0
- ggml/src/ggml-sycl/element_wise.cpp +50 -0
- ggml/src/ggml-sycl/element_wise.hpp +2 -0
- ggml/src/ggml-sycl/ggml-sycl.cpp +8 -0
- ggml/src/ggml-vulkan/ggml-vulkan.cpp +16 -0
- ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
- ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
- ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +4 -0
- ggml/src/ggml.c +45 -1
ggml/include/ggml.h
CHANGED
|
@@ -557,6 +557,8 @@ extern "C" {
|
|
| 557 |
GGML_GLU_OP_REGLU,
|
| 558 |
GGML_GLU_OP_GEGLU,
|
| 559 |
GGML_GLU_OP_SWIGLU,
|
|
|
|
|
|
|
| 560 |
|
| 561 |
GGML_GLU_OP_COUNT,
|
| 562 |
};
|
|
@@ -1144,6 +1146,22 @@ extern "C" {
|
|
| 1144 |
struct ggml_context * ctx,
|
| 1145 |
struct ggml_tensor * a);
|
| 1146 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1147 |
// A: n columns, r rows,
|
| 1148 |
// B: n columns, r rows,
|
| 1149 |
GGML_API struct ggml_tensor * ggml_glu_split(
|
|
@@ -1167,6 +1185,16 @@ extern "C" {
|
|
| 1167 |
struct ggml_tensor * a,
|
| 1168 |
struct ggml_tensor * b);
|
| 1169 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1170 |
// normalize along rows
|
| 1171 |
GGML_API struct ggml_tensor * ggml_norm(
|
| 1172 |
struct ggml_context * ctx,
|
|
|
|
| 557 |
GGML_GLU_OP_REGLU,
|
| 558 |
GGML_GLU_OP_GEGLU,
|
| 559 |
GGML_GLU_OP_SWIGLU,
|
| 560 |
+
GGML_GLU_OP_GEGLU_ERF,
|
| 561 |
+
GGML_GLU_OP_GEGLU_QUICK,
|
| 562 |
|
| 563 |
GGML_GLU_OP_COUNT,
|
| 564 |
};
|
|
|
|
| 1146 |
struct ggml_context * ctx,
|
| 1147 |
struct ggml_tensor * a);
|
| 1148 |
|
| 1149 |
+
GGML_API struct ggml_tensor * ggml_geglu_erf(
|
| 1150 |
+
struct ggml_context * ctx,
|
| 1151 |
+
struct ggml_tensor * a);
|
| 1152 |
+
|
| 1153 |
+
GGML_API struct ggml_tensor * ggml_geglu_erf_swapped(
|
| 1154 |
+
struct ggml_context * ctx,
|
| 1155 |
+
struct ggml_tensor * a);
|
| 1156 |
+
|
| 1157 |
+
GGML_API struct ggml_tensor * ggml_geglu_quick(
|
| 1158 |
+
struct ggml_context * ctx,
|
| 1159 |
+
struct ggml_tensor * a);
|
| 1160 |
+
|
| 1161 |
+
GGML_API struct ggml_tensor * ggml_geglu_quick_swapped(
|
| 1162 |
+
struct ggml_context * ctx,
|
| 1163 |
+
struct ggml_tensor * a);
|
| 1164 |
+
|
| 1165 |
// A: n columns, r rows,
|
| 1166 |
// B: n columns, r rows,
|
| 1167 |
GGML_API struct ggml_tensor * ggml_glu_split(
|
|
|
|
| 1185 |
struct ggml_tensor * a,
|
| 1186 |
struct ggml_tensor * b);
|
| 1187 |
|
| 1188 |
+
GGML_API struct ggml_tensor * ggml_geglu_erf_split(
|
| 1189 |
+
struct ggml_context * ctx,
|
| 1190 |
+
struct ggml_tensor * a,
|
| 1191 |
+
struct ggml_tensor * b);
|
| 1192 |
+
|
| 1193 |
+
GGML_API struct ggml_tensor * ggml_geglu_quick_split(
|
| 1194 |
+
struct ggml_context * ctx,
|
| 1195 |
+
struct ggml_tensor * a,
|
| 1196 |
+
struct ggml_tensor * b);
|
| 1197 |
+
|
| 1198 |
// normalize along rows
|
| 1199 |
GGML_API struct ggml_tensor * ggml_norm(
|
| 1200 |
struct ggml_context * ctx,
|
ggml/src/ggml-cpu/ggml-cpu.c
CHANGED
|
@@ -2172,6 +2172,8 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
|
| 2172 |
case GGML_GLU_OP_REGLU:
|
| 2173 |
case GGML_GLU_OP_GEGLU:
|
| 2174 |
case GGML_GLU_OP_SWIGLU:
|
|
|
|
|
|
|
| 2175 |
{
|
| 2176 |
n_tasks = n_threads;
|
| 2177 |
} break;
|
|
|
|
| 2172 |
case GGML_GLU_OP_REGLU:
|
| 2173 |
case GGML_GLU_OP_GEGLU:
|
| 2174 |
case GGML_GLU_OP_SWIGLU:
|
| 2175 |
+
case GGML_GLU_OP_GEGLU_ERF:
|
| 2176 |
+
case GGML_GLU_OP_GEGLU_QUICK:
|
| 2177 |
{
|
| 2178 |
n_tasks = n_threads;
|
| 2179 |
} break;
|
ggml/src/ggml-cpu/ops.cpp
CHANGED
|
@@ -3614,6 +3614,292 @@ static void ggml_compute_forward_swiglu(
|
|
| 3614 |
}
|
| 3615 |
}
|
| 3616 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3617 |
// ggml_compute_forward_norm
|
| 3618 |
|
| 3619 |
static void ggml_compute_forward_norm_f32(
|
|
@@ -8779,6 +9065,14 @@ void ggml_compute_forward_glu(
|
|
| 8779 |
{
|
| 8780 |
ggml_compute_forward_swiglu(params, dst);
|
| 8781 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8782 |
default:
|
| 8783 |
{
|
| 8784 |
GGML_ABORT("fatal error");
|
|
|
|
| 3614 |
}
|
| 3615 |
}
|
| 3616 |
|
| 3617 |
+
// ggml_compute_forward_geglu_erf
|
| 3618 |
+
|
| 3619 |
+
static void ggml_compute_forward_geglu_erf_f32(
|
| 3620 |
+
const ggml_compute_params * params,
|
| 3621 |
+
ggml_tensor * dst) {
|
| 3622 |
+
|
| 3623 |
+
const ggml_tensor * src0 = dst->src[0];
|
| 3624 |
+
const ggml_tensor * src1 = dst->src[1];
|
| 3625 |
+
char * src0_d = (char *) src0->data;
|
| 3626 |
+
char * src1_d = (char *) (src1 ? src1->data : src0->data);
|
| 3627 |
+
const size_t src0_o = src0->nb[1];
|
| 3628 |
+
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
| 3629 |
+
|
| 3630 |
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
| 3631 |
+
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
| 3632 |
+
|
| 3633 |
+
if (src1) {
|
| 3634 |
+
GGML_ASSERT(ggml_is_contiguous_1(src1));
|
| 3635 |
+
GGML_ASSERT(src0->type == src1->type);
|
| 3636 |
+
}
|
| 3637 |
+
|
| 3638 |
+
const int ith = params->ith;
|
| 3639 |
+
const int nth = params->nth;
|
| 3640 |
+
|
| 3641 |
+
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
| 3642 |
+
const int nr = ggml_nrows(src0);
|
| 3643 |
+
|
| 3644 |
+
GGML_ASSERT(dst->ne[0] == nc);
|
| 3645 |
+
GGML_ASSERT(ggml_nrows(dst) == nr);
|
| 3646 |
+
|
| 3647 |
+
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
|
| 3648 |
+
|
| 3649 |
+
// rows per thread
|
| 3650 |
+
const int dr = (nr + nth - 1)/nth;
|
| 3651 |
+
|
| 3652 |
+
// row range for this thread
|
| 3653 |
+
const int ir0 = dr*ith;
|
| 3654 |
+
const int ir1 = MIN(ir0 + dr, nr);
|
| 3655 |
+
|
| 3656 |
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
| 3657 |
+
float * src0_p = (float *) (src0_d + i1*src0_o);
|
| 3658 |
+
float * src1_p = (float *) (src1_d + i1*src1_o);
|
| 3659 |
+
|
| 3660 |
+
if (!src1) {
|
| 3661 |
+
src0_p += swapped ? nc : 0;
|
| 3662 |
+
src1_p += swapped ? 0 : nc;
|
| 3663 |
+
}
|
| 3664 |
+
|
| 3665 |
+
ggml_vec_geglu_erf_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
| 3666 |
+
|
| 3667 |
+
#ifndef NDEBUG
|
| 3668 |
+
for (int k = 0; k < nc; k++) {
|
| 3669 |
+
const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
| 3670 |
+
GGML_UNUSED(x);
|
| 3671 |
+
assert(!isnan(x));
|
| 3672 |
+
assert(!isinf(x));
|
| 3673 |
+
}
|
| 3674 |
+
#endif
|
| 3675 |
+
}
|
| 3676 |
+
}
|
| 3677 |
+
|
| 3678 |
+
static void ggml_compute_forward_geglu_erf_f16(
|
| 3679 |
+
const ggml_compute_params * params,
|
| 3680 |
+
ggml_tensor * dst) {
|
| 3681 |
+
|
| 3682 |
+
const ggml_tensor * src0 = dst->src[0];
|
| 3683 |
+
const ggml_tensor * src1 = dst->src[1];
|
| 3684 |
+
char * src0_d = (char *) src0->data;
|
| 3685 |
+
char * src1_d = (char *) (src1 ? src1->data : src0->data);
|
| 3686 |
+
const size_t src0_o = src0->nb[1];
|
| 3687 |
+
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
| 3688 |
+
|
| 3689 |
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
| 3690 |
+
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
| 3691 |
+
|
| 3692 |
+
if (src1) {
|
| 3693 |
+
GGML_ASSERT(ggml_is_contiguous_1(src1));
|
| 3694 |
+
GGML_ASSERT(src0->type == src1->type);
|
| 3695 |
+
}
|
| 3696 |
+
|
| 3697 |
+
const int ith = params->ith;
|
| 3698 |
+
const int nth = params->nth;
|
| 3699 |
+
|
| 3700 |
+
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
| 3701 |
+
const int nr = ggml_nrows(src0);
|
| 3702 |
+
|
| 3703 |
+
GGML_ASSERT(dst->ne[0] == nc);
|
| 3704 |
+
GGML_ASSERT(ggml_nrows(dst) == nr);
|
| 3705 |
+
|
| 3706 |
+
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
|
| 3707 |
+
|
| 3708 |
+
// rows per thread
|
| 3709 |
+
const int dr = (nr + nth - 1)/nth;
|
| 3710 |
+
|
| 3711 |
+
// row range for this thread
|
| 3712 |
+
const int ir0 = dr*ith;
|
| 3713 |
+
const int ir1 = MIN(ir0 + dr, nr);
|
| 3714 |
+
|
| 3715 |
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
| 3716 |
+
ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
|
| 3717 |
+
ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
|
| 3718 |
+
|
| 3719 |
+
if (!src1) {
|
| 3720 |
+
src0_p += swapped ? nc : 0;
|
| 3721 |
+
src1_p += swapped ? 0 : nc;
|
| 3722 |
+
}
|
| 3723 |
+
|
| 3724 |
+
ggml_vec_geglu_erf_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
| 3725 |
+
|
| 3726 |
+
#ifndef NDEBUG
|
| 3727 |
+
for (int k = 0; k < nc; k++) {
|
| 3728 |
+
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
| 3729 |
+
const float v = GGML_FP16_TO_FP32(x);
|
| 3730 |
+
GGML_UNUSED(v);
|
| 3731 |
+
assert(!isnan(v));
|
| 3732 |
+
assert(!isinf(v));
|
| 3733 |
+
}
|
| 3734 |
+
#endif
|
| 3735 |
+
}
|
| 3736 |
+
}
|
| 3737 |
+
|
| 3738 |
+
static void ggml_compute_forward_geglu_erf(
|
| 3739 |
+
const ggml_compute_params * params,
|
| 3740 |
+
ggml_tensor * dst) {
|
| 3741 |
+
|
| 3742 |
+
const ggml_tensor * src0 = dst->src[0];
|
| 3743 |
+
|
| 3744 |
+
switch (src0->type) {
|
| 3745 |
+
case GGML_TYPE_F32:
|
| 3746 |
+
{
|
| 3747 |
+
ggml_compute_forward_geglu_erf_f32(params, dst);
|
| 3748 |
+
} break;
|
| 3749 |
+
case GGML_TYPE_F16:
|
| 3750 |
+
{
|
| 3751 |
+
ggml_compute_forward_geglu_erf_f16(params, dst);
|
| 3752 |
+
} break;
|
| 3753 |
+
default:
|
| 3754 |
+
{
|
| 3755 |
+
GGML_ABORT("fatal error");
|
| 3756 |
+
}
|
| 3757 |
+
}
|
| 3758 |
+
}
|
| 3759 |
+
|
| 3760 |
+
// ggml_compute_forward_geglu_quick
|
| 3761 |
+
|
| 3762 |
+
static void ggml_compute_forward_geglu_quick_f32(
|
| 3763 |
+
const ggml_compute_params * params,
|
| 3764 |
+
ggml_tensor * dst) {
|
| 3765 |
+
|
| 3766 |
+
const ggml_tensor * src0 = dst->src[0];
|
| 3767 |
+
const ggml_tensor * src1 = dst->src[1];
|
| 3768 |
+
char * src0_d = (char *) src0->data;
|
| 3769 |
+
char * src1_d = (char *) (src1 ? src1->data : src0->data);
|
| 3770 |
+
const size_t src0_o = src0->nb[1];
|
| 3771 |
+
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
| 3772 |
+
|
| 3773 |
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
| 3774 |
+
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
| 3775 |
+
|
| 3776 |
+
if (src1) {
|
| 3777 |
+
GGML_ASSERT(ggml_is_contiguous_1(src1));
|
| 3778 |
+
GGML_ASSERT(src0->type == src1->type);
|
| 3779 |
+
}
|
| 3780 |
+
|
| 3781 |
+
const int ith = params->ith;
|
| 3782 |
+
const int nth = params->nth;
|
| 3783 |
+
|
| 3784 |
+
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
| 3785 |
+
const int nr = ggml_nrows(src0);
|
| 3786 |
+
|
| 3787 |
+
GGML_ASSERT(dst->ne[0] == nc);
|
| 3788 |
+
GGML_ASSERT(ggml_nrows(dst) == nr);
|
| 3789 |
+
|
| 3790 |
+
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
|
| 3791 |
+
|
| 3792 |
+
// rows per thread
|
| 3793 |
+
const int dr = (nr + nth - 1)/nth;
|
| 3794 |
+
|
| 3795 |
+
// row range for this thread
|
| 3796 |
+
const int ir0 = dr*ith;
|
| 3797 |
+
const int ir1 = MIN(ir0 + dr, nr);
|
| 3798 |
+
|
| 3799 |
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
| 3800 |
+
float * src0_p = (float *) (src0_d + i1*src0_o);
|
| 3801 |
+
float * src1_p = (float *) (src1_d + i1*src1_o);
|
| 3802 |
+
|
| 3803 |
+
if (!src1) {
|
| 3804 |
+
src0_p += swapped ? nc : 0;
|
| 3805 |
+
src1_p += swapped ? 0 : nc;
|
| 3806 |
+
}
|
| 3807 |
+
|
| 3808 |
+
ggml_vec_geglu_quick_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
| 3809 |
+
|
| 3810 |
+
#ifndef NDEBUG
|
| 3811 |
+
for (int k = 0; k < nc; k++) {
|
| 3812 |
+
const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
| 3813 |
+
GGML_UNUSED(x);
|
| 3814 |
+
assert(!isnan(x));
|
| 3815 |
+
assert(!isinf(x));
|
| 3816 |
+
}
|
| 3817 |
+
#endif
|
| 3818 |
+
}
|
| 3819 |
+
}
|
| 3820 |
+
|
| 3821 |
+
static void ggml_compute_forward_geglu_quick_f16(
|
| 3822 |
+
const ggml_compute_params * params,
|
| 3823 |
+
ggml_tensor * dst) {
|
| 3824 |
+
|
| 3825 |
+
const ggml_tensor * src0 = dst->src[0];
|
| 3826 |
+
const ggml_tensor * src1 = dst->src[1];
|
| 3827 |
+
char * src0_d = (char *) src0->data;
|
| 3828 |
+
char * src1_d = (char *) (src1 ? src1->data : src0->data);
|
| 3829 |
+
const size_t src0_o = src0->nb[1];
|
| 3830 |
+
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
| 3831 |
+
|
| 3832 |
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
| 3833 |
+
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
| 3834 |
+
|
| 3835 |
+
if (src1) {
|
| 3836 |
+
GGML_ASSERT(ggml_is_contiguous_1(src1));
|
| 3837 |
+
GGML_ASSERT(src0->type == src1->type);
|
| 3838 |
+
}
|
| 3839 |
+
|
| 3840 |
+
const int ith = params->ith;
|
| 3841 |
+
const int nth = params->nth;
|
| 3842 |
+
|
| 3843 |
+
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
| 3844 |
+
const int nr = ggml_nrows(src0);
|
| 3845 |
+
|
| 3846 |
+
GGML_ASSERT(dst->ne[0] == nc);
|
| 3847 |
+
GGML_ASSERT(ggml_nrows(dst) == nr);
|
| 3848 |
+
|
| 3849 |
+
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
|
| 3850 |
+
|
| 3851 |
+
// rows per thread
|
| 3852 |
+
const int dr = (nr + nth - 1)/nth;
|
| 3853 |
+
|
| 3854 |
+
// row range for this thread
|
| 3855 |
+
const int ir0 = dr*ith;
|
| 3856 |
+
const int ir1 = MIN(ir0 + dr, nr);
|
| 3857 |
+
|
| 3858 |
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
| 3859 |
+
ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
|
| 3860 |
+
ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
|
| 3861 |
+
|
| 3862 |
+
if (!src1) {
|
| 3863 |
+
src0_p += swapped ? nc : 0;
|
| 3864 |
+
src1_p += swapped ? 0 : nc;
|
| 3865 |
+
}
|
| 3866 |
+
|
| 3867 |
+
ggml_vec_geglu_quick_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
| 3868 |
+
|
| 3869 |
+
#ifndef NDEBUG
|
| 3870 |
+
for (int k = 0; k < nc; k++) {
|
| 3871 |
+
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
| 3872 |
+
const float v = GGML_FP16_TO_FP32(x);
|
| 3873 |
+
GGML_UNUSED(v);
|
| 3874 |
+
assert(!isnan(v));
|
| 3875 |
+
assert(!isinf(v));
|
| 3876 |
+
}
|
| 3877 |
+
#endif
|
| 3878 |
+
}
|
| 3879 |
+
}
|
| 3880 |
+
|
| 3881 |
+
static void ggml_compute_forward_geglu_quick(
|
| 3882 |
+
const ggml_compute_params * params,
|
| 3883 |
+
ggml_tensor * dst) {
|
| 3884 |
+
|
| 3885 |
+
const ggml_tensor * src0 = dst->src[0];
|
| 3886 |
+
|
| 3887 |
+
switch (src0->type) {
|
| 3888 |
+
case GGML_TYPE_F32:
|
| 3889 |
+
{
|
| 3890 |
+
ggml_compute_forward_geglu_quick_f32(params, dst);
|
| 3891 |
+
} break;
|
| 3892 |
+
case GGML_TYPE_F16:
|
| 3893 |
+
{
|
| 3894 |
+
ggml_compute_forward_geglu_quick_f16(params, dst);
|
| 3895 |
+
} break;
|
| 3896 |
+
default:
|
| 3897 |
+
{
|
| 3898 |
+
GGML_ABORT("fatal error");
|
| 3899 |
+
}
|
| 3900 |
+
}
|
| 3901 |
+
}
|
| 3902 |
+
|
| 3903 |
// ggml_compute_forward_norm
|
| 3904 |
|
| 3905 |
static void ggml_compute_forward_norm_f32(
|
|
|
|
| 9065 |
{
|
| 9066 |
ggml_compute_forward_swiglu(params, dst);
|
| 9067 |
} break;
|
| 9068 |
+
case GGML_GLU_OP_GEGLU_ERF:
|
| 9069 |
+
{
|
| 9070 |
+
ggml_compute_forward_geglu_erf(params, dst);
|
| 9071 |
+
} break;
|
| 9072 |
+
case GGML_GLU_OP_GEGLU_QUICK:
|
| 9073 |
+
{
|
| 9074 |
+
ggml_compute_forward_geglu_quick(params, dst);
|
| 9075 |
+
} break;
|
| 9076 |
default:
|
| 9077 |
{
|
| 9078 |
GGML_ABORT("fatal error");
|
ggml/src/ggml-cpu/vec.h
CHANGED
|
@@ -959,6 +959,46 @@ inline static void ggml_vec_swiglu_f16(const int n, ggml_fp16_t * y, const ggml_
|
|
| 959 |
}
|
| 960 |
}
|
| 961 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 962 |
inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) {
|
| 963 |
#ifndef GGML_USE_ACCELERATE
|
| 964 |
ggml_float sum = 0.0;
|
|
|
|
| 959 |
}
|
| 960 |
}
|
| 961 |
|
| 962 |
+
inline static void ggml_vec_geglu_erf_f32(const int n, float * y, const float * x, const float * g) {
|
| 963 |
+
for (int i = 0; i < n; ++i) {
|
| 964 |
+
float xi = x[i];
|
| 965 |
+
y[i] = 0.5f * xi * (1.0f + erff(xi*SQRT_2_INV)) * g[i];
|
| 966 |
+
}
|
| 967 |
+
}
|
| 968 |
+
|
| 969 |
+
inline static void ggml_vec_geglu_erf_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
|
| 970 |
+
for (int i = 0; i < n; ++i) {
|
| 971 |
+
float xi = GGML_CPU_FP16_TO_FP32(x[i]);
|
| 972 |
+
float gi = GGML_CPU_FP16_TO_FP32(g[i]);
|
| 973 |
+
y[i] = GGML_CPU_FP32_TO_FP16(0.5f * xi * (1.0f + erff(xi*SQRT_2_INV)) * gi);
|
| 974 |
+
}
|
| 975 |
+
}
|
| 976 |
+
|
| 977 |
+
#ifdef GGML_GELU_QUICK_FP16
|
| 978 |
+
inline static void ggml_vec_geglu_quick_f32(const int n, float * y, const float * x, const float * g) {
|
| 979 |
+
uint16_t t;
|
| 980 |
+
for (int i = 0; i < n; ++i) {
|
| 981 |
+
ggml_fp16_t fp16 = GGML_CPU_FP32_TO_FP16(x[i]);
|
| 982 |
+
memcpy(&t, &fp16, sizeof(uint16_t));
|
| 983 |
+
y[i] = GGML_CPU_FP16_TO_FP32(ggml_table_gelu_quick_f16[t]) * g[i];
|
| 984 |
+
}
|
| 985 |
+
}
|
| 986 |
+
#else
|
| 987 |
+
inline static void ggml_vec_geglu_quick_f32(const int n, float * y, const float * x, const float * g) {
|
| 988 |
+
for (int i = 0; i < n; ++i) {
|
| 989 |
+
y[i] = ggml_gelu_quick_f32(x[i]) * g[i];
|
| 990 |
+
}
|
| 991 |
+
}
|
| 992 |
+
#endif
|
| 993 |
+
|
| 994 |
+
inline static void ggml_vec_geglu_quick_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
|
| 995 |
+
const uint16_t * i16 = (const uint16_t *) x;
|
| 996 |
+
for (int i = 0; i < n; ++i) {
|
| 997 |
+
float v = GGML_CPU_FP16_TO_FP32(g[i]);
|
| 998 |
+
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(ggml_table_gelu_quick_f16[i16[i]]) * v);
|
| 999 |
+
}
|
| 1000 |
+
}
|
| 1001 |
+
|
| 1002 |
inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) {
|
| 1003 |
#ifndef GGML_USE_ACCELERATE
|
| 1004 |
ggml_float sum = 0.0;
|
ggml/src/ggml-cuda/ggml-cuda.cu
CHANGED
|
@@ -2314,6 +2314,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
|
| 2314 |
case GGML_GLU_OP_SWIGLU:
|
| 2315 |
ggml_cuda_op_swiglu(ctx, dst);
|
| 2316 |
break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2317 |
default:
|
| 2318 |
return false;
|
| 2319 |
}
|
|
@@ -3116,6 +3122,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|
| 3116 |
case GGML_GLU_OP_REGLU:
|
| 3117 |
case GGML_GLU_OP_GEGLU:
|
| 3118 |
case GGML_GLU_OP_SWIGLU:
|
|
|
|
|
|
|
| 3119 |
return ggml_is_contiguous_1(op->src[0]);
|
| 3120 |
default:
|
| 3121 |
return false;
|
|
|
|
| 2314 |
case GGML_GLU_OP_SWIGLU:
|
| 2315 |
ggml_cuda_op_swiglu(ctx, dst);
|
| 2316 |
break;
|
| 2317 |
+
case GGML_GLU_OP_GEGLU_ERF:
|
| 2318 |
+
ggml_cuda_op_geglu_erf(ctx, dst);
|
| 2319 |
+
break;
|
| 2320 |
+
case GGML_GLU_OP_GEGLU_QUICK:
|
| 2321 |
+
ggml_cuda_op_geglu_quick(ctx, dst);
|
| 2322 |
+
break;
|
| 2323 |
default:
|
| 2324 |
return false;
|
| 2325 |
}
|
|
|
|
| 3122 |
case GGML_GLU_OP_REGLU:
|
| 3123 |
case GGML_GLU_OP_GEGLU:
|
| 3124 |
case GGML_GLU_OP_SWIGLU:
|
| 3125 |
+
case GGML_GLU_OP_GEGLU_ERF:
|
| 3126 |
+
case GGML_GLU_OP_GEGLU_QUICK:
|
| 3127 |
return ggml_is_contiguous_1(op->src[0]);
|
| 3128 |
default:
|
| 3129 |
return false;
|
ggml/src/ggml-cuda/unary.cu
CHANGED
|
@@ -285,6 +285,14 @@ void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
| 285 |
ggml_cuda_op_unary_gated<op_silu>(ctx, dst);
|
| 286 |
}
|
| 287 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 288 |
/* silu_back */
|
| 289 |
|
| 290 |
static __device__ __forceinline__ float op_silu_back(float grad, float x) {
|
|
|
|
| 285 |
ggml_cuda_op_unary_gated<op_silu>(ctx, dst);
|
| 286 |
}
|
| 287 |
|
| 288 |
+
void ggml_cuda_op_geglu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 289 |
+
ggml_cuda_op_unary_gated<op_gelu_erf>(ctx, dst);
|
| 290 |
+
}
|
| 291 |
+
|
| 292 |
+
void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 293 |
+
ggml_cuda_op_unary_gated<op_gelu_quick>(ctx, dst);
|
| 294 |
+
}
|
| 295 |
+
|
| 296 |
/* silu_back */
|
| 297 |
|
| 298 |
static __device__ __forceinline__ float op_silu_back(float grad, float x) {
|
ggml/src/ggml-cuda/unary.cuh
CHANGED
|
@@ -64,3 +64,7 @@ void ggml_cuda_op_reglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
|
| 64 |
void ggml_cuda_op_geglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
| 65 |
|
| 66 |
void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
void ggml_cuda_op_geglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
| 65 |
|
| 66 |
void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
| 67 |
+
|
| 68 |
+
void ggml_cuda_op_geglu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
| 69 |
+
|
| 70 |
+
void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
ggml/src/ggml-metal/ggml-metal.m
CHANGED
|
@@ -530,6 +530,8 @@ enum ggml_metal_kernel_type {
|
|
| 530 |
GGML_METAL_KERNEL_TYPE_REGLU,
|
| 531 |
GGML_METAL_KERNEL_TYPE_GEGLU,
|
| 532 |
GGML_METAL_KERNEL_TYPE_SWIGLU,
|
|
|
|
|
|
|
| 533 |
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
|
| 534 |
GGML_METAL_KERNEL_TYPE_MEAN,
|
| 535 |
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
|
|
@@ -1510,6 +1512,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
| 1510 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REGLU, reglu, true);
|
| 1511 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU, geglu, true);
|
| 1512 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU, swiglu, true);
|
|
|
|
|
|
|
| 1513 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
| 1514 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MEAN, mean, true);
|
| 1515 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
|
|
@@ -1693,6 +1697,8 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
|
| 1693 |
case GGML_GLU_OP_REGLU:
|
| 1694 |
case GGML_GLU_OP_GEGLU:
|
| 1695 |
case GGML_GLU_OP_SWIGLU:
|
|
|
|
|
|
|
| 1696 |
return ggml_is_contiguous_1(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
|
| 1697 |
default:
|
| 1698 |
return false;
|
|
@@ -2456,6 +2462,12 @@ static bool ggml_metal_encode_node(
|
|
| 2456 |
case GGML_GLU_OP_SWIGLU:
|
| 2457 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU].pipeline;
|
| 2458 |
break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2459 |
default:
|
| 2460 |
GGML_ABORT("fatal error");
|
| 2461 |
}
|
|
|
|
| 530 |
GGML_METAL_KERNEL_TYPE_REGLU,
|
| 531 |
GGML_METAL_KERNEL_TYPE_GEGLU,
|
| 532 |
GGML_METAL_KERNEL_TYPE_SWIGLU,
|
| 533 |
+
GGML_METAL_KERNEL_TYPE_GEGLU_ERF,
|
| 534 |
+
GGML_METAL_KERNEL_TYPE_GEGLU_QUICK,
|
| 535 |
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
|
| 536 |
GGML_METAL_KERNEL_TYPE_MEAN,
|
| 537 |
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
|
|
|
|
| 1512 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REGLU, reglu, true);
|
| 1513 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU, geglu, true);
|
| 1514 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU, swiglu, true);
|
| 1515 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU_ERF, geglu_erf, true);
|
| 1516 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU_QUICK, geglu_quick, true);
|
| 1517 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
| 1518 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MEAN, mean, true);
|
| 1519 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
|
|
|
|
| 1697 |
case GGML_GLU_OP_REGLU:
|
| 1698 |
case GGML_GLU_OP_GEGLU:
|
| 1699 |
case GGML_GLU_OP_SWIGLU:
|
| 1700 |
+
case GGML_GLU_OP_GEGLU_ERF:
|
| 1701 |
+
case GGML_GLU_OP_GEGLU_QUICK:
|
| 1702 |
return ggml_is_contiguous_1(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
|
| 1703 |
default:
|
| 1704 |
return false;
|
|
|
|
| 2462 |
case GGML_GLU_OP_SWIGLU:
|
| 2463 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU].pipeline;
|
| 2464 |
break;
|
| 2465 |
+
case GGML_GLU_OP_GEGLU_ERF:
|
| 2466 |
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GEGLU_ERF].pipeline;
|
| 2467 |
+
break;
|
| 2468 |
+
case GGML_GLU_OP_GEGLU_QUICK:
|
| 2469 |
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GEGLU_QUICK].pipeline;
|
| 2470 |
+
break;
|
| 2471 |
default:
|
| 2472 |
GGML_ABORT("fatal error");
|
| 2473 |
}
|
ggml/src/ggml-metal/ggml-metal.metal
CHANGED
|
@@ -1258,6 +1258,50 @@ kernel void kernel_swiglu(
|
|
| 1258 |
}
|
| 1259 |
}
|
| 1260 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1261 |
template <bool norm>
|
| 1262 |
kernel void kernel_sum_rows(
|
| 1263 |
constant ggml_metal_kargs_sum_rows & args,
|
|
|
|
| 1258 |
}
|
| 1259 |
}
|
| 1260 |
|
| 1261 |
+
kernel void kernel_geglu_erf(
|
| 1262 |
+
device const char * src0,
|
| 1263 |
+
device const char * src1,
|
| 1264 |
+
device char * dst,
|
| 1265 |
+
constant ggml_metal_kargs_glu & args,
|
| 1266 |
+
uint tgpig[[threadgroup_position_in_grid]],
|
| 1267 |
+
uint tpitg[[thread_position_in_threadgroup]],
|
| 1268 |
+
uint ntg[[threads_per_threadgroup]]) {
|
| 1269 |
+
device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
|
| 1270 |
+
device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
|
| 1271 |
+
device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
|
| 1272 |
+
|
| 1273 |
+
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
|
| 1274 |
+
const float x0 = src0_row[i0];
|
| 1275 |
+
const float x1 = src1_row[i0];
|
| 1276 |
+
|
| 1277 |
+
const float gelu_erf = 0.5f*x0*(1.0f+erf_approx<float>(x0*SQRT_2_INV));
|
| 1278 |
+
|
| 1279 |
+
dst_row[i0] = gelu_erf*x1;
|
| 1280 |
+
}
|
| 1281 |
+
}
|
| 1282 |
+
|
| 1283 |
+
kernel void kernel_geglu_quick(
|
| 1284 |
+
device const char * src0,
|
| 1285 |
+
device const char * src1,
|
| 1286 |
+
device char * dst,
|
| 1287 |
+
constant ggml_metal_kargs_glu & args,
|
| 1288 |
+
uint tgpig[[threadgroup_position_in_grid]],
|
| 1289 |
+
uint tpitg[[thread_position_in_threadgroup]],
|
| 1290 |
+
uint ntg[[threads_per_threadgroup]]) {
|
| 1291 |
+
device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
|
| 1292 |
+
device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
|
| 1293 |
+
device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
|
| 1294 |
+
|
| 1295 |
+
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
|
| 1296 |
+
const float x0 = src0_row[i0];
|
| 1297 |
+
const float x1 = src1_row[i0];
|
| 1298 |
+
|
| 1299 |
+
const float gelu_quick = x0*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x0)));
|
| 1300 |
+
|
| 1301 |
+
dst_row[i0] = gelu_quick*x1;
|
| 1302 |
+
}
|
| 1303 |
+
}
|
| 1304 |
+
|
| 1305 |
template <bool norm>
|
| 1306 |
kernel void kernel_sum_rows(
|
| 1307 |
constant ggml_metal_kargs_sum_rows & args,
|
ggml/src/ggml-opencl/ggml-opencl.cpp
CHANGED
|
@@ -402,8 +402,8 @@ struct ggml_backend_opencl_context {
|
|
| 402 |
cl_kernel kernel_relu;
|
| 403 |
cl_kernel kernel_sigmoid_f32, kernel_sigmoid_f16;
|
| 404 |
cl_kernel kernel_clamp;
|
| 405 |
-
cl_kernel kernel_geglu, kernel_reglu, kernel_swiglu,
|
| 406 |
-
kernel_geglu_f16, kernel_reglu_f16, kernel_swiglu_f16;
|
| 407 |
cl_kernel kernel_norm;
|
| 408 |
cl_kernel kernel_rms_norm;
|
| 409 |
cl_kernel kernel_group_norm;
|
|
@@ -753,12 +753,16 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
|
|
| 753 |
backend_ctx->program_glu =
|
| 754 |
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
|
| 755 |
|
| 756 |
-
CL_CHECK((backend_ctx->kernel_geglu
|
| 757 |
-
CL_CHECK((backend_ctx->kernel_reglu
|
| 758 |
-
CL_CHECK((backend_ctx->kernel_swiglu
|
| 759 |
-
CL_CHECK((backend_ctx->
|
| 760 |
-
CL_CHECK((backend_ctx->
|
| 761 |
-
CL_CHECK((backend_ctx->
|
|
|
|
|
|
|
|
|
|
|
|
|
| 762 |
GGML_LOG_CONT(".");
|
| 763 |
}
|
| 764 |
|
|
@@ -2277,6 +2281,8 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
|
|
| 2277 |
case GGML_GLU_OP_GEGLU:
|
| 2278 |
case GGML_GLU_OP_REGLU:
|
| 2279 |
case GGML_GLU_OP_SWIGLU:
|
|
|
|
|
|
|
| 2280 |
return ggml_is_contiguous_1(op->src[0]) && (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16);
|
| 2281 |
default:
|
| 2282 |
return false;
|
|
@@ -6254,6 +6260,20 @@ static void ggml_cl_glu(ggml_backend_t backend, const ggml_tensor * src0, const
|
|
| 6254 |
kernel = backend_ctx->kernel_swiglu_f16;
|
| 6255 |
}
|
| 6256 |
break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6257 |
default:
|
| 6258 |
GGML_ABORT("Unsupported glu op");
|
| 6259 |
}
|
|
|
|
| 402 |
cl_kernel kernel_relu;
|
| 403 |
cl_kernel kernel_sigmoid_f32, kernel_sigmoid_f16;
|
| 404 |
cl_kernel kernel_clamp;
|
| 405 |
+
cl_kernel kernel_geglu, kernel_reglu, kernel_swiglu, kernel_geglu_erf, kernel_geglu_quick,
|
| 406 |
+
kernel_geglu_f16, kernel_reglu_f16, kernel_swiglu_f16, kernel_geglu_erf_f16, kernel_geglu_quick_f16;
|
| 407 |
cl_kernel kernel_norm;
|
| 408 |
cl_kernel kernel_rms_norm;
|
| 409 |
cl_kernel kernel_group_norm;
|
|
|
|
| 753 |
backend_ctx->program_glu =
|
| 754 |
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
|
| 755 |
|
| 756 |
+
CL_CHECK((backend_ctx->kernel_geglu = clCreateKernel(backend_ctx->program_glu, "kernel_geglu", &err), err));
|
| 757 |
+
CL_CHECK((backend_ctx->kernel_reglu = clCreateKernel(backend_ctx->program_glu, "kernel_reglu", &err), err));
|
| 758 |
+
CL_CHECK((backend_ctx->kernel_swiglu = clCreateKernel(backend_ctx->program_glu, "kernel_swiglu", &err), err));
|
| 759 |
+
CL_CHECK((backend_ctx->kernel_geglu_erf = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_erf", &err), err));
|
| 760 |
+
CL_CHECK((backend_ctx->kernel_geglu_quick = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_quick", &err), err));
|
| 761 |
+
CL_CHECK((backend_ctx->kernel_geglu_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_f16", &err), err));
|
| 762 |
+
CL_CHECK((backend_ctx->kernel_reglu_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_reglu_f16", &err), err));
|
| 763 |
+
CL_CHECK((backend_ctx->kernel_swiglu_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_swiglu_f16", &err), err));
|
| 764 |
+
CL_CHECK((backend_ctx->kernel_geglu_erf_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_erf_f16", &err), err));
|
| 765 |
+
CL_CHECK((backend_ctx->kernel_geglu_quick_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_quick_f16", &err), err));
|
| 766 |
GGML_LOG_CONT(".");
|
| 767 |
}
|
| 768 |
|
|
|
|
| 2281 |
case GGML_GLU_OP_GEGLU:
|
| 2282 |
case GGML_GLU_OP_REGLU:
|
| 2283 |
case GGML_GLU_OP_SWIGLU:
|
| 2284 |
+
case GGML_GLU_OP_GEGLU_ERF:
|
| 2285 |
+
case GGML_GLU_OP_GEGLU_QUICK:
|
| 2286 |
return ggml_is_contiguous_1(op->src[0]) && (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16);
|
| 2287 |
default:
|
| 2288 |
return false;
|
|
|
|
| 6260 |
kernel = backend_ctx->kernel_swiglu_f16;
|
| 6261 |
}
|
| 6262 |
break;
|
| 6263 |
+
case GGML_GLU_OP_GEGLU_ERF:
|
| 6264 |
+
if (dst->type == GGML_TYPE_F32) {
|
| 6265 |
+
kernel = backend_ctx->kernel_geglu_erf;
|
| 6266 |
+
} else {
|
| 6267 |
+
kernel = backend_ctx->kernel_geglu_erf_f16;
|
| 6268 |
+
}
|
| 6269 |
+
break;
|
| 6270 |
+
case GGML_GLU_OP_GEGLU_QUICK:
|
| 6271 |
+
if (dst->type == GGML_TYPE_F32) {
|
| 6272 |
+
kernel = backend_ctx->kernel_geglu_quick;
|
| 6273 |
+
} else {
|
| 6274 |
+
kernel = backend_ctx->kernel_geglu_quick_f16;
|
| 6275 |
+
}
|
| 6276 |
+
break;
|
| 6277 |
default:
|
| 6278 |
GGML_ABORT("Unsupported glu op");
|
| 6279 |
}
|
ggml/src/ggml-opencl/kernels/glu.cl
CHANGED
|
@@ -1,7 +1,9 @@
|
|
| 1 |
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
| 2 |
|
| 3 |
#define GELU_COEF_A 0.044715f
|
|
|
|
| 4 |
#define SQRT_2_OVER_PI 0.79788456080286535587989211986876f
|
|
|
|
| 5 |
|
| 6 |
//------------------------------------------------------------------------------
|
| 7 |
// geglu
|
|
@@ -199,3 +201,137 @@ kernel void kernel_swiglu_f16(
|
|
| 199 |
dst_row[i0] = silu*x1;
|
| 200 |
}
|
| 201 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
| 2 |
|
| 3 |
#define GELU_COEF_A 0.044715f
|
| 4 |
+
#define GELU_QUICK_COEF -1.702f
|
| 5 |
#define SQRT_2_OVER_PI 0.79788456080286535587989211986876f
|
| 6 |
+
#define SQRT_2_INV 0.70710678118654752440084436210484f
|
| 7 |
|
| 8 |
//------------------------------------------------------------------------------
|
| 9 |
// geglu
|
|
|
|
| 201 |
dst_row[i0] = silu*x1;
|
| 202 |
}
|
| 203 |
}
|
| 204 |
+
|
| 205 |
+
//------------------------------------------------------------------------------
|
| 206 |
+
// geglu_erf
|
| 207 |
+
//------------------------------------------------------------------------------
|
| 208 |
+
kernel void kernel_geglu_erf(
|
| 209 |
+
global char * src0,
|
| 210 |
+
ulong offset0,
|
| 211 |
+
global char * src1,
|
| 212 |
+
ulong offset1,
|
| 213 |
+
global char * dst,
|
| 214 |
+
ulong offsetd,
|
| 215 |
+
ulong nb01,
|
| 216 |
+
ulong nb11,
|
| 217 |
+
int ne0,
|
| 218 |
+
ulong nb1,
|
| 219 |
+
int ne00_off,
|
| 220 |
+
int ne10_off
|
| 221 |
+
) {
|
| 222 |
+
src0 = (global char*)((global char*)src0 + offset0);
|
| 223 |
+
src1 = (global char*)((global char*)src1 + offset1);
|
| 224 |
+
dst = (global char*)((global char*)dst + offsetd);
|
| 225 |
+
|
| 226 |
+
global float * src0_row = (global float *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
|
| 227 |
+
global float * src1_row = (global float *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
|
| 228 |
+
global float * dst_row = (global float *) ((global char *) dst + get_group_id(0)*nb1);
|
| 229 |
+
|
| 230 |
+
for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
|
| 231 |
+
const float x0 = src0_row[i0];
|
| 232 |
+
const float x1 = src1_row[i0];
|
| 233 |
+
|
| 234 |
+
const float gelu_erf = 0.5f*x0*(1.0f + erf(x0*SQRT_2_INV));
|
| 235 |
+
|
| 236 |
+
dst_row[i0] = gelu_erf*x1;
|
| 237 |
+
}
|
| 238 |
+
}
|
| 239 |
+
|
| 240 |
+
kernel void kernel_geglu_erf_f16(
|
| 241 |
+
global char * src0,
|
| 242 |
+
ulong offset0,
|
| 243 |
+
global char * src1,
|
| 244 |
+
ulong offset1,
|
| 245 |
+
global char * dst,
|
| 246 |
+
ulong offsetd,
|
| 247 |
+
ulong nb01,
|
| 248 |
+
ulong nb11,
|
| 249 |
+
int ne0,
|
| 250 |
+
ulong nb1,
|
| 251 |
+
int ne00_off,
|
| 252 |
+
int ne10_off
|
| 253 |
+
) {
|
| 254 |
+
src0 = (global char*)((global char*)src0 + offset0);
|
| 255 |
+
src1 = (global char*)((global char*)src1 + offset1);
|
| 256 |
+
dst = (global char*)((global char*)dst + offsetd);
|
| 257 |
+
|
| 258 |
+
global half * src0_row = (global half *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
|
| 259 |
+
global half * src1_row = (global half *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
|
| 260 |
+
global half * dst_row = (global half *) ((global char *) dst + get_group_id(0)*nb1);
|
| 261 |
+
|
| 262 |
+
for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
|
| 263 |
+
const half x0 = src0_row[i0];
|
| 264 |
+
const half x1 = src1_row[i0];
|
| 265 |
+
|
| 266 |
+
const half gelu_erf = 0.5f*x0*(1.0f + erf(x0*SQRT_2_INV));
|
| 267 |
+
|
| 268 |
+
dst_row[i0] = gelu_erf*x1;
|
| 269 |
+
}
|
| 270 |
+
}
|
| 271 |
+
|
| 272 |
+
//------------------------------------------------------------------------------
|
| 273 |
+
// geglu_quick
|
| 274 |
+
//------------------------------------------------------------------------------
|
| 275 |
+
kernel void kernel_geglu_quick(
|
| 276 |
+
global char * src0,
|
| 277 |
+
ulong offset0,
|
| 278 |
+
global char * src1,
|
| 279 |
+
ulong offset1,
|
| 280 |
+
global char * dst,
|
| 281 |
+
ulong offsetd,
|
| 282 |
+
ulong nb01,
|
| 283 |
+
ulong nb11,
|
| 284 |
+
int ne0,
|
| 285 |
+
ulong nb1,
|
| 286 |
+
int ne00_off,
|
| 287 |
+
int ne10_off
|
| 288 |
+
) {
|
| 289 |
+
src0 = (global char*)((global char*)src0 + offset0);
|
| 290 |
+
src1 = (global char*)((global char*)src1 + offset1);
|
| 291 |
+
dst = (global char*)((global char*)dst + offsetd);
|
| 292 |
+
|
| 293 |
+
global float * src0_row = (global float *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
|
| 294 |
+
global float * src1_row = (global float *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
|
| 295 |
+
global float * dst_row = (global float *) ((global char *) dst + get_group_id(0)*nb1);
|
| 296 |
+
|
| 297 |
+
for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
|
| 298 |
+
const float x0 = src0_row[i0];
|
| 299 |
+
const float x1 = src1_row[i0];
|
| 300 |
+
|
| 301 |
+
const float gelu_quick = x0*(1.0f/(1.0f + exp(GELU_QUICK_COEF*x0)));
|
| 302 |
+
|
| 303 |
+
dst_row[i0] = gelu_quick*x1;
|
| 304 |
+
}
|
| 305 |
+
}
|
| 306 |
+
|
| 307 |
+
kernel void kernel_geglu_quick_f16(
|
| 308 |
+
global char * src0,
|
| 309 |
+
ulong offset0,
|
| 310 |
+
global char * src1,
|
| 311 |
+
ulong offset1,
|
| 312 |
+
global char * dst,
|
| 313 |
+
ulong offsetd,
|
| 314 |
+
ulong nb01,
|
| 315 |
+
ulong nb11,
|
| 316 |
+
int ne0,
|
| 317 |
+
ulong nb1,
|
| 318 |
+
int ne00_off,
|
| 319 |
+
int ne10_off
|
| 320 |
+
) {
|
| 321 |
+
src0 = (global char*)((global char*)src0 + offset0);
|
| 322 |
+
src1 = (global char*)((global char*)src1 + offset1);
|
| 323 |
+
dst = (global char*)((global char*)dst + offsetd);
|
| 324 |
+
|
| 325 |
+
global half * src0_row = (global half *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
|
| 326 |
+
global half * src1_row = (global half *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
|
| 327 |
+
global half * dst_row = (global half *) ((global char *) dst + get_group_id(0)*nb1);
|
| 328 |
+
|
| 329 |
+
for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
|
| 330 |
+
const half x0 = src0_row[i0];
|
| 331 |
+
const half x1 = src1_row[i0];
|
| 332 |
+
|
| 333 |
+
const half gelu_quick = x0*(1.0f/(1.0f + exp(GELU_QUICK_COEF*x0)));
|
| 334 |
+
|
| 335 |
+
dst_row[i0] = gelu_quick*x1;
|
| 336 |
+
}
|
| 337 |
+
}
|
ggml/src/ggml-sycl/element_wise.cpp
CHANGED
|
@@ -383,6 +383,24 @@ static void gated_op_fused_swiglu(const T * x, const T * g, T * dst, const uint6
|
|
| 383 |
}
|
| 384 |
}
|
| 385 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 386 |
namespace ggml_sycl_detail {
|
| 387 |
static void acc_f32_sycl(const float *x, const float *y, float *dst,
|
| 388 |
const int n_elements, const int ne10, const int ne11,
|
|
@@ -978,6 +996,28 @@ static inline void ggml_sycl_op_swiglu(ggml_backend_sycl_context & ctx, ggml_ten
|
|
| 978 |
});
|
| 979 |
}
|
| 980 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 981 |
|
| 982 |
void ggml_sycl_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
| 983 |
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
|
@@ -1118,3 +1158,13 @@ void ggml_sycl_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
| 1118 |
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
| 1119 |
ggml_sycl_op_swiglu(ctx, dst);
|
| 1120 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 383 |
}
|
| 384 |
}
|
| 385 |
|
| 386 |
+
template<typename T>
|
| 387 |
+
static void gated_op_fused_geglu_erf(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o0, const uint64_t o1, const sycl::nd_item<1> &item_ct1) {
|
| 388 |
+
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
| 389 |
+
const int64_t j0 = (i / n) * o0 + (i % n);
|
| 390 |
+
const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n);
|
| 391 |
+
dst[i] = op_gelu_erf(x[j0]) * g[j1];
|
| 392 |
+
}
|
| 393 |
+
}
|
| 394 |
+
|
| 395 |
+
template<typename T>
|
| 396 |
+
static void gated_op_fused_geglu_quick(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o0, const uint64_t o1, const sycl::nd_item<1> &item_ct1) {
|
| 397 |
+
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
| 398 |
+
const int64_t j0 = (i / n) * o0 + (i % n);
|
| 399 |
+
const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n);
|
| 400 |
+
dst[i] = op_gelu_quick(x[j0]) * g[j1];
|
| 401 |
+
}
|
| 402 |
+
}
|
| 403 |
+
|
| 404 |
namespace ggml_sycl_detail {
|
| 405 |
static void acc_f32_sycl(const float *x, const float *y, float *dst,
|
| 406 |
const int n_elements, const int ne10, const int ne11,
|
|
|
|
| 996 |
});
|
| 997 |
}
|
| 998 |
|
| 999 |
+
static inline void ggml_sycl_op_geglu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
| 1000 |
+
ggml_sycl_detail::dispatch_ggml_sycl_op_fused_glu(ctx, dst,
|
| 1001 |
+
[](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
|
| 1002 |
+
const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE);
|
| 1003 |
+
sycl_parallel_for(main_stream,
|
| 1004 |
+
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
|
| 1005 |
+
gated_op_fused_geglu_erf(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
|
| 1006 |
+
});
|
| 1007 |
+
});
|
| 1008 |
+
}
|
| 1009 |
+
|
| 1010 |
+
static inline void ggml_sycl_op_geglu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
| 1011 |
+
ggml_sycl_detail::dispatch_ggml_sycl_op_fused_glu(ctx, dst,
|
| 1012 |
+
[](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
|
| 1013 |
+
const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE);
|
| 1014 |
+
sycl_parallel_for(main_stream,
|
| 1015 |
+
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
|
| 1016 |
+
gated_op_fused_geglu_quick(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
|
| 1017 |
+
});
|
| 1018 |
+
});
|
| 1019 |
+
}
|
| 1020 |
+
|
| 1021 |
|
| 1022 |
void ggml_sycl_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
| 1023 |
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
|
|
|
| 1158 |
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
| 1159 |
ggml_sycl_op_swiglu(ctx, dst);
|
| 1160 |
}
|
| 1161 |
+
|
| 1162 |
+
void ggml_sycl_geglu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
| 1163 |
+
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
| 1164 |
+
ggml_sycl_op_geglu_erf(ctx, dst);
|
| 1165 |
+
}
|
| 1166 |
+
|
| 1167 |
+
void ggml_sycl_geglu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
| 1168 |
+
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
| 1169 |
+
ggml_sycl_op_geglu_quick(ctx, dst);
|
| 1170 |
+
}
|
ggml/src/ggml-sycl/element_wise.hpp
CHANGED
|
@@ -80,5 +80,7 @@ void ggml_sycl_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
|
| 80 |
void ggml_sycl_geglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
| 81 |
void ggml_sycl_reglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
| 82 |
void ggml_sycl_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
|
|
|
|
|
|
| 83 |
|
| 84 |
#endif // GGML_SYCL_ELEMENTWISE_HPP
|
|
|
|
| 80 |
void ggml_sycl_geglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
| 81 |
void ggml_sycl_reglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
| 82 |
void ggml_sycl_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
| 83 |
+
void ggml_sycl_geglu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
| 84 |
+
void ggml_sycl_geglu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
| 85 |
|
| 86 |
#endif // GGML_SYCL_ELEMENTWISE_HPP
|
ggml/src/ggml-sycl/ggml-sycl.cpp
CHANGED
|
@@ -3687,6 +3687,12 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
|
|
| 3687 |
case GGML_GLU_OP_SWIGLU:
|
| 3688 |
ggml_sycl_swiglu(ctx, dst);
|
| 3689 |
break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3690 |
default:
|
| 3691 |
return false;
|
| 3692 |
}
|
|
@@ -4232,6 +4238,8 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
|
| 4232 |
case GGML_GLU_OP_REGLU:
|
| 4233 |
case GGML_GLU_OP_GEGLU:
|
| 4234 |
case GGML_GLU_OP_SWIGLU:
|
|
|
|
|
|
|
| 4235 |
return ggml_is_contiguous_1(op->src[0]);
|
| 4236 |
default:
|
| 4237 |
return false;
|
|
|
|
| 3687 |
case GGML_GLU_OP_SWIGLU:
|
| 3688 |
ggml_sycl_swiglu(ctx, dst);
|
| 3689 |
break;
|
| 3690 |
+
case GGML_GLU_OP_GEGLU_ERF:
|
| 3691 |
+
ggml_sycl_geglu_erf(ctx, dst);
|
| 3692 |
+
break;
|
| 3693 |
+
case GGML_GLU_OP_GEGLU_QUICK:
|
| 3694 |
+
ggml_sycl_geglu_quick(ctx, dst);
|
| 3695 |
+
break;
|
| 3696 |
default:
|
| 3697 |
return false;
|
| 3698 |
}
|
|
|
|
| 4238 |
case GGML_GLU_OP_REGLU:
|
| 4239 |
case GGML_GLU_OP_GEGLU:
|
| 4240 |
case GGML_GLU_OP_SWIGLU:
|
| 4241 |
+
case GGML_GLU_OP_GEGLU_ERF:
|
| 4242 |
+
case GGML_GLU_OP_GEGLU_QUICK:
|
| 4243 |
return ggml_is_contiguous_1(op->src[0]);
|
| 4244 |
default:
|
| 4245 |
return false;
|
ggml/src/ggml-vulkan/ggml-vulkan.cpp
CHANGED
|
@@ -456,6 +456,8 @@ struct vk_device_struct {
|
|
| 456 |
vk_pipeline pipeline_geglu[2];
|
| 457 |
vk_pipeline pipeline_reglu[2];
|
| 458 |
vk_pipeline pipeline_swiglu[2];
|
|
|
|
|
|
|
| 459 |
|
| 460 |
vk_pipeline pipeline_leaky_relu_f32;
|
| 461 |
vk_pipeline pipeline_silu_back_f32;
|
|
@@ -2821,6 +2823,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
| 2821 |
CREATE_GLU(geglu)
|
| 2822 |
CREATE_GLU(reglu)
|
| 2823 |
CREATE_GLU(swiglu)
|
|
|
|
|
|
|
| 2824 |
#undef CREATE_GLU
|
| 2825 |
|
| 2826 |
ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
|
@@ -6575,6 +6579,10 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
| 6575 |
return ctx->device->pipeline_reglu[dst->type == GGML_TYPE_F16];
|
| 6576 |
case GGML_GLU_OP_SWIGLU:
|
| 6577 |
return ctx->device->pipeline_swiglu[dst->type == GGML_TYPE_F16];
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6578 |
default:
|
| 6579 |
break;
|
| 6580 |
}
|
|
@@ -8919,6 +8927,8 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|
| 8919 |
case GGML_GLU_OP_GEGLU:
|
| 8920 |
case GGML_GLU_OP_REGLU:
|
| 8921 |
case GGML_GLU_OP_SWIGLU:
|
|
|
|
|
|
|
| 8922 |
break;
|
| 8923 |
default:
|
| 8924 |
return false;
|
|
@@ -9166,6 +9176,8 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|
| 9166 |
case GGML_GLU_OP_GEGLU:
|
| 9167 |
case GGML_GLU_OP_REGLU:
|
| 9168 |
case GGML_GLU_OP_SWIGLU:
|
|
|
|
|
|
|
| 9169 |
ggml_vk_glu(ctx, compute_ctx, src0, src1, node, dryrun);
|
| 9170 |
break;
|
| 9171 |
default:
|
|
@@ -9384,6 +9396,8 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
|
|
| 9384 |
case GGML_GLU_OP_GEGLU:
|
| 9385 |
case GGML_GLU_OP_REGLU:
|
| 9386 |
case GGML_GLU_OP_SWIGLU:
|
|
|
|
|
|
|
| 9387 |
buf = tensor->buffer;
|
| 9388 |
break;
|
| 9389 |
default:
|
|
@@ -10194,6 +10208,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
| 10194 |
case GGML_GLU_OP_GEGLU:
|
| 10195 |
case GGML_GLU_OP_REGLU:
|
| 10196 |
case GGML_GLU_OP_SWIGLU:
|
|
|
|
|
|
|
| 10197 |
return ggml_is_contiguous(op->src[0]) &&
|
| 10198 |
(op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
|
| 10199 |
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
|
|
|
|
| 456 |
vk_pipeline pipeline_geglu[2];
|
| 457 |
vk_pipeline pipeline_reglu[2];
|
| 458 |
vk_pipeline pipeline_swiglu[2];
|
| 459 |
+
vk_pipeline pipeline_geglu_erf[2];
|
| 460 |
+
vk_pipeline pipeline_geglu_quick[2];
|
| 461 |
|
| 462 |
vk_pipeline pipeline_leaky_relu_f32;
|
| 463 |
vk_pipeline pipeline_silu_back_f32;
|
|
|
|
| 2823 |
CREATE_GLU(geglu)
|
| 2824 |
CREATE_GLU(reglu)
|
| 2825 |
CREATE_GLU(swiglu)
|
| 2826 |
+
CREATE_GLU(geglu_erf)
|
| 2827 |
+
CREATE_GLU(geglu_quick)
|
| 2828 |
#undef CREATE_GLU
|
| 2829 |
|
| 2830 |
ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
|
|
|
| 6579 |
return ctx->device->pipeline_reglu[dst->type == GGML_TYPE_F16];
|
| 6580 |
case GGML_GLU_OP_SWIGLU:
|
| 6581 |
return ctx->device->pipeline_swiglu[dst->type == GGML_TYPE_F16];
|
| 6582 |
+
case GGML_GLU_OP_GEGLU_ERF:
|
| 6583 |
+
return ctx->device->pipeline_geglu_erf[dst->type == GGML_TYPE_F16];
|
| 6584 |
+
case GGML_GLU_OP_GEGLU_QUICK:
|
| 6585 |
+
return ctx->device->pipeline_geglu_quick[dst->type == GGML_TYPE_F16];
|
| 6586 |
default:
|
| 6587 |
break;
|
| 6588 |
}
|
|
|
|
| 8927 |
case GGML_GLU_OP_GEGLU:
|
| 8928 |
case GGML_GLU_OP_REGLU:
|
| 8929 |
case GGML_GLU_OP_SWIGLU:
|
| 8930 |
+
case GGML_GLU_OP_GEGLU_ERF:
|
| 8931 |
+
case GGML_GLU_OP_GEGLU_QUICK:
|
| 8932 |
break;
|
| 8933 |
default:
|
| 8934 |
return false;
|
|
|
|
| 9176 |
case GGML_GLU_OP_GEGLU:
|
| 9177 |
case GGML_GLU_OP_REGLU:
|
| 9178 |
case GGML_GLU_OP_SWIGLU:
|
| 9179 |
+
case GGML_GLU_OP_GEGLU_ERF:
|
| 9180 |
+
case GGML_GLU_OP_GEGLU_QUICK:
|
| 9181 |
ggml_vk_glu(ctx, compute_ctx, src0, src1, node, dryrun);
|
| 9182 |
break;
|
| 9183 |
default:
|
|
|
|
| 9396 |
case GGML_GLU_OP_GEGLU:
|
| 9397 |
case GGML_GLU_OP_REGLU:
|
| 9398 |
case GGML_GLU_OP_SWIGLU:
|
| 9399 |
+
case GGML_GLU_OP_GEGLU_ERF:
|
| 9400 |
+
case GGML_GLU_OP_GEGLU_QUICK:
|
| 9401 |
buf = tensor->buffer;
|
| 9402 |
break;
|
| 9403 |
default:
|
|
|
|
| 10208 |
case GGML_GLU_OP_GEGLU:
|
| 10209 |
case GGML_GLU_OP_REGLU:
|
| 10210 |
case GGML_GLU_OP_SWIGLU:
|
| 10211 |
+
case GGML_GLU_OP_GEGLU_ERF:
|
| 10212 |
+
case GGML_GLU_OP_GEGLU_QUICK:
|
| 10213 |
return ggml_is_contiguous(op->src[0]) &&
|
| 10214 |
(op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
|
| 10215 |
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
|
ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#version 450
|
| 2 |
+
|
| 3 |
+
#include "glu_head.comp"
|
| 4 |
+
|
| 5 |
+
// based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation
|
| 6 |
+
// ref: https://www.johndcook.com/blog/python_erf/
|
| 7 |
+
const float p_erf = 0.3275911f;
|
| 8 |
+
const float a1_erf = 0.254829592f;
|
| 9 |
+
const float a2_erf = -0.284496736f;
|
| 10 |
+
const float a3_erf = 1.421413741f;
|
| 11 |
+
const float a4_erf = -1.453152027f;
|
| 12 |
+
const float a5_erf = 1.061405429f;
|
| 13 |
+
|
| 14 |
+
const float SQRT_2_INV = 0.70710678118654752440084436210484f;
|
| 15 |
+
|
| 16 |
+
float op(float a, float b) {
|
| 17 |
+
const float a_div_sqr2 = a * SQRT_2_INV;
|
| 18 |
+
const float sign_x = sign(a_div_sqr2);
|
| 19 |
+
const float x = abs(a_div_sqr2);
|
| 20 |
+
const float t = 1.0f / (1.0f + p_erf * x);
|
| 21 |
+
const float y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x);
|
| 22 |
+
const float erf_approx = sign_x * y;
|
| 23 |
+
|
| 24 |
+
return 0.5f * a * (1.0f + erf_approx) * b;
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
#include "glu_main.comp"
|
ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#version 450
|
| 2 |
+
|
| 3 |
+
#include "glu_head.comp"
|
| 4 |
+
|
| 5 |
+
const float GELU_QUICK_COEF = -1.702f;
|
| 6 |
+
|
| 7 |
+
float op(float a, float b) {
|
| 8 |
+
return a * (1.0f / (1.0f + exp(GELU_QUICK_COEF * a))) * b;
|
| 9 |
+
}
|
| 10 |
+
|
| 11 |
+
#include "glu_main.comp"
|
ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
CHANGED
|
@@ -593,6 +593,10 @@ void process_shaders() {
|
|
| 593 |
string_to_spv("reglu_f32", "reglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
| 594 |
string_to_spv("swiglu_f16", "swiglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
| 595 |
string_to_spv("swiglu_f32", "swiglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
|
|
|
|
|
|
|
|
|
|
|
|
| 596 |
|
| 597 |
string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
| 598 |
string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
|
|
|
|
| 593 |
string_to_spv("reglu_f32", "reglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
| 594 |
string_to_spv("swiglu_f16", "swiglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
| 595 |
string_to_spv("swiglu_f32", "swiglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
| 596 |
+
string_to_spv("geglu_erf_f16", "geglu_erf.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
| 597 |
+
string_to_spv("geglu_erf_f32", "geglu_erf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
| 598 |
+
string_to_spv("geglu_quick_f16","geglu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
| 599 |
+
string_to_spv("geglu_quick_f32","geglu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
| 600 |
|
| 601 |
string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
| 602 |
string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
|
ggml/src/ggml.c
CHANGED
|
@@ -1132,9 +1132,11 @@ static const char * GGML_GLU_OP_NAME[GGML_GLU_OP_COUNT] = {
|
|
| 1132 |
"REGLU",
|
| 1133 |
"GEGLU",
|
| 1134 |
"SWIGLU",
|
|
|
|
|
|
|
| 1135 |
};
|
| 1136 |
|
| 1137 |
-
static_assert(GGML_GLU_OP_COUNT ==
|
| 1138 |
|
| 1139 |
|
| 1140 |
static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
|
|
@@ -2760,6 +2762,48 @@ struct ggml_tensor * ggml_swiglu_split(
|
|
| 2760 |
return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_SWIGLU, false);
|
| 2761 |
}
|
| 2762 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2763 |
// ggml_norm
|
| 2764 |
|
| 2765 |
static struct ggml_tensor * ggml_norm_impl(
|
|
|
|
| 1132 |
"REGLU",
|
| 1133 |
"GEGLU",
|
| 1134 |
"SWIGLU",
|
| 1135 |
+
"GEGLU_ERF",
|
| 1136 |
+
"GEGLU_QUICK",
|
| 1137 |
};
|
| 1138 |
|
| 1139 |
+
static_assert(GGML_GLU_OP_COUNT == 5, "GGML_GLU_OP_COUNT != 5");
|
| 1140 |
|
| 1141 |
|
| 1142 |
static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
|
|
|
|
| 2762 |
return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_SWIGLU, false);
|
| 2763 |
}
|
| 2764 |
|
| 2765 |
+
// ggml_geglu_erf
|
| 2766 |
+
|
| 2767 |
+
struct ggml_tensor * ggml_geglu_erf(
|
| 2768 |
+
struct ggml_context * ctx,
|
| 2769 |
+
struct ggml_tensor * a) {
|
| 2770 |
+
return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU_ERF, false);
|
| 2771 |
+
}
|
| 2772 |
+
|
| 2773 |
+
struct ggml_tensor * ggml_geglu_erf_swapped(
|
| 2774 |
+
struct ggml_context * ctx,
|
| 2775 |
+
struct ggml_tensor * a) {
|
| 2776 |
+
return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU_ERF, true);
|
| 2777 |
+
}
|
| 2778 |
+
|
| 2779 |
+
struct ggml_tensor * ggml_geglu_erf_split(
|
| 2780 |
+
struct ggml_context * ctx,
|
| 2781 |
+
struct ggml_tensor * a,
|
| 2782 |
+
struct ggml_tensor * b) {
|
| 2783 |
+
return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_GEGLU_ERF, false);
|
| 2784 |
+
}
|
| 2785 |
+
|
| 2786 |
+
// ggml_geglu_quick
|
| 2787 |
+
|
| 2788 |
+
struct ggml_tensor * ggml_geglu_quick(
|
| 2789 |
+
struct ggml_context * ctx,
|
| 2790 |
+
struct ggml_tensor * a) {
|
| 2791 |
+
return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU_QUICK, false);
|
| 2792 |
+
}
|
| 2793 |
+
|
| 2794 |
+
struct ggml_tensor * ggml_geglu_quick_swapped(
|
| 2795 |
+
struct ggml_context * ctx,
|
| 2796 |
+
struct ggml_tensor * a) {
|
| 2797 |
+
return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU_QUICK, true);
|
| 2798 |
+
}
|
| 2799 |
+
|
| 2800 |
+
struct ggml_tensor * ggml_geglu_quick_split(
|
| 2801 |
+
struct ggml_context * ctx,
|
| 2802 |
+
struct ggml_tensor * a,
|
| 2803 |
+
struct ggml_tensor * b) {
|
| 2804 |
+
return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_GEGLU_QUICK, false);
|
| 2805 |
+
}
|
| 2806 |
+
|
| 2807 |
// ggml_norm
|
| 2808 |
|
| 2809 |
static struct ggml_tensor * ggml_norm_impl(
|