lhez commited on
Commit
3261fcd
·
1 Parent(s): 1830e27

opencl: add multi and vision rope, `gelu_quick` and `im2col` (llama/12600)

Browse files

* opencl: add `im2col`

* opencl: add `gelu_quick`

* opencl: add mrope

* opencl: add vision rope

ggml/src/ggml-opencl/CMakeLists.txt CHANGED
@@ -63,6 +63,7 @@ set(GGML_OPENCL_KERNELS
63
  ggml-opencl_transpose_16
64
  ggml-opencl_transpose_32
65
  ggml-opencl_transpose_32_16
 
66
  )
67
 
68
  foreach (K ${GGML_OPENCL_KERNELS})
 
63
  ggml-opencl_transpose_16
64
  ggml-opencl_transpose_32
65
  ggml-opencl_transpose_32_16
66
+ ggml-opencl_im2col
67
  )
68
 
69
  foreach (K ${GGML_OPENCL_KERNELS})
ggml/src/ggml-opencl/ggml-opencl.cpp CHANGED
@@ -224,12 +224,14 @@ struct ggml_backend_opencl_context {
224
  cl_program program;
225
  cl_program program_1;
226
  cl_program program_2;
 
227
 
228
  cl_kernel kernel_add, kernel_add_row;
229
  cl_kernel kernel_mul, kernel_mul_row;
230
  cl_kernel kernel_scale;
231
  cl_kernel kernel_silu, kernel_silu_4;
232
  cl_kernel kernel_gelu, kernel_gelu_4;
 
233
  cl_kernel kernel_relu;
234
  cl_kernel kernel_clamp;
235
  cl_kernel kernel_norm;
@@ -239,6 +241,7 @@ struct ggml_backend_opencl_context {
239
  cl_kernel kernel_soft_max_f16, kernel_soft_max_4_f16;
240
  cl_kernel kernel_get_rows_f32, kernel_get_rows_f16, kernel_get_rows_q4_0;
241
  cl_kernel kernel_rope_norm_f32, kernel_rope_norm_f16, kernel_rope_neox_f32, kernel_rope_neox_f16;
 
242
  cl_kernel kernel_cpy_f16_f16, kernel_cpy_f16_f32, kernel_cpy_f32_f16, kernel_cpy_f32_f32;
243
  cl_kernel kernel_mul_mat_f32_f32;
244
  cl_kernel kernel_mul_mat_f16_f16;
@@ -252,6 +255,7 @@ struct ggml_backend_opencl_context {
252
  kernel_mul_mat_q4_0_f32_flat_img_v0;
253
  cl_kernel kernel_mul_mat_q4_0_f32_1d_8x_flat, kernel_mul_mat_q4_0_f32_1d_16x_flat;
254
  cl_kernel kernel_mul_mv_q6_K_f32;
 
255
 
256
  #ifdef GGML_OPENCL_USE_ADRENO_KERNELS
257
  // Transpose kernels
@@ -708,6 +712,8 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
708
  CL_CHECK((backend_ctx->kernel_silu_4 = clCreateKernel(backend_ctx->program, "kernel_silu_4", &err), err));
709
  CL_CHECK((backend_ctx->kernel_gelu = clCreateKernel(backend_ctx->program, "kernel_gelu", &err), err));
710
  CL_CHECK((backend_ctx->kernel_gelu_4 = clCreateKernel(backend_ctx->program, "kernel_gelu_4", &err), err));
 
 
711
  CL_CHECK((backend_ctx->kernel_relu = clCreateKernel(backend_ctx->program, "kernel_relu", &err), err));
712
  CL_CHECK((backend_ctx->kernel_clamp = clCreateKernel(backend_ctx->program, "kernel_clamp", &err), err));
713
  CL_CHECK((backend_ctx->kernel_norm = clCreateKernel(backend_ctx->program, "kernel_norm", &err), err));
@@ -722,6 +728,10 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
722
  CL_CHECK((backend_ctx->kernel_rope_norm_f16 = clCreateKernel(backend_ctx->program, "kernel_rope_norm_f16", &err), err));
723
  CL_CHECK((backend_ctx->kernel_rope_neox_f32 = clCreateKernel(backend_ctx->program, "kernel_rope_neox_f32", &err), err));
724
  CL_CHECK((backend_ctx->kernel_rope_neox_f16 = clCreateKernel(backend_ctx->program, "kernel_rope_neox_f16", &err), err));
 
 
 
 
725
  CL_CHECK((backend_ctx->kernel_cpy_f16_f16 = clCreateKernel(backend_ctx->program, "kernel_cpy_f16_f16", &err), err));
726
  CL_CHECK((backend_ctx->kernel_cpy_f16_f32 = clCreateKernel(backend_ctx->program, "kernel_cpy_f16_f32", &err), err));
727
  CL_CHECK((backend_ctx->kernel_cpy_f32_f16 = clCreateKernel(backend_ctx->program, "kernel_cpy_f32_f16", &err), err));
@@ -769,6 +779,19 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
769
 
770
  CL_CHECK((backend_ctx->kernel_convert_block_q4_0_noshuffle = clCreateKernel(backend_ctx->program_2, "kernel_convert_block_q4_0_noshuffle", &err), err));
771
 
 
 
 
 
 
 
 
 
 
 
 
 
 
772
  // Kernels for Adreno
773
  #ifdef GGML_OPENCL_USE_ADRENO_KERNELS
774
  #ifdef GGML_OPENCL_EMBED_KERNELS
@@ -1187,6 +1210,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
1187
  case GGML_UNARY_OP_GELU:
1188
  case GGML_UNARY_OP_SILU:
1189
  case GGML_UNARY_OP_RELU:
 
1190
  return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
1191
  default:
1192
  return false;
@@ -1216,14 +1240,26 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
1216
  return op->ne[3] == 1;
1217
  case GGML_OP_ROPE: {
1218
  const int mode = ((const int32_t *) op->op_params)[2];
1219
- if (mode & GGML_ROPE_TYPE_MROPE) {
 
 
 
 
 
 
1220
  return false;
1221
  }
1222
- if (mode & GGML_ROPE_TYPE_VISION) {
 
 
 
 
1223
  return false;
1224
  }
1225
  return true;
1226
  }
 
 
1227
  default:
1228
  return false;
1229
  }
@@ -2582,6 +2618,53 @@ static void ggml_cl_gelu(ggml_backend_t backend, const ggml_tensor * src0, const
2582
  #endif
2583
  }
2584
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2585
  static void ggml_cl_silu(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
2586
  GGML_ASSERT(src0);
2587
  GGML_ASSERT(src0->extra);
@@ -3980,6 +4063,7 @@ static void ggml_cl_rope(ggml_backend_t backend, const ggml_tensor * src0, const
3980
  float attn_factor;
3981
  float beta_fast;
3982
  float beta_slow;
 
3983
 
3984
  memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
3985
  memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
@@ -3987,29 +4071,62 @@ static void ggml_cl_rope(ggml_backend_t backend, const ggml_tensor * src0, const
3987
  memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
3988
  memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
3989
  memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
 
3990
 
3991
  const bool is_neox = mode & 2;
 
 
 
 
 
 
 
 
 
 
3992
 
3993
  cl_kernel kernel;
3994
 
3995
- if (!is_neox) {
3996
  switch (src0->type) {
3997
  case GGML_TYPE_F32:
3998
- kernel = backend_ctx->kernel_rope_norm_f32;
3999
  break;
4000
  case GGML_TYPE_F16:
4001
- kernel = backend_ctx->kernel_rope_norm_f16;
 
 
 
 
 
 
 
 
 
 
 
4002
  break;
4003
  default:
4004
  GGML_ASSERT(false);
4005
  };
 
 
 
 
 
 
 
 
 
 
 
4006
  } else {
4007
  switch (src0->type) {
4008
  case GGML_TYPE_F32:
4009
- kernel = backend_ctx->kernel_rope_neox_f32;
4010
  break;
4011
  case GGML_TYPE_F16:
4012
- kernel = backend_ctx->kernel_rope_neox_f16;
4013
  break;
4014
  default:
4015
  GGML_ASSERT(false);
@@ -4049,6 +4166,9 @@ static void ggml_cl_rope(ggml_backend_t backend, const ggml_tensor * src0, const
4049
  CL_CHECK(clSetKernelArg(kernel, 30, sizeof(float), &attn_factor));
4050
  CL_CHECK(clSetKernelArg(kernel, 31, sizeof(float), &beta_fast));
4051
  CL_CHECK(clSetKernelArg(kernel, 32, sizeof(float), &beta_slow));
 
 
 
4052
 
4053
  size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
4054
  size_t local_work_size[] = {(size_t)nth, 1, 1};
@@ -4064,6 +4184,98 @@ static void ggml_cl_rope(ggml_backend_t backend, const ggml_tensor * src0, const
4064
  #endif
4065
  }
4066
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4067
  //------------------------------------------------------------------------------
4068
  // Op offloading
4069
  //------------------------------------------------------------------------------
@@ -4122,6 +4334,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
4122
  }
4123
  func = ggml_cl_gelu;
4124
  break;
 
 
 
 
 
 
4125
  case GGML_UNARY_OP_SILU:
4126
  if (!any_on_device) {
4127
  return false;
@@ -4194,6 +4412,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
4194
  }
4195
  func = ggml_cl_rope;
4196
  break;
 
 
 
 
 
 
4197
  default:
4198
  return false;
4199
  }
 
224
  cl_program program;
225
  cl_program program_1;
226
  cl_program program_2;
227
+ cl_program program_im2col;
228
 
229
  cl_kernel kernel_add, kernel_add_row;
230
  cl_kernel kernel_mul, kernel_mul_row;
231
  cl_kernel kernel_scale;
232
  cl_kernel kernel_silu, kernel_silu_4;
233
  cl_kernel kernel_gelu, kernel_gelu_4;
234
+ cl_kernel kernel_gelu_quick, kernel_gelu_quick_4;
235
  cl_kernel kernel_relu;
236
  cl_kernel kernel_clamp;
237
  cl_kernel kernel_norm;
 
241
  cl_kernel kernel_soft_max_f16, kernel_soft_max_4_f16;
242
  cl_kernel kernel_get_rows_f32, kernel_get_rows_f16, kernel_get_rows_q4_0;
243
  cl_kernel kernel_rope_norm_f32, kernel_rope_norm_f16, kernel_rope_neox_f32, kernel_rope_neox_f16;
244
+ cl_kernel kernel_rope_multi_f32, kernel_rope_multi_f16, kernel_rope_vision_f32, kernel_rope_vision_f16;
245
  cl_kernel kernel_cpy_f16_f16, kernel_cpy_f16_f32, kernel_cpy_f32_f16, kernel_cpy_f32_f32;
246
  cl_kernel kernel_mul_mat_f32_f32;
247
  cl_kernel kernel_mul_mat_f16_f16;
 
255
  kernel_mul_mat_q4_0_f32_flat_img_v0;
256
  cl_kernel kernel_mul_mat_q4_0_f32_1d_8x_flat, kernel_mul_mat_q4_0_f32_1d_16x_flat;
257
  cl_kernel kernel_mul_mv_q6_K_f32;
258
+ cl_kernel kernel_im2col_f32, kernel_im2col_f16;
259
 
260
  #ifdef GGML_OPENCL_USE_ADRENO_KERNELS
261
  // Transpose kernels
 
712
  CL_CHECK((backend_ctx->kernel_silu_4 = clCreateKernel(backend_ctx->program, "kernel_silu_4", &err), err));
713
  CL_CHECK((backend_ctx->kernel_gelu = clCreateKernel(backend_ctx->program, "kernel_gelu", &err), err));
714
  CL_CHECK((backend_ctx->kernel_gelu_4 = clCreateKernel(backend_ctx->program, "kernel_gelu_4", &err), err));
715
+ CL_CHECK((backend_ctx->kernel_gelu_quick = clCreateKernel(backend_ctx->program, "kernel_gelu_quick", &err), err));
716
+ CL_CHECK((backend_ctx->kernel_gelu_quick_4 = clCreateKernel(backend_ctx->program, "kernel_gelu_quick_4", &err), err));
717
  CL_CHECK((backend_ctx->kernel_relu = clCreateKernel(backend_ctx->program, "kernel_relu", &err), err));
718
  CL_CHECK((backend_ctx->kernel_clamp = clCreateKernel(backend_ctx->program, "kernel_clamp", &err), err));
719
  CL_CHECK((backend_ctx->kernel_norm = clCreateKernel(backend_ctx->program, "kernel_norm", &err), err));
 
728
  CL_CHECK((backend_ctx->kernel_rope_norm_f16 = clCreateKernel(backend_ctx->program, "kernel_rope_norm_f16", &err), err));
729
  CL_CHECK((backend_ctx->kernel_rope_neox_f32 = clCreateKernel(backend_ctx->program, "kernel_rope_neox_f32", &err), err));
730
  CL_CHECK((backend_ctx->kernel_rope_neox_f16 = clCreateKernel(backend_ctx->program, "kernel_rope_neox_f16", &err), err));
731
+ CL_CHECK((backend_ctx->kernel_rope_multi_f32 = clCreateKernel(backend_ctx->program, "kernel_rope_multi_f32", &err), err));
732
+ CL_CHECK((backend_ctx->kernel_rope_multi_f16 = clCreateKernel(backend_ctx->program, "kernel_rope_multi_f16", &err), err));
733
+ CL_CHECK((backend_ctx->kernel_rope_vision_f32 = clCreateKernel(backend_ctx->program, "kernel_rope_vision_f32", &err), err));
734
+ CL_CHECK((backend_ctx->kernel_rope_vision_f16 = clCreateKernel(backend_ctx->program, "kernel_rope_vision_f16", &err), err));
735
  CL_CHECK((backend_ctx->kernel_cpy_f16_f16 = clCreateKernel(backend_ctx->program, "kernel_cpy_f16_f16", &err), err));
736
  CL_CHECK((backend_ctx->kernel_cpy_f16_f32 = clCreateKernel(backend_ctx->program, "kernel_cpy_f16_f32", &err), err));
737
  CL_CHECK((backend_ctx->kernel_cpy_f32_f16 = clCreateKernel(backend_ctx->program, "kernel_cpy_f32_f16", &err), err));
 
779
 
780
  CL_CHECK((backend_ctx->kernel_convert_block_q4_0_noshuffle = clCreateKernel(backend_ctx->program_2, "kernel_convert_block_q4_0_noshuffle", &err), err));
781
 
782
+ // im2col kernels
783
+ #ifdef GGML_OPENCL_EMBED_KERNELS
784
+ const std::string kernel_src_im2col {
785
+ #include "ggml-opencl_im2col.cl.h"
786
+ };
787
+ #else
788
+ const std::string kernel_src_im2col = read_file("ggml-opencl_im2col.cl");
789
+ #endif
790
+ backend_ctx->program_im2col = build_program_from_source(context, device, kernel_src_im2col.c_str(), compile_opts);
791
+
792
+ CL_CHECK((backend_ctx->kernel_im2col_f32 = clCreateKernel(backend_ctx->program_im2col, "kernel_im2col_f32", &err), err));
793
+ CL_CHECK((backend_ctx->kernel_im2col_f16 = clCreateKernel(backend_ctx->program_im2col, "kernel_im2col_f16", &err), err));
794
+
795
  // Kernels for Adreno
796
  #ifdef GGML_OPENCL_USE_ADRENO_KERNELS
797
  #ifdef GGML_OPENCL_EMBED_KERNELS
 
1210
  case GGML_UNARY_OP_GELU:
1211
  case GGML_UNARY_OP_SILU:
1212
  case GGML_UNARY_OP_RELU:
1213
+ case GGML_UNARY_OP_GELU_QUICK:
1214
  return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
1215
  default:
1216
  return false;
 
1240
  return op->ne[3] == 1;
1241
  case GGML_OP_ROPE: {
1242
  const int mode = ((const int32_t *) op->op_params)[2];
1243
+ const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
1244
+ const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
1245
+ if (is_mrope && !is_vision) {
1246
+ if (op->src[0]->type == GGML_TYPE_F32 ||
1247
+ op->src[0]->type == GGML_TYPE_F16) {
1248
+ return true;
1249
+ }
1250
  return false;
1251
  }
1252
+ if (is_vision) {
1253
+ if (op->src[0]->type == GGML_TYPE_F32 ||
1254
+ op->src[0]->type == GGML_TYPE_F16) {
1255
+ return true;
1256
+ }
1257
  return false;
1258
  }
1259
  return true;
1260
  }
1261
+ case GGML_OP_IM2COL:
1262
+ return true;
1263
  default:
1264
  return false;
1265
  }
 
2618
  #endif
2619
  }
2620
 
2621
+ static void ggml_cl_gelu_quick(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
2622
+ GGML_ASSERT(src0);
2623
+ GGML_ASSERT(src0->extra);
2624
+ GGML_ASSERT(dst);
2625
+ GGML_ASSERT(dst->extra);
2626
+
2627
+ UNUSED(src1);
2628
+
2629
+ ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
2630
+ cl_command_queue queue = backend_ctx->queue;
2631
+
2632
+ ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
2633
+ ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
2634
+
2635
+ cl_ulong offset0 = extra0->offset + src0->view_offs;
2636
+ cl_ulong offsetd = extrad->offset + dst->view_offs;
2637
+
2638
+ cl_kernel kernel;
2639
+
2640
+ int n = ggml_nelements(dst);
2641
+
2642
+ if (n % 4 == 0) {
2643
+ kernel = backend_ctx->kernel_gelu_quick_4;
2644
+ n /= 4;
2645
+ } else {
2646
+ kernel = backend_ctx->kernel_gelu_quick;
2647
+ }
2648
+
2649
+ CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
2650
+ CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
2651
+ CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
2652
+ CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
2653
+
2654
+ size_t global_work_size[] = {(size_t)n, 1, 1};
2655
+ size_t local_work_size[] = {64, 1, 1};
2656
+
2657
+ #ifdef GGML_OPENCL_PROFILING
2658
+ cl_event evt;
2659
+ clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt);
2660
+
2661
+ g_profiling_info.emplace_back();
2662
+ populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst);
2663
+ #else
2664
+ clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL);
2665
+ #endif
2666
+ }
2667
+
2668
  static void ggml_cl_silu(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
2669
  GGML_ASSERT(src0);
2670
  GGML_ASSERT(src0->extra);
 
4063
  float attn_factor;
4064
  float beta_fast;
4065
  float beta_slow;
4066
+ int32_t sections[4];
4067
 
4068
  memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
4069
  memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
 
4071
  memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
4072
  memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
4073
  memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
4074
+ memcpy(&sections, (int32_t *) dst->op_params + 11, sizeof(int32_t)*4);
4075
 
4076
  const bool is_neox = mode & 2;
4077
+ const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
4078
+ const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
4079
+
4080
+ if (is_mrope) {
4081
+ GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
4082
+ }
4083
+
4084
+ if (is_vision) {
4085
+ GGML_ASSERT(n_dims == ne00/2);
4086
+ }
4087
 
4088
  cl_kernel kernel;
4089
 
4090
+ if (is_neox) {
4091
  switch (src0->type) {
4092
  case GGML_TYPE_F32:
4093
+ kernel = backend_ctx->kernel_rope_neox_f32;
4094
  break;
4095
  case GGML_TYPE_F16:
4096
+ kernel = backend_ctx->kernel_rope_neox_f16;
4097
+ break;
4098
+ default:
4099
+ GGML_ASSERT(false);
4100
+ };
4101
+ } else if (is_mrope && !is_vision) {
4102
+ switch (src0->type) {
4103
+ case GGML_TYPE_F32:
4104
+ kernel = backend_ctx->kernel_rope_multi_f32;
4105
+ break;
4106
+ case GGML_TYPE_F16:
4107
+ kernel = backend_ctx->kernel_rope_multi_f16;
4108
  break;
4109
  default:
4110
  GGML_ASSERT(false);
4111
  };
4112
+ } else if (is_vision) {
4113
+ switch (src0->type) {
4114
+ case GGML_TYPE_F32:
4115
+ kernel = backend_ctx->kernel_rope_vision_f32;
4116
+ break;
4117
+ case GGML_TYPE_F16:
4118
+ kernel = backend_ctx->kernel_rope_vision_f16;
4119
+ break;
4120
+ default:
4121
+ GGML_ASSERT(false);
4122
+ }
4123
  } else {
4124
  switch (src0->type) {
4125
  case GGML_TYPE_F32:
4126
+ kernel = backend_ctx->kernel_rope_norm_f32;
4127
  break;
4128
  case GGML_TYPE_F16:
4129
+ kernel = backend_ctx->kernel_rope_norm_f16;
4130
  break;
4131
  default:
4132
  GGML_ASSERT(false);
 
4166
  CL_CHECK(clSetKernelArg(kernel, 30, sizeof(float), &attn_factor));
4167
  CL_CHECK(clSetKernelArg(kernel, 31, sizeof(float), &beta_fast));
4168
  CL_CHECK(clSetKernelArg(kernel, 32, sizeof(float), &beta_slow));
4169
+ if (is_mrope || is_vision) {
4170
+ CL_CHECK(clSetKernelArg(kernel, 33, sizeof(int32_t)*4, &sections));
4171
+ }
4172
 
4173
  size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
4174
  size_t local_work_size[] = {(size_t)nth, 1, 1};
 
4184
  #endif
4185
  }
4186
 
4187
+ static void ggml_cl_im2col(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
4188
+ GGML_ASSERT(src0);
4189
+ GGML_ASSERT(src1);
4190
+ GGML_ASSERT(src1->extra);
4191
+ GGML_ASSERT(dst);
4192
+ GGML_ASSERT(dst->extra);
4193
+
4194
+ // src0 - filter, src1 - input
4195
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
4196
+ GGML_ASSERT(dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
4197
+
4198
+ ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
4199
+ cl_command_queue queue = backend_ctx->queue;
4200
+
4201
+ ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
4202
+ ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
4203
+
4204
+ cl_ulong offset1 = extra1->offset + src1->view_offs;
4205
+ cl_ulong offsetd = extrad->offset + dst->view_offs;
4206
+
4207
+ const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
4208
+ const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
4209
+ const int32_t p0 = ((const int32_t*)(dst->op_params))[2];
4210
+ const int32_t p1 = ((const int32_t*)(dst->op_params))[3];
4211
+ const int32_t d0 = ((const int32_t*)(dst->op_params))[4];
4212
+ const int32_t d1 = ((const int32_t*)(dst->op_params))[5];
4213
+
4214
+ const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1;
4215
+
4216
+ const cl_long IC = src1->ne[is_2D ? 2 : 1];
4217
+ const cl_long IH = is_2D ? src1->ne[1] : 1;
4218
+ const cl_long IW = src1->ne[0];
4219
+
4220
+ const cl_long KH = is_2D ? src0->ne[1] : 1;
4221
+ const cl_long KW = src0->ne[0];
4222
+
4223
+ const cl_long OH = is_2D ? dst->ne[2] : 1;
4224
+ const cl_long OW = dst->ne[1];
4225
+
4226
+ // nb is byte offset, src is type float32
4227
+ const cl_ulong delta_offset = src1->nb[is_2D ? 2 : 1]/4;
4228
+ const cl_long batch = src1->ne[is_2D ? 3 : 2];
4229
+ const cl_ulong batch_offset = src1->nb[is_2D ? 3 : 2]/4;
4230
+
4231
+ const cl_long pelements = OW*KW*KH;
4232
+ const cl_long CHW = IC*KH*KW;
4233
+
4234
+ cl_kernel kernel;
4235
+
4236
+ if(dst->type == GGML_TYPE_F16) {
4237
+ kernel = backend_ctx->kernel_im2col_f16;
4238
+ } else {
4239
+ kernel = backend_ctx->kernel_im2col_f32;
4240
+ }
4241
+
4242
+ CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra1->data_device));
4243
+ CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset1));
4244
+ CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
4245
+ CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
4246
+ CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &batch_offset));
4247
+ CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &delta_offset));
4248
+ CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_long), &IW));
4249
+ CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_long), &IH));
4250
+ CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_long), &IC));
4251
+ CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_long), &OW));
4252
+ CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_long), &OH));
4253
+ CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_long), &KW));
4254
+ CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_long), &KH));
4255
+ CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_long), &pelements));
4256
+ CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_long), &CHW));
4257
+ CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &s0));
4258
+ CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &s1));
4259
+ CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &p0));
4260
+ CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &p1));
4261
+ CL_CHECK(clSetKernelArg(kernel, 19, sizeof(int), &d0));
4262
+ CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int), &d1));
4263
+
4264
+ const int num_blocks = (pelements + 256 - 1) / 256;
4265
+ size_t global_work_size[] = {(size_t)num_blocks*256, (size_t)OH, (size_t)batch*IC};
4266
+ size_t local_work_size[] = {256, 1, 1};
4267
+
4268
+ #ifdef GGML_OPENCL_PROFILING
4269
+ cl_event evt;
4270
+ CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
4271
+
4272
+ g_profiling_info.emplace_back();
4273
+ populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst);
4274
+ #else
4275
+ CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL));
4276
+ #endif
4277
+ }
4278
+
4279
  //------------------------------------------------------------------------------
4280
  // Op offloading
4281
  //------------------------------------------------------------------------------
 
4334
  }
4335
  func = ggml_cl_gelu;
4336
  break;
4337
+ case GGML_UNARY_OP_GELU_QUICK:
4338
+ if (!any_on_device) {
4339
+ return false;
4340
+ }
4341
+ func = ggml_cl_gelu_quick;
4342
+ break;
4343
  case GGML_UNARY_OP_SILU:
4344
  if (!any_on_device) {
4345
  return false;
 
4412
  }
4413
  func = ggml_cl_rope;
4414
  break;
4415
+ case GGML_OP_IM2COL:
4416
+ if (!any_on_device) {
4417
+ return false;
4418
+ }
4419
+ func = ggml_cl_im2col;
4420
+ break;
4421
  default:
4422
  return false;
4423
  }
ggml/src/ggml-opencl/kernels/ggml-opencl.cl CHANGED
@@ -404,6 +404,7 @@ kernel void kernel_scale(
404
  // gelu
405
  //------------------------------------------------------------------------------
406
  #define GELU_COEF_A 0.044715f
 
407
  #define SQRT_2_OVER_PI 0.79788456080286535587989211986876f
408
 
409
  kernel void kernel_gelu(
@@ -434,6 +435,32 @@ kernel void kernel_gelu_4(
434
  dst[get_global_id(0)] = 0.5f*x*(1.0f + tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
435
  }
436
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
437
  //------------------------------------------------------------------------------
438
  // silu
439
  //------------------------------------------------------------------------------
@@ -1325,6 +1352,368 @@ kernel void kernel_rope_neox_f16(
1325
  }
1326
  }
1327
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1328
  //------------------------------------------------------------------------------
1329
  // cpy
1330
  //------------------------------------------------------------------------------
 
404
  // gelu
405
  //------------------------------------------------------------------------------
406
  #define GELU_COEF_A 0.044715f
407
+ #define GELU_QUICK_COEF -1.702f
408
  #define SQRT_2_OVER_PI 0.79788456080286535587989211986876f
409
 
410
  kernel void kernel_gelu(
 
435
  dst[get_global_id(0)] = 0.5f*x*(1.0f + tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
436
  }
437
 
438
+ kernel void kernel_gelu_quick(
439
+ global float * src0,
440
+ ulong offset0,
441
+ global float * dst,
442
+ ulong offsetd
443
+ ) {
444
+ src0 = (global float*)((global char*)src0 + offset0);
445
+ dst = (global float*)((global char*)dst + offsetd);
446
+
447
+ float x = src0[get_global_id(0)];
448
+ dst[get_global_id(0)] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
449
+ }
450
+
451
+ kernel void kernel_gelu_quick_4(
452
+ global float4 * src0,
453
+ ulong offset0,
454
+ global float4 * dst,
455
+ ulong offsetd
456
+ ) {
457
+ src0 = (global float4*)((global char*)src0 + offset0);
458
+ dst = (global float4*)((global char*)dst + offsetd);
459
+
460
+ float4 x = src0[get_global_id(0)];
461
+ dst[get_global_id(0)] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
462
+ }
463
+
464
  //------------------------------------------------------------------------------
465
  // silu
466
  //------------------------------------------------------------------------------
 
1352
  }
1353
  }
1354
 
1355
+ kernel void kernel_rope_multi_f32(
1356
+ global void * src0,
1357
+ ulong offset0,
1358
+ global int * src1,
1359
+ ulong offset1,
1360
+ global float * src2,
1361
+ ulong offset2,
1362
+ global float * dst,
1363
+ ulong offsetd,
1364
+ int ne00,
1365
+ int ne01,
1366
+ int ne02,
1367
+ int ne03,
1368
+ ulong nb00,
1369
+ ulong nb01,
1370
+ ulong nb02,
1371
+ ulong nb03,
1372
+ int ne0,
1373
+ int ne1,
1374
+ int ne2,
1375
+ int ne3,
1376
+ ulong nb0,
1377
+ ulong nb1,
1378
+ ulong nb2,
1379
+ ulong nb3,
1380
+ int n_past,
1381
+ int n_dims,
1382
+ int n_ctx_orig,
1383
+ float freq_base,
1384
+ float freq_scale,
1385
+ float ext_factor,
1386
+ float attn_factor,
1387
+ float beta_fast,
1388
+ float beta_slow,
1389
+ int4 sections
1390
+ ) {
1391
+ src0 = (global void*)((global char*)src0 + offset0);
1392
+ src1 = (global int*)((global char*)src1 + offset1);
1393
+ src2 = (global float*)((global char*)src2 + offset2);
1394
+ dst = (global float*)((global char*)dst + offsetd);
1395
+
1396
+ int i3 = get_group_id(2);
1397
+ int i2 = get_group_id(1);
1398
+ int i1 = get_group_id(0);
1399
+
1400
+ float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow);
1401
+
1402
+ global int * pos = src1;
1403
+
1404
+ const int sect_dims = sections.s0 + sections.s1 + sections.s2 + sections.s3;
1405
+ const int sec_w = sections.s1 + sections.s0;
1406
+
1407
+ float inv_ndims = -1.f/n_dims;
1408
+
1409
+ for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) {
1410
+ if (i0 < n_dims) {
1411
+ int ic = i0/2;
1412
+
1413
+ const int sector = (i0 / 2) % sect_dims;
1414
+ float theta_base = 0.0f;
1415
+
1416
+ if (sector < sections.s0) {
1417
+ theta_base = pos[i2];
1418
+ }
1419
+ else if (sector >= sections.s0 && sector < sec_w) {
1420
+ theta_base = pos[i2 + ne2 * 1];
1421
+ }
1422
+ else if (sector >= sec_w && sector < sec_w + sections.s2) {
1423
+ theta_base = pos[i2 + ne2 * 2];
1424
+ }
1425
+ else if (sector >= sec_w + sections.s2) {
1426
+ theta_base = pos[i2 + ne2 * 3];
1427
+ }
1428
+
1429
+ const float theta = theta_base * pow(freq_base, inv_ndims*i0);
1430
+
1431
+ const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
1432
+
1433
+ float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor);
1434
+
1435
+ global float * src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
1436
+ global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
1437
+
1438
+ const float x0 = src[0];
1439
+ const float x1 = src[n_dims/2];
1440
+
1441
+ dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1;
1442
+ dst_data[n_dims/2] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0;
1443
+ } else {
1444
+ global float * const src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1445
+ global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1446
+
1447
+ dst_data[0] = src[0];
1448
+ dst_data[1] = src[1];
1449
+ }
1450
+ }
1451
+ }
1452
+
1453
+ kernel void kernel_rope_multi_f16(
1454
+ global void * src0,
1455
+ ulong offset0,
1456
+ global int * src1,
1457
+ ulong offset1,
1458
+ global float * src2,
1459
+ ulong offset2,
1460
+ global half * dst,
1461
+ ulong offsetd,
1462
+ int ne00,
1463
+ int ne01,
1464
+ int ne02,
1465
+ int ne03,
1466
+ ulong nb00,
1467
+ ulong nb01,
1468
+ ulong nb02,
1469
+ ulong nb03,
1470
+ int ne0,
1471
+ int ne1,
1472
+ int ne2,
1473
+ int ne3,
1474
+ ulong nb0,
1475
+ ulong nb1,
1476
+ ulong nb2,
1477
+ ulong nb3,
1478
+ int n_past,
1479
+ int n_dims,
1480
+ int n_ctx_orig,
1481
+ float freq_base,
1482
+ float freq_scale,
1483
+ float ext_factor,
1484
+ float attn_factor,
1485
+ float beta_fast,
1486
+ float beta_slow,
1487
+ int4 sections
1488
+ ) {
1489
+ src0 = (global void*)((global char*)src0 + offset0);
1490
+ src1 = (global int*)((global char*)src1 + offset1);
1491
+ src2 = (global float*)((global char*)src2 + offset2);
1492
+ dst = (global float*)((global char*)dst + offsetd);
1493
+
1494
+ int i3 = get_group_id(2);
1495
+ int i2 = get_group_id(1);
1496
+ int i1 = get_group_id(0);
1497
+
1498
+ float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow);
1499
+
1500
+ global int * pos = src1;
1501
+
1502
+ const int sect_dims = sections.s0 + sections.s1 + sections.s2 + sections.s3;
1503
+ const int sec_w = sections.s1 + sections.s0;
1504
+
1505
+ float inv_ndims = -1.f/n_dims;
1506
+
1507
+ for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) {
1508
+ if (i0 < n_dims) {
1509
+ int ic = i0/2;
1510
+
1511
+ const int sector = (i0 / 2) % sect_dims;
1512
+ float theta_base = 0.0f;
1513
+
1514
+ if (sector < sections.s0) {
1515
+ theta_base = pos[i2];
1516
+ }
1517
+ else if (sector >= sections.s0 && sector < sec_w) {
1518
+ theta_base = pos[i2 + ne2 * 1];
1519
+ }
1520
+ else if (sector >= sec_w && sector < sec_w + sections.s2) {
1521
+ theta_base = pos[i2 + ne2 * 2];
1522
+ }
1523
+ else if (sector >= sec_w + sections.s2) {
1524
+ theta_base = pos[i2 + ne2 * 3];
1525
+ }
1526
+
1527
+ const float theta = theta_base * pow(freq_base, inv_ndims*i0);
1528
+
1529
+ const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
1530
+
1531
+ float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor);
1532
+
1533
+ global half * src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
1534
+ global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
1535
+
1536
+ const float x0 = src[0];
1537
+ const float x1 = src[n_dims/2];
1538
+
1539
+ dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1;
1540
+ dst_data[n_dims/2] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0;
1541
+ } else {
1542
+ global half * const src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1543
+ global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1544
+
1545
+ dst_data[0] = src[0];
1546
+ dst_data[1] = src[1];
1547
+ }
1548
+ }
1549
+ }
1550
+
1551
+ kernel void kernel_rope_vision_f32(
1552
+ global void * src0,
1553
+ ulong offset0,
1554
+ global int * src1,
1555
+ ulong offset1,
1556
+ global float * src2,
1557
+ ulong offset2,
1558
+ global float * dst,
1559
+ ulong offsetd,
1560
+ int ne00,
1561
+ int ne01,
1562
+ int ne02,
1563
+ int ne03,
1564
+ ulong nb00,
1565
+ ulong nb01,
1566
+ ulong nb02,
1567
+ ulong nb03,
1568
+ int ne0,
1569
+ int ne1,
1570
+ int ne2,
1571
+ int ne3,
1572
+ ulong nb0,
1573
+ ulong nb1,
1574
+ ulong nb2,
1575
+ ulong nb3,
1576
+ int n_past,
1577
+ int n_dims,
1578
+ int n_ctx_orig,
1579
+ float freq_base,
1580
+ float freq_scale,
1581
+ float ext_factor,
1582
+ float attn_factor,
1583
+ float beta_fast,
1584
+ float beta_slow,
1585
+ int4 sections
1586
+ ) {
1587
+ src0 = (global void*)((global char*)src0 + offset0);
1588
+ src1 = (global int*)((global char*)src1 + offset1);
1589
+ src2 = (global float*)((global char*)src2 + offset2);
1590
+ dst = (global float*)((global char*)dst + offsetd);
1591
+
1592
+ int i3 = get_group_id(2);
1593
+ int i2 = get_group_id(1);
1594
+ int i1 = get_group_id(0);
1595
+
1596
+ float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow);
1597
+
1598
+ global int * pos = src1;
1599
+
1600
+ const int sect_dims = sections.s0 + sections.s1;
1601
+ const int sec_w = sections.s1 + sections.s0;
1602
+
1603
+ float inv_ndims = -1.f/n_dims;
1604
+
1605
+ for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) {
1606
+ int ic = i0/2;
1607
+
1608
+ const int sector = (i0/2) % sect_dims;
1609
+ float theta_base = 0.0f;
1610
+
1611
+ if (sector < sections.s0) {
1612
+ const int p = sector;
1613
+ theta_base = pos[i2] * pow(freq_base, inv_ndims*2.0f*p);
1614
+ } else if (sector >= sections.s0 && sector < sec_w) {
1615
+ const int p = sector - sections.s0;
1616
+ theta_base = pos[i2 + ne2] * pow(freq_base, inv_ndims*2.0f*p);
1617
+ }
1618
+
1619
+ const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
1620
+
1621
+ float2 cos_sin_theta = rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor);
1622
+
1623
+ global float * src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
1624
+ global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
1625
+
1626
+ const float x0 = src[0];
1627
+ const float x1 = src[n_dims];
1628
+
1629
+ dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1;
1630
+ dst_data[n_dims] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0;
1631
+ }
1632
+ }
1633
+
1634
+ kernel void kernel_rope_vision_f16(
1635
+ global void * src0,
1636
+ ulong offset0,
1637
+ global int * src1,
1638
+ ulong offset1,
1639
+ global float * src2,
1640
+ ulong offset2,
1641
+ global half * dst,
1642
+ ulong offsetd,
1643
+ int ne00,
1644
+ int ne01,
1645
+ int ne02,
1646
+ int ne03,
1647
+ ulong nb00,
1648
+ ulong nb01,
1649
+ ulong nb02,
1650
+ ulong nb03,
1651
+ int ne0,
1652
+ int ne1,
1653
+ int ne2,
1654
+ int ne3,
1655
+ ulong nb0,
1656
+ ulong nb1,
1657
+ ulong nb2,
1658
+ ulong nb3,
1659
+ int n_past,
1660
+ int n_dims,
1661
+ int n_ctx_orig,
1662
+ float freq_base,
1663
+ float freq_scale,
1664
+ float ext_factor,
1665
+ float attn_factor,
1666
+ float beta_fast,
1667
+ float beta_slow,
1668
+ int4 sections
1669
+ ) {
1670
+ src0 = (global void*)((global char*)src0 + offset0);
1671
+ src1 = (global int*)((global char*)src1 + offset1);
1672
+ src2 = (global float*)((global char*)src2 + offset2);
1673
+ dst = (global float*)((global char*)dst + offsetd);
1674
+
1675
+ int i3 = get_group_id(2);
1676
+ int i2 = get_group_id(1);
1677
+ int i1 = get_group_id(0);
1678
+
1679
+ float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow);
1680
+
1681
+ global int * pos = src1;
1682
+
1683
+ const int sect_dims = sections.s0 + sections.s1;
1684
+ const int sec_w = sections.s1 + sections.s0;
1685
+
1686
+ float inv_ndims = -1.f/n_dims;
1687
+
1688
+ for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) {
1689
+ int ic = i0/2;
1690
+
1691
+ const int sector = (i0/2) % sect_dims;
1692
+ float theta_base = 0.0f;
1693
+
1694
+ if (sector < sections.s0) {
1695
+ const int p = sector;
1696
+ theta_base = pos[i2] * pow(freq_base, inv_ndims*2.0f*p);
1697
+ } else if (sector >= sections.s0 && sector < sec_w) {
1698
+ const int p = sector - sections.s0;
1699
+ theta_base = pos[i2 + ne2] * pow(freq_base, inv_ndims*2.0f*p);
1700
+ }
1701
+
1702
+ const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
1703
+
1704
+ float2 cos_sin_theta = rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor);
1705
+
1706
+ global half * src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
1707
+ global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
1708
+
1709
+ const float x0 = src[0];
1710
+ const float x1 = src[n_dims];
1711
+
1712
+ dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1;
1713
+ dst_data[n_dims] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0;
1714
+ }
1715
+ }
1716
+
1717
  //------------------------------------------------------------------------------
1718
  // cpy
1719
  //------------------------------------------------------------------------------
ggml/src/ggml-opencl/kernels/ggml-opencl_im2col.cl ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #ifdef cl_khr_fp16
2
+ #pragma OPENCL EXTENSION cl_khr_fp16 : enable
3
+ #elif defined(cl_amd_fp16)
4
+ #pragma OPENCL EXTENSION cl_amd_fp16 : enable
5
+ #else
6
+ #error "Half precision floating point not supportedby OpenCL implementation on your device."
7
+ #endif
8
+
9
+ #ifdef cl_khr_subgroups
10
+ #pragma OPENCL EXTENSION cl_khr_subgroups : enable
11
+ #elif defined(cl_intel_subgroups)
12
+ #pragma OPENCL EXTENSION cl_intel_subgroups : enable
13
+ #else
14
+ #error "Subgroup not supported on your device."
15
+ #endif
16
+
17
+ #ifdef cl_intel_required_subgroup_size
18
+ // Always use subgroup size of 32 on Intel.
19
+ #pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
20
+ #define INTEL_GPU 1
21
+ #define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
22
+ #define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
23
+ #elif defined(cl_qcom_reqd_sub_group_size)
24
+ // Always use subgroups size of 64 on Adreno.
25
+ #pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
26
+ #define ADRENO_GPU 1
27
+ #define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
28
+ #define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
29
+ #else
30
+ // TODO: do not know how to choose subgroup size on other GPUs.
31
+ #error "Selecting subgroup size is not supported on your device."
32
+ #endif
33
+
34
+ kernel void kernel_im2col_f32(
35
+ global float * src1,
36
+ ulong offset1,
37
+ global float * dst,
38
+ ulong offsetd,
39
+ ulong batch_offset,
40
+ ulong delta_offset,
41
+ long IW,
42
+ long IH,
43
+ long IC,
44
+ long OW,
45
+ long OH,
46
+ long KW,
47
+ long KH,
48
+ long pelements,
49
+ long CHW,
50
+ int s0,
51
+ int s1,
52
+ int p0,
53
+ int p1,
54
+ int d0,
55
+ int d1
56
+ ) {
57
+ // threadIdx.x + blockIdx.x * blockDim.x
58
+ long i = get_global_id(0);
59
+ if (i >= pelements) {
60
+ return;
61
+ }
62
+
63
+ src1 = (global float*)((global char*)src1 + offset1);
64
+ dst = (global float*)((global char*)dst + offsetd);
65
+
66
+ long ksize = OW * (KH > 1 ? KW : 1);
67
+ long kx = i / ksize;
68
+ long kd = kx * ksize;
69
+ long ky = (i - kd) / OW;
70
+ long ix = i % OW;
71
+
72
+ long oh = get_group_id(1);
73
+ long batch = get_group_id(2) / IC;
74
+ long ic = get_group_id(2) % IC;
75
+
76
+ long iiw = ix * s0 + kx * d0 - p0;
77
+ long iih = oh * s1 + ky * d1 - p1;
78
+
79
+ long offset_dst =
80
+ ((batch * OH + oh) * OW + ix) * CHW +
81
+ (ic * (KW * KH) + ky * KW + kx);
82
+
83
+ if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
84
+ dst[offset_dst] = 0.0f;
85
+ } else {
86
+ long offset_src = ic * delta_offset + batch * batch_offset;
87
+ dst[offset_dst] = src1[offset_src + iih * IW + iiw];
88
+ }
89
+ }
90
+
91
+ kernel void kernel_im2col_f16(
92
+ global float * src1,
93
+ ulong offset1,
94
+ global half * dst,
95
+ ulong offsetd,
96
+ ulong batch_offset,
97
+ ulong delta_offset,
98
+ long IW,
99
+ long IH,
100
+ long IC,
101
+ long OW,
102
+ long OH,
103
+ long KW,
104
+ long KH,
105
+ long pelements,
106
+ long CHW,
107
+ int s0,
108
+ int s1,
109
+ int p0,
110
+ int p1,
111
+ int d0,
112
+ int d1
113
+ ) {
114
+ long i = get_global_id(0);
115
+
116
+ if (i >= pelements) {
117
+ return;
118
+ }
119
+
120
+ src1 = (global float*)((global char*)src1 + offset1);
121
+ dst = (global half*)((global char*)dst + offsetd);
122
+
123
+ long ksize = OW * (KH > 1 ? KW : 1);
124
+ long kx = i / ksize;
125
+ long kd = kx * ksize;
126
+ long ky = (i - kd) / OW;
127
+ long ix = i % OW;
128
+
129
+ long oh = get_group_id(1);
130
+ long batch = get_group_id(2) / IC;
131
+ long ic = get_group_id(2) % IC;
132
+
133
+ long iiw = ix * s0 + kx * d0 - p0;
134
+ long iih = oh * s1 + ky * d1 - p1;
135
+
136
+ long offset_dst =
137
+ ((batch * OH + oh) * OW + ix) * CHW +
138
+ (ic * (KW * KH) + ky * KW + kx);
139
+
140
+ if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
141
+ dst[offset_dst] = 0.0f;
142
+ } else {
143
+ long offset_src = ic * delta_offset + batch * batch_offset;
144
+ dst[offset_dst] = src1[offset_src + iih * IW + iiw];
145
+ }
146
+ }