Sigbjørn Skjæret commited on
Commit
f798922
·
1 Parent(s): 4434043

ggml : implement GEGLU_ERF and GEGLU_QUICK ops (llama/14445)

Browse files
ggml/include/ggml.h CHANGED
@@ -557,6 +557,8 @@ extern "C" {
557
  GGML_GLU_OP_REGLU,
558
  GGML_GLU_OP_GEGLU,
559
  GGML_GLU_OP_SWIGLU,
 
 
560
 
561
  GGML_GLU_OP_COUNT,
562
  };
@@ -1144,6 +1146,22 @@ extern "C" {
1144
  struct ggml_context * ctx,
1145
  struct ggml_tensor * a);
1146
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1147
  // A: n columns, r rows,
1148
  // B: n columns, r rows,
1149
  GGML_API struct ggml_tensor * ggml_glu_split(
@@ -1167,6 +1185,16 @@ extern "C" {
1167
  struct ggml_tensor * a,
1168
  struct ggml_tensor * b);
1169
 
 
 
 
 
 
 
 
 
 
 
1170
  // normalize along rows
1171
  GGML_API struct ggml_tensor * ggml_norm(
1172
  struct ggml_context * ctx,
 
557
  GGML_GLU_OP_REGLU,
558
  GGML_GLU_OP_GEGLU,
559
  GGML_GLU_OP_SWIGLU,
560
+ GGML_GLU_OP_GEGLU_ERF,
561
+ GGML_GLU_OP_GEGLU_QUICK,
562
 
563
  GGML_GLU_OP_COUNT,
564
  };
 
1146
  struct ggml_context * ctx,
1147
  struct ggml_tensor * a);
1148
 
1149
+ GGML_API struct ggml_tensor * ggml_geglu_erf(
1150
+ struct ggml_context * ctx,
1151
+ struct ggml_tensor * a);
1152
+
1153
+ GGML_API struct ggml_tensor * ggml_geglu_erf_swapped(
1154
+ struct ggml_context * ctx,
1155
+ struct ggml_tensor * a);
1156
+
1157
+ GGML_API struct ggml_tensor * ggml_geglu_quick(
1158
+ struct ggml_context * ctx,
1159
+ struct ggml_tensor * a);
1160
+
1161
+ GGML_API struct ggml_tensor * ggml_geglu_quick_swapped(
1162
+ struct ggml_context * ctx,
1163
+ struct ggml_tensor * a);
1164
+
1165
  // A: n columns, r rows,
1166
  // B: n columns, r rows,
1167
  GGML_API struct ggml_tensor * ggml_glu_split(
 
1185
  struct ggml_tensor * a,
1186
  struct ggml_tensor * b);
1187
 
1188
+ GGML_API struct ggml_tensor * ggml_geglu_erf_split(
1189
+ struct ggml_context * ctx,
1190
+ struct ggml_tensor * a,
1191
+ struct ggml_tensor * b);
1192
+
1193
+ GGML_API struct ggml_tensor * ggml_geglu_quick_split(
1194
+ struct ggml_context * ctx,
1195
+ struct ggml_tensor * a,
1196
+ struct ggml_tensor * b);
1197
+
1198
  // normalize along rows
1199
  GGML_API struct ggml_tensor * ggml_norm(
1200
  struct ggml_context * ctx,
ggml/src/ggml-cpu/ggml-cpu.c CHANGED
@@ -2172,6 +2172,8 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
2172
  case GGML_GLU_OP_REGLU:
2173
  case GGML_GLU_OP_GEGLU:
2174
  case GGML_GLU_OP_SWIGLU:
 
 
2175
  {
2176
  n_tasks = n_threads;
2177
  } break;
 
2172
  case GGML_GLU_OP_REGLU:
2173
  case GGML_GLU_OP_GEGLU:
2174
  case GGML_GLU_OP_SWIGLU:
2175
+ case GGML_GLU_OP_GEGLU_ERF:
2176
+ case GGML_GLU_OP_GEGLU_QUICK:
2177
  {
2178
  n_tasks = n_threads;
2179
  } break;
ggml/src/ggml-cpu/ops.cpp CHANGED
@@ -3614,6 +3614,292 @@ static void ggml_compute_forward_swiglu(
3614
  }
3615
  }
3616
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3617
  // ggml_compute_forward_norm
3618
 
3619
  static void ggml_compute_forward_norm_f32(
@@ -8779,6 +9065,14 @@ void ggml_compute_forward_glu(
8779
  {
8780
  ggml_compute_forward_swiglu(params, dst);
8781
  } break;
 
 
 
 
 
 
 
 
8782
  default:
8783
  {
8784
  GGML_ABORT("fatal error");
 
3614
  }
3615
  }
3616
 
3617
+ // ggml_compute_forward_geglu_erf
3618
+
3619
+ static void ggml_compute_forward_geglu_erf_f32(
3620
+ const ggml_compute_params * params,
3621
+ ggml_tensor * dst) {
3622
+
3623
+ const ggml_tensor * src0 = dst->src[0];
3624
+ const ggml_tensor * src1 = dst->src[1];
3625
+ char * src0_d = (char *) src0->data;
3626
+ char * src1_d = (char *) (src1 ? src1->data : src0->data);
3627
+ const size_t src0_o = src0->nb[1];
3628
+ const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3629
+
3630
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
3631
+ GGML_ASSERT(ggml_is_contiguous_1(dst));
3632
+
3633
+ if (src1) {
3634
+ GGML_ASSERT(ggml_is_contiguous_1(src1));
3635
+ GGML_ASSERT(src0->type == src1->type);
3636
+ }
3637
+
3638
+ const int ith = params->ith;
3639
+ const int nth = params->nth;
3640
+
3641
+ const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3642
+ const int nr = ggml_nrows(src0);
3643
+
3644
+ GGML_ASSERT(dst->ne[0] == nc);
3645
+ GGML_ASSERT(ggml_nrows(dst) == nr);
3646
+
3647
+ const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3648
+
3649
+ // rows per thread
3650
+ const int dr = (nr + nth - 1)/nth;
3651
+
3652
+ // row range for this thread
3653
+ const int ir0 = dr*ith;
3654
+ const int ir1 = MIN(ir0 + dr, nr);
3655
+
3656
+ for (int i1 = ir0; i1 < ir1; i1++) {
3657
+ float * src0_p = (float *) (src0_d + i1*src0_o);
3658
+ float * src1_p = (float *) (src1_d + i1*src1_o);
3659
+
3660
+ if (!src1) {
3661
+ src0_p += swapped ? nc : 0;
3662
+ src1_p += swapped ? 0 : nc;
3663
+ }
3664
+
3665
+ ggml_vec_geglu_erf_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3666
+
3667
+ #ifndef NDEBUG
3668
+ for (int k = 0; k < nc; k++) {
3669
+ const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3670
+ GGML_UNUSED(x);
3671
+ assert(!isnan(x));
3672
+ assert(!isinf(x));
3673
+ }
3674
+ #endif
3675
+ }
3676
+ }
3677
+
3678
+ static void ggml_compute_forward_geglu_erf_f16(
3679
+ const ggml_compute_params * params,
3680
+ ggml_tensor * dst) {
3681
+
3682
+ const ggml_tensor * src0 = dst->src[0];
3683
+ const ggml_tensor * src1 = dst->src[1];
3684
+ char * src0_d = (char *) src0->data;
3685
+ char * src1_d = (char *) (src1 ? src1->data : src0->data);
3686
+ const size_t src0_o = src0->nb[1];
3687
+ const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3688
+
3689
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
3690
+ GGML_ASSERT(ggml_is_contiguous_1(dst));
3691
+
3692
+ if (src1) {
3693
+ GGML_ASSERT(ggml_is_contiguous_1(src1));
3694
+ GGML_ASSERT(src0->type == src1->type);
3695
+ }
3696
+
3697
+ const int ith = params->ith;
3698
+ const int nth = params->nth;
3699
+
3700
+ const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3701
+ const int nr = ggml_nrows(src0);
3702
+
3703
+ GGML_ASSERT(dst->ne[0] == nc);
3704
+ GGML_ASSERT(ggml_nrows(dst) == nr);
3705
+
3706
+ const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3707
+
3708
+ // rows per thread
3709
+ const int dr = (nr + nth - 1)/nth;
3710
+
3711
+ // row range for this thread
3712
+ const int ir0 = dr*ith;
3713
+ const int ir1 = MIN(ir0 + dr, nr);
3714
+
3715
+ for (int i1 = ir0; i1 < ir1; i1++) {
3716
+ ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
3717
+ ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
3718
+
3719
+ if (!src1) {
3720
+ src0_p += swapped ? nc : 0;
3721
+ src1_p += swapped ? 0 : nc;
3722
+ }
3723
+
3724
+ ggml_vec_geglu_erf_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3725
+
3726
+ #ifndef NDEBUG
3727
+ for (int k = 0; k < nc; k++) {
3728
+ const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3729
+ const float v = GGML_FP16_TO_FP32(x);
3730
+ GGML_UNUSED(v);
3731
+ assert(!isnan(v));
3732
+ assert(!isinf(v));
3733
+ }
3734
+ #endif
3735
+ }
3736
+ }
3737
+
3738
+ static void ggml_compute_forward_geglu_erf(
3739
+ const ggml_compute_params * params,
3740
+ ggml_tensor * dst) {
3741
+
3742
+ const ggml_tensor * src0 = dst->src[0];
3743
+
3744
+ switch (src0->type) {
3745
+ case GGML_TYPE_F32:
3746
+ {
3747
+ ggml_compute_forward_geglu_erf_f32(params, dst);
3748
+ } break;
3749
+ case GGML_TYPE_F16:
3750
+ {
3751
+ ggml_compute_forward_geglu_erf_f16(params, dst);
3752
+ } break;
3753
+ default:
3754
+ {
3755
+ GGML_ABORT("fatal error");
3756
+ }
3757
+ }
3758
+ }
3759
+
3760
+ // ggml_compute_forward_geglu_quick
3761
+
3762
+ static void ggml_compute_forward_geglu_quick_f32(
3763
+ const ggml_compute_params * params,
3764
+ ggml_tensor * dst) {
3765
+
3766
+ const ggml_tensor * src0 = dst->src[0];
3767
+ const ggml_tensor * src1 = dst->src[1];
3768
+ char * src0_d = (char *) src0->data;
3769
+ char * src1_d = (char *) (src1 ? src1->data : src0->data);
3770
+ const size_t src0_o = src0->nb[1];
3771
+ const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3772
+
3773
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
3774
+ GGML_ASSERT(ggml_is_contiguous_1(dst));
3775
+
3776
+ if (src1) {
3777
+ GGML_ASSERT(ggml_is_contiguous_1(src1));
3778
+ GGML_ASSERT(src0->type == src1->type);
3779
+ }
3780
+
3781
+ const int ith = params->ith;
3782
+ const int nth = params->nth;
3783
+
3784
+ const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3785
+ const int nr = ggml_nrows(src0);
3786
+
3787
+ GGML_ASSERT(dst->ne[0] == nc);
3788
+ GGML_ASSERT(ggml_nrows(dst) == nr);
3789
+
3790
+ const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3791
+
3792
+ // rows per thread
3793
+ const int dr = (nr + nth - 1)/nth;
3794
+
3795
+ // row range for this thread
3796
+ const int ir0 = dr*ith;
3797
+ const int ir1 = MIN(ir0 + dr, nr);
3798
+
3799
+ for (int i1 = ir0; i1 < ir1; i1++) {
3800
+ float * src0_p = (float *) (src0_d + i1*src0_o);
3801
+ float * src1_p = (float *) (src1_d + i1*src1_o);
3802
+
3803
+ if (!src1) {
3804
+ src0_p += swapped ? nc : 0;
3805
+ src1_p += swapped ? 0 : nc;
3806
+ }
3807
+
3808
+ ggml_vec_geglu_quick_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3809
+
3810
+ #ifndef NDEBUG
3811
+ for (int k = 0; k < nc; k++) {
3812
+ const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3813
+ GGML_UNUSED(x);
3814
+ assert(!isnan(x));
3815
+ assert(!isinf(x));
3816
+ }
3817
+ #endif
3818
+ }
3819
+ }
3820
+
3821
+ static void ggml_compute_forward_geglu_quick_f16(
3822
+ const ggml_compute_params * params,
3823
+ ggml_tensor * dst) {
3824
+
3825
+ const ggml_tensor * src0 = dst->src[0];
3826
+ const ggml_tensor * src1 = dst->src[1];
3827
+ char * src0_d = (char *) src0->data;
3828
+ char * src1_d = (char *) (src1 ? src1->data : src0->data);
3829
+ const size_t src0_o = src0->nb[1];
3830
+ const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3831
+
3832
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
3833
+ GGML_ASSERT(ggml_is_contiguous_1(dst));
3834
+
3835
+ if (src1) {
3836
+ GGML_ASSERT(ggml_is_contiguous_1(src1));
3837
+ GGML_ASSERT(src0->type == src1->type);
3838
+ }
3839
+
3840
+ const int ith = params->ith;
3841
+ const int nth = params->nth;
3842
+
3843
+ const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3844
+ const int nr = ggml_nrows(src0);
3845
+
3846
+ GGML_ASSERT(dst->ne[0] == nc);
3847
+ GGML_ASSERT(ggml_nrows(dst) == nr);
3848
+
3849
+ const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3850
+
3851
+ // rows per thread
3852
+ const int dr = (nr + nth - 1)/nth;
3853
+
3854
+ // row range for this thread
3855
+ const int ir0 = dr*ith;
3856
+ const int ir1 = MIN(ir0 + dr, nr);
3857
+
3858
+ for (int i1 = ir0; i1 < ir1; i1++) {
3859
+ ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
3860
+ ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
3861
+
3862
+ if (!src1) {
3863
+ src0_p += swapped ? nc : 0;
3864
+ src1_p += swapped ? 0 : nc;
3865
+ }
3866
+
3867
+ ggml_vec_geglu_quick_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3868
+
3869
+ #ifndef NDEBUG
3870
+ for (int k = 0; k < nc; k++) {
3871
+ const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3872
+ const float v = GGML_FP16_TO_FP32(x);
3873
+ GGML_UNUSED(v);
3874
+ assert(!isnan(v));
3875
+ assert(!isinf(v));
3876
+ }
3877
+ #endif
3878
+ }
3879
+ }
3880
+
3881
+ static void ggml_compute_forward_geglu_quick(
3882
+ const ggml_compute_params * params,
3883
+ ggml_tensor * dst) {
3884
+
3885
+ const ggml_tensor * src0 = dst->src[0];
3886
+
3887
+ switch (src0->type) {
3888
+ case GGML_TYPE_F32:
3889
+ {
3890
+ ggml_compute_forward_geglu_quick_f32(params, dst);
3891
+ } break;
3892
+ case GGML_TYPE_F16:
3893
+ {
3894
+ ggml_compute_forward_geglu_quick_f16(params, dst);
3895
+ } break;
3896
+ default:
3897
+ {
3898
+ GGML_ABORT("fatal error");
3899
+ }
3900
+ }
3901
+ }
3902
+
3903
  // ggml_compute_forward_norm
3904
 
3905
  static void ggml_compute_forward_norm_f32(
 
9065
  {
9066
  ggml_compute_forward_swiglu(params, dst);
9067
  } break;
9068
+ case GGML_GLU_OP_GEGLU_ERF:
9069
+ {
9070
+ ggml_compute_forward_geglu_erf(params, dst);
9071
+ } break;
9072
+ case GGML_GLU_OP_GEGLU_QUICK:
9073
+ {
9074
+ ggml_compute_forward_geglu_quick(params, dst);
9075
+ } break;
9076
  default:
9077
  {
9078
  GGML_ABORT("fatal error");
ggml/src/ggml-cpu/vec.h CHANGED
@@ -959,6 +959,46 @@ inline static void ggml_vec_swiglu_f16(const int n, ggml_fp16_t * y, const ggml_
959
  }
960
  }
961
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
962
  inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) {
963
  #ifndef GGML_USE_ACCELERATE
964
  ggml_float sum = 0.0;
 
959
  }
960
  }
961
 
962
+ inline static void ggml_vec_geglu_erf_f32(const int n, float * y, const float * x, const float * g) {
963
+ for (int i = 0; i < n; ++i) {
964
+ float xi = x[i];
965
+ y[i] = 0.5f * xi * (1.0f + erff(xi*SQRT_2_INV)) * g[i];
966
+ }
967
+ }
968
+
969
+ inline static void ggml_vec_geglu_erf_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
970
+ for (int i = 0; i < n; ++i) {
971
+ float xi = GGML_CPU_FP16_TO_FP32(x[i]);
972
+ float gi = GGML_CPU_FP16_TO_FP32(g[i]);
973
+ y[i] = GGML_CPU_FP32_TO_FP16(0.5f * xi * (1.0f + erff(xi*SQRT_2_INV)) * gi);
974
+ }
975
+ }
976
+
977
+ #ifdef GGML_GELU_QUICK_FP16
978
+ inline static void ggml_vec_geglu_quick_f32(const int n, float * y, const float * x, const float * g) {
979
+ uint16_t t;
980
+ for (int i = 0; i < n; ++i) {
981
+ ggml_fp16_t fp16 = GGML_CPU_FP32_TO_FP16(x[i]);
982
+ memcpy(&t, &fp16, sizeof(uint16_t));
983
+ y[i] = GGML_CPU_FP16_TO_FP32(ggml_table_gelu_quick_f16[t]) * g[i];
984
+ }
985
+ }
986
+ #else
987
+ inline static void ggml_vec_geglu_quick_f32(const int n, float * y, const float * x, const float * g) {
988
+ for (int i = 0; i < n; ++i) {
989
+ y[i] = ggml_gelu_quick_f32(x[i]) * g[i];
990
+ }
991
+ }
992
+ #endif
993
+
994
+ inline static void ggml_vec_geglu_quick_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
995
+ const uint16_t * i16 = (const uint16_t *) x;
996
+ for (int i = 0; i < n; ++i) {
997
+ float v = GGML_CPU_FP16_TO_FP32(g[i]);
998
+ y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(ggml_table_gelu_quick_f16[i16[i]]) * v);
999
+ }
1000
+ }
1001
+
1002
  inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) {
1003
  #ifndef GGML_USE_ACCELERATE
1004
  ggml_float sum = 0.0;
ggml/src/ggml-cuda/ggml-cuda.cu CHANGED
@@ -2314,6 +2314,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
2314
  case GGML_GLU_OP_SWIGLU:
2315
  ggml_cuda_op_swiglu(ctx, dst);
2316
  break;
 
 
 
 
 
 
2317
  default:
2318
  return false;
2319
  }
@@ -3116,6 +3122,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3116
  case GGML_GLU_OP_REGLU:
3117
  case GGML_GLU_OP_GEGLU:
3118
  case GGML_GLU_OP_SWIGLU:
 
 
3119
  return ggml_is_contiguous_1(op->src[0]);
3120
  default:
3121
  return false;
 
2314
  case GGML_GLU_OP_SWIGLU:
2315
  ggml_cuda_op_swiglu(ctx, dst);
2316
  break;
2317
+ case GGML_GLU_OP_GEGLU_ERF:
2318
+ ggml_cuda_op_geglu_erf(ctx, dst);
2319
+ break;
2320
+ case GGML_GLU_OP_GEGLU_QUICK:
2321
+ ggml_cuda_op_geglu_quick(ctx, dst);
2322
+ break;
2323
  default:
2324
  return false;
2325
  }
 
3122
  case GGML_GLU_OP_REGLU:
3123
  case GGML_GLU_OP_GEGLU:
3124
  case GGML_GLU_OP_SWIGLU:
3125
+ case GGML_GLU_OP_GEGLU_ERF:
3126
+ case GGML_GLU_OP_GEGLU_QUICK:
3127
  return ggml_is_contiguous_1(op->src[0]);
3128
  default:
3129
  return false;
ggml/src/ggml-cuda/unary.cu CHANGED
@@ -285,6 +285,14 @@ void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
285
  ggml_cuda_op_unary_gated<op_silu>(ctx, dst);
286
  }
287
 
 
 
 
 
 
 
 
 
288
  /* silu_back */
289
 
290
  static __device__ __forceinline__ float op_silu_back(float grad, float x) {
 
285
  ggml_cuda_op_unary_gated<op_silu>(ctx, dst);
286
  }
287
 
288
+ void ggml_cuda_op_geglu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
289
+ ggml_cuda_op_unary_gated<op_gelu_erf>(ctx, dst);
290
+ }
291
+
292
+ void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
293
+ ggml_cuda_op_unary_gated<op_gelu_quick>(ctx, dst);
294
+ }
295
+
296
  /* silu_back */
297
 
298
  static __device__ __forceinline__ float op_silu_back(float grad, float x) {
ggml/src/ggml-cuda/unary.cuh CHANGED
@@ -64,3 +64,7 @@ void ggml_cuda_op_reglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
64
  void ggml_cuda_op_geglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
65
 
66
  void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
 
 
 
 
64
  void ggml_cuda_op_geglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
65
 
66
  void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
67
+
68
+ void ggml_cuda_op_geglu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
69
+
70
+ void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
ggml/src/ggml-metal/ggml-metal.m CHANGED
@@ -530,6 +530,8 @@ enum ggml_metal_kernel_type {
530
  GGML_METAL_KERNEL_TYPE_REGLU,
531
  GGML_METAL_KERNEL_TYPE_GEGLU,
532
  GGML_METAL_KERNEL_TYPE_SWIGLU,
 
 
533
  GGML_METAL_KERNEL_TYPE_SUM_ROWS,
534
  GGML_METAL_KERNEL_TYPE_MEAN,
535
  GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
@@ -1510,6 +1512,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
1510
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REGLU, reglu, true);
1511
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU, geglu, true);
1512
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU, swiglu, true);
 
 
1513
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
1514
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MEAN, mean, true);
1515
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
@@ -1693,6 +1697,8 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
1693
  case GGML_GLU_OP_REGLU:
1694
  case GGML_GLU_OP_GEGLU:
1695
  case GGML_GLU_OP_SWIGLU:
 
 
1696
  return ggml_is_contiguous_1(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
1697
  default:
1698
  return false;
@@ -2456,6 +2462,12 @@ static bool ggml_metal_encode_node(
2456
  case GGML_GLU_OP_SWIGLU:
2457
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU].pipeline;
2458
  break;
 
 
 
 
 
 
2459
  default:
2460
  GGML_ABORT("fatal error");
2461
  }
 
530
  GGML_METAL_KERNEL_TYPE_REGLU,
531
  GGML_METAL_KERNEL_TYPE_GEGLU,
532
  GGML_METAL_KERNEL_TYPE_SWIGLU,
533
+ GGML_METAL_KERNEL_TYPE_GEGLU_ERF,
534
+ GGML_METAL_KERNEL_TYPE_GEGLU_QUICK,
535
  GGML_METAL_KERNEL_TYPE_SUM_ROWS,
536
  GGML_METAL_KERNEL_TYPE_MEAN,
537
  GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
 
1512
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REGLU, reglu, true);
1513
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU, geglu, true);
1514
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU, swiglu, true);
1515
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU_ERF, geglu_erf, true);
1516
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU_QUICK, geglu_quick, true);
1517
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
1518
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MEAN, mean, true);
1519
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
 
1697
  case GGML_GLU_OP_REGLU:
1698
  case GGML_GLU_OP_GEGLU:
1699
  case GGML_GLU_OP_SWIGLU:
1700
+ case GGML_GLU_OP_GEGLU_ERF:
1701
+ case GGML_GLU_OP_GEGLU_QUICK:
1702
  return ggml_is_contiguous_1(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
1703
  default:
1704
  return false;
 
2462
  case GGML_GLU_OP_SWIGLU:
2463
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU].pipeline;
2464
  break;
2465
+ case GGML_GLU_OP_GEGLU_ERF:
2466
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GEGLU_ERF].pipeline;
2467
+ break;
2468
+ case GGML_GLU_OP_GEGLU_QUICK:
2469
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GEGLU_QUICK].pipeline;
2470
+ break;
2471
  default:
2472
  GGML_ABORT("fatal error");
2473
  }
ggml/src/ggml-metal/ggml-metal.metal CHANGED
@@ -1258,6 +1258,50 @@ kernel void kernel_swiglu(
1258
  }
1259
  }
1260
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1261
  template <bool norm>
1262
  kernel void kernel_sum_rows(
1263
  constant ggml_metal_kargs_sum_rows & args,
 
1258
  }
1259
  }
1260
 
1261
+ kernel void kernel_geglu_erf(
1262
+ device const char * src0,
1263
+ device const char * src1,
1264
+ device char * dst,
1265
+ constant ggml_metal_kargs_glu & args,
1266
+ uint tgpig[[threadgroup_position_in_grid]],
1267
+ uint tpitg[[thread_position_in_threadgroup]],
1268
+ uint ntg[[threads_per_threadgroup]]) {
1269
+ device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
1270
+ device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
1271
+ device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
1272
+
1273
+ for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
1274
+ const float x0 = src0_row[i0];
1275
+ const float x1 = src1_row[i0];
1276
+
1277
+ const float gelu_erf = 0.5f*x0*(1.0f+erf_approx<float>(x0*SQRT_2_INV));
1278
+
1279
+ dst_row[i0] = gelu_erf*x1;
1280
+ }
1281
+ }
1282
+
1283
+ kernel void kernel_geglu_quick(
1284
+ device const char * src0,
1285
+ device const char * src1,
1286
+ device char * dst,
1287
+ constant ggml_metal_kargs_glu & args,
1288
+ uint tgpig[[threadgroup_position_in_grid]],
1289
+ uint tpitg[[thread_position_in_threadgroup]],
1290
+ uint ntg[[threads_per_threadgroup]]) {
1291
+ device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
1292
+ device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
1293
+ device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
1294
+
1295
+ for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
1296
+ const float x0 = src0_row[i0];
1297
+ const float x1 = src1_row[i0];
1298
+
1299
+ const float gelu_quick = x0*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x0)));
1300
+
1301
+ dst_row[i0] = gelu_quick*x1;
1302
+ }
1303
+ }
1304
+
1305
  template <bool norm>
1306
  kernel void kernel_sum_rows(
1307
  constant ggml_metal_kargs_sum_rows & args,
ggml/src/ggml-opencl/ggml-opencl.cpp CHANGED
@@ -402,8 +402,8 @@ struct ggml_backend_opencl_context {
402
  cl_kernel kernel_relu;
403
  cl_kernel kernel_sigmoid_f32, kernel_sigmoid_f16;
404
  cl_kernel kernel_clamp;
405
- cl_kernel kernel_geglu, kernel_reglu, kernel_swiglu,
406
- kernel_geglu_f16, kernel_reglu_f16, kernel_swiglu_f16;
407
  cl_kernel kernel_norm;
408
  cl_kernel kernel_rms_norm;
409
  cl_kernel kernel_group_norm;
@@ -753,12 +753,16 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
753
  backend_ctx->program_glu =
754
  build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
755
 
756
- CL_CHECK((backend_ctx->kernel_geglu = clCreateKernel(backend_ctx->program_glu, "kernel_geglu", &err), err));
757
- CL_CHECK((backend_ctx->kernel_reglu = clCreateKernel(backend_ctx->program_glu, "kernel_reglu", &err), err));
758
- CL_CHECK((backend_ctx->kernel_swiglu = clCreateKernel(backend_ctx->program_glu, "kernel_swiglu", &err), err));
759
- CL_CHECK((backend_ctx->kernel_geglu_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_f16", &err), err));
760
- CL_CHECK((backend_ctx->kernel_reglu_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_reglu_f16", &err), err));
761
- CL_CHECK((backend_ctx->kernel_swiglu_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_swiglu_f16", &err), err));
 
 
 
 
762
  GGML_LOG_CONT(".");
763
  }
764
 
@@ -2277,6 +2281,8 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
2277
  case GGML_GLU_OP_GEGLU:
2278
  case GGML_GLU_OP_REGLU:
2279
  case GGML_GLU_OP_SWIGLU:
 
 
2280
  return ggml_is_contiguous_1(op->src[0]) && (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16);
2281
  default:
2282
  return false;
@@ -6254,6 +6260,20 @@ static void ggml_cl_glu(ggml_backend_t backend, const ggml_tensor * src0, const
6254
  kernel = backend_ctx->kernel_swiglu_f16;
6255
  }
6256
  break;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6257
  default:
6258
  GGML_ABORT("Unsupported glu op");
6259
  }
 
402
  cl_kernel kernel_relu;
403
  cl_kernel kernel_sigmoid_f32, kernel_sigmoid_f16;
404
  cl_kernel kernel_clamp;
405
+ cl_kernel kernel_geglu, kernel_reglu, kernel_swiglu, kernel_geglu_erf, kernel_geglu_quick,
406
+ kernel_geglu_f16, kernel_reglu_f16, kernel_swiglu_f16, kernel_geglu_erf_f16, kernel_geglu_quick_f16;
407
  cl_kernel kernel_norm;
408
  cl_kernel kernel_rms_norm;
409
  cl_kernel kernel_group_norm;
 
753
  backend_ctx->program_glu =
754
  build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
755
 
756
+ CL_CHECK((backend_ctx->kernel_geglu = clCreateKernel(backend_ctx->program_glu, "kernel_geglu", &err), err));
757
+ CL_CHECK((backend_ctx->kernel_reglu = clCreateKernel(backend_ctx->program_glu, "kernel_reglu", &err), err));
758
+ CL_CHECK((backend_ctx->kernel_swiglu = clCreateKernel(backend_ctx->program_glu, "kernel_swiglu", &err), err));
759
+ CL_CHECK((backend_ctx->kernel_geglu_erf = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_erf", &err), err));
760
+ CL_CHECK((backend_ctx->kernel_geglu_quick = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_quick", &err), err));
761
+ CL_CHECK((backend_ctx->kernel_geglu_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_f16", &err), err));
762
+ CL_CHECK((backend_ctx->kernel_reglu_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_reglu_f16", &err), err));
763
+ CL_CHECK((backend_ctx->kernel_swiglu_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_swiglu_f16", &err), err));
764
+ CL_CHECK((backend_ctx->kernel_geglu_erf_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_erf_f16", &err), err));
765
+ CL_CHECK((backend_ctx->kernel_geglu_quick_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_quick_f16", &err), err));
766
  GGML_LOG_CONT(".");
767
  }
768
 
 
2281
  case GGML_GLU_OP_GEGLU:
2282
  case GGML_GLU_OP_REGLU:
2283
  case GGML_GLU_OP_SWIGLU:
2284
+ case GGML_GLU_OP_GEGLU_ERF:
2285
+ case GGML_GLU_OP_GEGLU_QUICK:
2286
  return ggml_is_contiguous_1(op->src[0]) && (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16);
2287
  default:
2288
  return false;
 
6260
  kernel = backend_ctx->kernel_swiglu_f16;
6261
  }
6262
  break;
6263
+ case GGML_GLU_OP_GEGLU_ERF:
6264
+ if (dst->type == GGML_TYPE_F32) {
6265
+ kernel = backend_ctx->kernel_geglu_erf;
6266
+ } else {
6267
+ kernel = backend_ctx->kernel_geglu_erf_f16;
6268
+ }
6269
+ break;
6270
+ case GGML_GLU_OP_GEGLU_QUICK:
6271
+ if (dst->type == GGML_TYPE_F32) {
6272
+ kernel = backend_ctx->kernel_geglu_quick;
6273
+ } else {
6274
+ kernel = backend_ctx->kernel_geglu_quick_f16;
6275
+ }
6276
+ break;
6277
  default:
6278
  GGML_ABORT("Unsupported glu op");
6279
  }
ggml/src/ggml-opencl/kernels/glu.cl CHANGED
@@ -1,7 +1,9 @@
1
  #pragma OPENCL EXTENSION cl_khr_fp16 : enable
2
 
3
  #define GELU_COEF_A 0.044715f
 
4
  #define SQRT_2_OVER_PI 0.79788456080286535587989211986876f
 
5
 
6
  //------------------------------------------------------------------------------
7
  // geglu
@@ -199,3 +201,137 @@ kernel void kernel_swiglu_f16(
199
  dst_row[i0] = silu*x1;
200
  }
201
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  #pragma OPENCL EXTENSION cl_khr_fp16 : enable
2
 
3
  #define GELU_COEF_A 0.044715f
4
+ #define GELU_QUICK_COEF -1.702f
5
  #define SQRT_2_OVER_PI 0.79788456080286535587989211986876f
6
+ #define SQRT_2_INV 0.70710678118654752440084436210484f
7
 
8
  //------------------------------------------------------------------------------
9
  // geglu
 
201
  dst_row[i0] = silu*x1;
202
  }
203
  }
204
+
205
+ //------------------------------------------------------------------------------
206
+ // geglu_erf
207
+ //------------------------------------------------------------------------------
208
+ kernel void kernel_geglu_erf(
209
+ global char * src0,
210
+ ulong offset0,
211
+ global char * src1,
212
+ ulong offset1,
213
+ global char * dst,
214
+ ulong offsetd,
215
+ ulong nb01,
216
+ ulong nb11,
217
+ int ne0,
218
+ ulong nb1,
219
+ int ne00_off,
220
+ int ne10_off
221
+ ) {
222
+ src0 = (global char*)((global char*)src0 + offset0);
223
+ src1 = (global char*)((global char*)src1 + offset1);
224
+ dst = (global char*)((global char*)dst + offsetd);
225
+
226
+ global float * src0_row = (global float *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
227
+ global float * src1_row = (global float *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
228
+ global float * dst_row = (global float *) ((global char *) dst + get_group_id(0)*nb1);
229
+
230
+ for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
231
+ const float x0 = src0_row[i0];
232
+ const float x1 = src1_row[i0];
233
+
234
+ const float gelu_erf = 0.5f*x0*(1.0f + erf(x0*SQRT_2_INV));
235
+
236
+ dst_row[i0] = gelu_erf*x1;
237
+ }
238
+ }
239
+
240
+ kernel void kernel_geglu_erf_f16(
241
+ global char * src0,
242
+ ulong offset0,
243
+ global char * src1,
244
+ ulong offset1,
245
+ global char * dst,
246
+ ulong offsetd,
247
+ ulong nb01,
248
+ ulong nb11,
249
+ int ne0,
250
+ ulong nb1,
251
+ int ne00_off,
252
+ int ne10_off
253
+ ) {
254
+ src0 = (global char*)((global char*)src0 + offset0);
255
+ src1 = (global char*)((global char*)src1 + offset1);
256
+ dst = (global char*)((global char*)dst + offsetd);
257
+
258
+ global half * src0_row = (global half *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
259
+ global half * src1_row = (global half *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
260
+ global half * dst_row = (global half *) ((global char *) dst + get_group_id(0)*nb1);
261
+
262
+ for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
263
+ const half x0 = src0_row[i0];
264
+ const half x1 = src1_row[i0];
265
+
266
+ const half gelu_erf = 0.5f*x0*(1.0f + erf(x0*SQRT_2_INV));
267
+
268
+ dst_row[i0] = gelu_erf*x1;
269
+ }
270
+ }
271
+
272
+ //------------------------------------------------------------------------------
273
+ // geglu_quick
274
+ //------------------------------------------------------------------------------
275
+ kernel void kernel_geglu_quick(
276
+ global char * src0,
277
+ ulong offset0,
278
+ global char * src1,
279
+ ulong offset1,
280
+ global char * dst,
281
+ ulong offsetd,
282
+ ulong nb01,
283
+ ulong nb11,
284
+ int ne0,
285
+ ulong nb1,
286
+ int ne00_off,
287
+ int ne10_off
288
+ ) {
289
+ src0 = (global char*)((global char*)src0 + offset0);
290
+ src1 = (global char*)((global char*)src1 + offset1);
291
+ dst = (global char*)((global char*)dst + offsetd);
292
+
293
+ global float * src0_row = (global float *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
294
+ global float * src1_row = (global float *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
295
+ global float * dst_row = (global float *) ((global char *) dst + get_group_id(0)*nb1);
296
+
297
+ for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
298
+ const float x0 = src0_row[i0];
299
+ const float x1 = src1_row[i0];
300
+
301
+ const float gelu_quick = x0*(1.0f/(1.0f + exp(GELU_QUICK_COEF*x0)));
302
+
303
+ dst_row[i0] = gelu_quick*x1;
304
+ }
305
+ }
306
+
307
+ kernel void kernel_geglu_quick_f16(
308
+ global char * src0,
309
+ ulong offset0,
310
+ global char * src1,
311
+ ulong offset1,
312
+ global char * dst,
313
+ ulong offsetd,
314
+ ulong nb01,
315
+ ulong nb11,
316
+ int ne0,
317
+ ulong nb1,
318
+ int ne00_off,
319
+ int ne10_off
320
+ ) {
321
+ src0 = (global char*)((global char*)src0 + offset0);
322
+ src1 = (global char*)((global char*)src1 + offset1);
323
+ dst = (global char*)((global char*)dst + offsetd);
324
+
325
+ global half * src0_row = (global half *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
326
+ global half * src1_row = (global half *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
327
+ global half * dst_row = (global half *) ((global char *) dst + get_group_id(0)*nb1);
328
+
329
+ for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
330
+ const half x0 = src0_row[i0];
331
+ const half x1 = src1_row[i0];
332
+
333
+ const half gelu_quick = x0*(1.0f/(1.0f + exp(GELU_QUICK_COEF*x0)));
334
+
335
+ dst_row[i0] = gelu_quick*x1;
336
+ }
337
+ }
ggml/src/ggml-sycl/element_wise.cpp CHANGED
@@ -383,6 +383,24 @@ static void gated_op_fused_swiglu(const T * x, const T * g, T * dst, const uint6
383
  }
384
  }
385
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
386
  namespace ggml_sycl_detail {
387
  static void acc_f32_sycl(const float *x, const float *y, float *dst,
388
  const int n_elements, const int ne10, const int ne11,
@@ -978,6 +996,28 @@ static inline void ggml_sycl_op_swiglu(ggml_backend_sycl_context & ctx, ggml_ten
978
  });
979
  }
980
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
981
 
982
  void ggml_sycl_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
983
  scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
@@ -1118,3 +1158,13 @@ void ggml_sycl_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1118
  scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1119
  ggml_sycl_op_swiglu(ctx, dst);
1120
  }
 
 
 
 
 
 
 
 
 
 
 
383
  }
384
  }
385
 
386
+ template<typename T>
387
+ static void gated_op_fused_geglu_erf(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o0, const uint64_t o1, const sycl::nd_item<1> &item_ct1) {
388
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
389
+ const int64_t j0 = (i / n) * o0 + (i % n);
390
+ const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n);
391
+ dst[i] = op_gelu_erf(x[j0]) * g[j1];
392
+ }
393
+ }
394
+
395
+ template<typename T>
396
+ static void gated_op_fused_geglu_quick(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o0, const uint64_t o1, const sycl::nd_item<1> &item_ct1) {
397
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
398
+ const int64_t j0 = (i / n) * o0 + (i % n);
399
+ const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n);
400
+ dst[i] = op_gelu_quick(x[j0]) * g[j1];
401
+ }
402
+ }
403
+
404
  namespace ggml_sycl_detail {
405
  static void acc_f32_sycl(const float *x, const float *y, float *dst,
406
  const int n_elements, const int ne10, const int ne11,
 
996
  });
997
  }
998
 
999
+ static inline void ggml_sycl_op_geglu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1000
+ ggml_sycl_detail::dispatch_ggml_sycl_op_fused_glu(ctx, dst,
1001
+ [](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
1002
+ const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE);
1003
+ sycl_parallel_for(main_stream,
1004
+ sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
1005
+ gated_op_fused_geglu_erf(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
1006
+ });
1007
+ });
1008
+ }
1009
+
1010
+ static inline void ggml_sycl_op_geglu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1011
+ ggml_sycl_detail::dispatch_ggml_sycl_op_fused_glu(ctx, dst,
1012
+ [](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
1013
+ const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE);
1014
+ sycl_parallel_for(main_stream,
1015
+ sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
1016
+ gated_op_fused_geglu_quick(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
1017
+ });
1018
+ });
1019
+ }
1020
+
1021
 
1022
  void ggml_sycl_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1023
  scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
 
1158
  scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1159
  ggml_sycl_op_swiglu(ctx, dst);
1160
  }
1161
+
1162
+ void ggml_sycl_geglu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1163
+ scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1164
+ ggml_sycl_op_geglu_erf(ctx, dst);
1165
+ }
1166
+
1167
+ void ggml_sycl_geglu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1168
+ scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1169
+ ggml_sycl_op_geglu_quick(ctx, dst);
1170
+ }
ggml/src/ggml-sycl/element_wise.hpp CHANGED
@@ -80,5 +80,7 @@ void ggml_sycl_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
80
  void ggml_sycl_geglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
81
  void ggml_sycl_reglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
82
  void ggml_sycl_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
 
 
83
 
84
  #endif // GGML_SYCL_ELEMENTWISE_HPP
 
80
  void ggml_sycl_geglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
81
  void ggml_sycl_reglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
82
  void ggml_sycl_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
83
+ void ggml_sycl_geglu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
84
+ void ggml_sycl_geglu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
85
 
86
  #endif // GGML_SYCL_ELEMENTWISE_HPP
ggml/src/ggml-sycl/ggml-sycl.cpp CHANGED
@@ -3687,6 +3687,12 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
3687
  case GGML_GLU_OP_SWIGLU:
3688
  ggml_sycl_swiglu(ctx, dst);
3689
  break;
 
 
 
 
 
 
3690
  default:
3691
  return false;
3692
  }
@@ -4232,6 +4238,8 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
4232
  case GGML_GLU_OP_REGLU:
4233
  case GGML_GLU_OP_GEGLU:
4234
  case GGML_GLU_OP_SWIGLU:
 
 
4235
  return ggml_is_contiguous_1(op->src[0]);
4236
  default:
4237
  return false;
 
3687
  case GGML_GLU_OP_SWIGLU:
3688
  ggml_sycl_swiglu(ctx, dst);
3689
  break;
3690
+ case GGML_GLU_OP_GEGLU_ERF:
3691
+ ggml_sycl_geglu_erf(ctx, dst);
3692
+ break;
3693
+ case GGML_GLU_OP_GEGLU_QUICK:
3694
+ ggml_sycl_geglu_quick(ctx, dst);
3695
+ break;
3696
  default:
3697
  return false;
3698
  }
 
4238
  case GGML_GLU_OP_REGLU:
4239
  case GGML_GLU_OP_GEGLU:
4240
  case GGML_GLU_OP_SWIGLU:
4241
+ case GGML_GLU_OP_GEGLU_ERF:
4242
+ case GGML_GLU_OP_GEGLU_QUICK:
4243
  return ggml_is_contiguous_1(op->src[0]);
4244
  default:
4245
  return false;
ggml/src/ggml-vulkan/ggml-vulkan.cpp CHANGED
@@ -456,6 +456,8 @@ struct vk_device_struct {
456
  vk_pipeline pipeline_geglu[2];
457
  vk_pipeline pipeline_reglu[2];
458
  vk_pipeline pipeline_swiglu[2];
 
 
459
 
460
  vk_pipeline pipeline_leaky_relu_f32;
461
  vk_pipeline pipeline_silu_back_f32;
@@ -2821,6 +2823,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
2821
  CREATE_GLU(geglu)
2822
  CREATE_GLU(reglu)
2823
  CREATE_GLU(swiglu)
 
 
2824
  #undef CREATE_GLU
2825
 
2826
  ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
@@ -6575,6 +6579,10 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
6575
  return ctx->device->pipeline_reglu[dst->type == GGML_TYPE_F16];
6576
  case GGML_GLU_OP_SWIGLU:
6577
  return ctx->device->pipeline_swiglu[dst->type == GGML_TYPE_F16];
 
 
 
 
6578
  default:
6579
  break;
6580
  }
@@ -8919,6 +8927,8 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
8919
  case GGML_GLU_OP_GEGLU:
8920
  case GGML_GLU_OP_REGLU:
8921
  case GGML_GLU_OP_SWIGLU:
 
 
8922
  break;
8923
  default:
8924
  return false;
@@ -9166,6 +9176,8 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
9166
  case GGML_GLU_OP_GEGLU:
9167
  case GGML_GLU_OP_REGLU:
9168
  case GGML_GLU_OP_SWIGLU:
 
 
9169
  ggml_vk_glu(ctx, compute_ctx, src0, src1, node, dryrun);
9170
  break;
9171
  default:
@@ -9384,6 +9396,8 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
9384
  case GGML_GLU_OP_GEGLU:
9385
  case GGML_GLU_OP_REGLU:
9386
  case GGML_GLU_OP_SWIGLU:
 
 
9387
  buf = tensor->buffer;
9388
  break;
9389
  default:
@@ -10194,6 +10208,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
10194
  case GGML_GLU_OP_GEGLU:
10195
  case GGML_GLU_OP_REGLU:
10196
  case GGML_GLU_OP_SWIGLU:
 
 
10197
  return ggml_is_contiguous(op->src[0]) &&
10198
  (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
10199
  (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
 
456
  vk_pipeline pipeline_geglu[2];
457
  vk_pipeline pipeline_reglu[2];
458
  vk_pipeline pipeline_swiglu[2];
459
+ vk_pipeline pipeline_geglu_erf[2];
460
+ vk_pipeline pipeline_geglu_quick[2];
461
 
462
  vk_pipeline pipeline_leaky_relu_f32;
463
  vk_pipeline pipeline_silu_back_f32;
 
2823
  CREATE_GLU(geglu)
2824
  CREATE_GLU(reglu)
2825
  CREATE_GLU(swiglu)
2826
+ CREATE_GLU(geglu_erf)
2827
+ CREATE_GLU(geglu_quick)
2828
  #undef CREATE_GLU
2829
 
2830
  ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
 
6579
  return ctx->device->pipeline_reglu[dst->type == GGML_TYPE_F16];
6580
  case GGML_GLU_OP_SWIGLU:
6581
  return ctx->device->pipeline_swiglu[dst->type == GGML_TYPE_F16];
6582
+ case GGML_GLU_OP_GEGLU_ERF:
6583
+ return ctx->device->pipeline_geglu_erf[dst->type == GGML_TYPE_F16];
6584
+ case GGML_GLU_OP_GEGLU_QUICK:
6585
+ return ctx->device->pipeline_geglu_quick[dst->type == GGML_TYPE_F16];
6586
  default:
6587
  break;
6588
  }
 
8927
  case GGML_GLU_OP_GEGLU:
8928
  case GGML_GLU_OP_REGLU:
8929
  case GGML_GLU_OP_SWIGLU:
8930
+ case GGML_GLU_OP_GEGLU_ERF:
8931
+ case GGML_GLU_OP_GEGLU_QUICK:
8932
  break;
8933
  default:
8934
  return false;
 
9176
  case GGML_GLU_OP_GEGLU:
9177
  case GGML_GLU_OP_REGLU:
9178
  case GGML_GLU_OP_SWIGLU:
9179
+ case GGML_GLU_OP_GEGLU_ERF:
9180
+ case GGML_GLU_OP_GEGLU_QUICK:
9181
  ggml_vk_glu(ctx, compute_ctx, src0, src1, node, dryrun);
9182
  break;
9183
  default:
 
9396
  case GGML_GLU_OP_GEGLU:
9397
  case GGML_GLU_OP_REGLU:
9398
  case GGML_GLU_OP_SWIGLU:
9399
+ case GGML_GLU_OP_GEGLU_ERF:
9400
+ case GGML_GLU_OP_GEGLU_QUICK:
9401
  buf = tensor->buffer;
9402
  break;
9403
  default:
 
10208
  case GGML_GLU_OP_GEGLU:
10209
  case GGML_GLU_OP_REGLU:
10210
  case GGML_GLU_OP_SWIGLU:
10211
+ case GGML_GLU_OP_GEGLU_ERF:
10212
+ case GGML_GLU_OP_GEGLU_QUICK:
10213
  return ggml_is_contiguous(op->src[0]) &&
10214
  (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
10215
  (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #version 450
2
+
3
+ #include "glu_head.comp"
4
+
5
+ // based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation
6
+ // ref: https://www.johndcook.com/blog/python_erf/
7
+ const float p_erf = 0.3275911f;
8
+ const float a1_erf = 0.254829592f;
9
+ const float a2_erf = -0.284496736f;
10
+ const float a3_erf = 1.421413741f;
11
+ const float a4_erf = -1.453152027f;
12
+ const float a5_erf = 1.061405429f;
13
+
14
+ const float SQRT_2_INV = 0.70710678118654752440084436210484f;
15
+
16
+ float op(float a, float b) {
17
+ const float a_div_sqr2 = a * SQRT_2_INV;
18
+ const float sign_x = sign(a_div_sqr2);
19
+ const float x = abs(a_div_sqr2);
20
+ const float t = 1.0f / (1.0f + p_erf * x);
21
+ const float y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x);
22
+ const float erf_approx = sign_x * y;
23
+
24
+ return 0.5f * a * (1.0f + erf_approx) * b;
25
+ }
26
+
27
+ #include "glu_main.comp"
ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #version 450
2
+
3
+ #include "glu_head.comp"
4
+
5
+ const float GELU_QUICK_COEF = -1.702f;
6
+
7
+ float op(float a, float b) {
8
+ return a * (1.0f / (1.0f + exp(GELU_QUICK_COEF * a))) * b;
9
+ }
10
+
11
+ #include "glu_main.comp"
ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp CHANGED
@@ -593,6 +593,10 @@ void process_shaders() {
593
  string_to_spv("reglu_f32", "reglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
594
  string_to_spv("swiglu_f16", "swiglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
595
  string_to_spv("swiglu_f32", "swiglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
 
 
 
 
596
 
597
  string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
598
  string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
 
593
  string_to_spv("reglu_f32", "reglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
594
  string_to_spv("swiglu_f16", "swiglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
595
  string_to_spv("swiglu_f32", "swiglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
596
+ string_to_spv("geglu_erf_f16", "geglu_erf.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
597
+ string_to_spv("geglu_erf_f32", "geglu_erf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
598
+ string_to_spv("geglu_quick_f16","geglu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
599
+ string_to_spv("geglu_quick_f32","geglu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
600
 
601
  string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
602
  string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
ggml/src/ggml.c CHANGED
@@ -1132,9 +1132,11 @@ static const char * GGML_GLU_OP_NAME[GGML_GLU_OP_COUNT] = {
1132
  "REGLU",
1133
  "GEGLU",
1134
  "SWIGLU",
 
 
1135
  };
1136
 
1137
- static_assert(GGML_GLU_OP_COUNT == 3, "GGML_GLU_OP_COUNT != 3");
1138
 
1139
 
1140
  static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
@@ -2760,6 +2762,48 @@ struct ggml_tensor * ggml_swiglu_split(
2760
  return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_SWIGLU, false);
2761
  }
2762
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2763
  // ggml_norm
2764
 
2765
  static struct ggml_tensor * ggml_norm_impl(
 
1132
  "REGLU",
1133
  "GEGLU",
1134
  "SWIGLU",
1135
+ "GEGLU_ERF",
1136
+ "GEGLU_QUICK",
1137
  };
1138
 
1139
+ static_assert(GGML_GLU_OP_COUNT == 5, "GGML_GLU_OP_COUNT != 5");
1140
 
1141
 
1142
  static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
 
2762
  return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_SWIGLU, false);
2763
  }
2764
 
2765
+ // ggml_geglu_erf
2766
+
2767
+ struct ggml_tensor * ggml_geglu_erf(
2768
+ struct ggml_context * ctx,
2769
+ struct ggml_tensor * a) {
2770
+ return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU_ERF, false);
2771
+ }
2772
+
2773
+ struct ggml_tensor * ggml_geglu_erf_swapped(
2774
+ struct ggml_context * ctx,
2775
+ struct ggml_tensor * a) {
2776
+ return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU_ERF, true);
2777
+ }
2778
+
2779
+ struct ggml_tensor * ggml_geglu_erf_split(
2780
+ struct ggml_context * ctx,
2781
+ struct ggml_tensor * a,
2782
+ struct ggml_tensor * b) {
2783
+ return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_GEGLU_ERF, false);
2784
+ }
2785
+
2786
+ // ggml_geglu_quick
2787
+
2788
+ struct ggml_tensor * ggml_geglu_quick(
2789
+ struct ggml_context * ctx,
2790
+ struct ggml_tensor * a) {
2791
+ return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU_QUICK, false);
2792
+ }
2793
+
2794
+ struct ggml_tensor * ggml_geglu_quick_swapped(
2795
+ struct ggml_context * ctx,
2796
+ struct ggml_tensor * a) {
2797
+ return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU_QUICK, true);
2798
+ }
2799
+
2800
+ struct ggml_tensor * ggml_geglu_quick_split(
2801
+ struct ggml_context * ctx,
2802
+ struct ggml_tensor * a,
2803
+ struct ggml_tensor * b) {
2804
+ return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_GEGLU_QUICK, false);
2805
+ }
2806
+
2807
  // ggml_norm
2808
 
2809
  static struct ggml_tensor * ggml_norm_impl(