mrfatso commited on
Commit
d579f20
·
1 Parent(s): 931edc1

opencl: add conv2d kernel (llama/14403)

Browse files

* add conv2d kernel

* fix trailing whitespace

* whitespace fixe

* handle f16 input and f16 kernel, more opt

* resolve conflicts

* use enqueue_ndrange_kernel

ggml/src/ggml-opencl/CMakeLists.txt CHANGED
@@ -105,6 +105,8 @@ set(GGML_OPENCL_KERNELS
105
  pad
106
  repeat
107
  mul_mat_f16_f32
 
 
108
  )
109
 
110
  foreach (K ${GGML_OPENCL_KERNELS})
 
105
  pad
106
  repeat
107
  mul_mat_f16_f32
108
+ conv2d
109
+ conv2d_f16_f32
110
  )
111
 
112
  foreach (K ${GGML_OPENCL_KERNELS})
ggml/src/ggml-opencl/ggml-opencl.cpp CHANGED
@@ -390,6 +390,9 @@ struct ggml_backend_opencl_context {
390
  cl_program program_tanh;
391
  cl_program program_upscale;
392
  cl_program program_concat;
 
 
 
393
  cl_program program_tsembd;
394
  cl_program program_mul_mv_id_q4_0_f32_8x_flat;
395
 
@@ -441,6 +444,9 @@ struct ggml_backend_opencl_context {
441
  cl_kernel kernel_upscale_bilinear;
442
  cl_kernel kernel_concat_f32_contiguous;
443
  cl_kernel kernel_concat_f32_non_contiguous;
 
 
 
444
  cl_kernel kernel_timestep_embedding;
445
  cl_kernel kernel_mul_mv_id_q4_0_f32_8x_flat;
446
 
@@ -1478,6 +1484,47 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
1478
  GGML_LOG_CONT(".");
1479
  }
1480
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1481
  // mul_mv_id_q4_0_f32_8x_flat
1482
  {
1483
  #ifdef GGML_OPENCL_EMBED_KERNELS
@@ -2361,6 +2408,10 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
2361
  op->src[0]->ne[3] == 1 && op->ne[3] == 1;
2362
  case GGML_OP_UPSCALE:
2363
  return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
 
 
 
 
2364
  case GGML_OP_CONCAT:
2365
  return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
2366
  case GGML_OP_TIMESTEP_EMBEDDING:
@@ -4998,6 +5049,83 @@ static void ggml_cl_mul_mat_f16_f32_tiled(ggml_backend_t backend, const ggml_ten
4998
  backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size, local_work_size, dst);
4999
  }
5000
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5001
  static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
5002
  GGML_ASSERT(src0);
5003
  GGML_ASSERT(src0->extra);
@@ -6752,6 +6880,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
6752
  }
6753
  ggml_cl_upscale(backend, tensor->src[0], tensor);
6754
  return true;
 
 
 
 
 
 
6755
  case GGML_OP_CONCAT:
6756
  if (!any_on_device) {
6757
  return false;
 
390
  cl_program program_tanh;
391
  cl_program program_upscale;
392
  cl_program program_concat;
393
+ cl_program program_conv_2d_f16;
394
+ cl_program program_conv_2d_f32;
395
+ cl_program program_conv_2d_f16_f32;
396
  cl_program program_tsembd;
397
  cl_program program_mul_mv_id_q4_0_f32_8x_flat;
398
 
 
444
  cl_kernel kernel_upscale_bilinear;
445
  cl_kernel kernel_concat_f32_contiguous;
446
  cl_kernel kernel_concat_f32_non_contiguous;
447
+ cl_kernel kernel_conv_2d_f16;
448
+ cl_kernel kernel_conv_2d_f32;
449
+ cl_kernel kernel_conv_2d_f16_f32;
450
  cl_kernel kernel_timestep_embedding;
451
  cl_kernel kernel_mul_mv_id_q4_0_f32_8x_flat;
452
 
 
1484
  GGML_LOG_CONT(".");
1485
  }
1486
 
1487
+ // conv2d
1488
+ {
1489
+ #ifdef GGML_OPENCL_EMBED_KERNELS
1490
+ const std::string kernel_src {
1491
+ #include "conv2d.cl.h"
1492
+ };
1493
+ const std::string kernel_src_f16_f32 {
1494
+ #include "conv2d_f16_f32.cl.h"
1495
+ };
1496
+ #else
1497
+ const std::string kernel_src = read_file("conv2d.cl");
1498
+ const std::string kernel_src_f16_f32 = read_file("conv2d_f16_f32.cl");
1499
+ #endif
1500
+ if (!kernel_src.empty()) {
1501
+ backend_ctx->program_conv_2d_f16 =
1502
+ build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), (std::string(compile_opts) + " -DUSE_FP16=1").c_str());
1503
+ CL_CHECK((backend_ctx->kernel_conv_2d_f16 = clCreateKernel(backend_ctx->program_conv_2d_f16, "kernel_conv_2d", &err), err));
1504
+ GGML_LOG_CONT(".");
1505
+ backend_ctx->program_conv_2d_f32 =
1506
+ build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
1507
+ CL_CHECK((backend_ctx->kernel_conv_2d_f32 = clCreateKernel(backend_ctx->program_conv_2d_f32, "kernel_conv_2d", &err), err));
1508
+ GGML_LOG_CONT(".");
1509
+ } else {
1510
+ GGML_LOG_WARN("ggml_opencl: conv2d kernel source not found or empty. This op will not be available.\n");
1511
+ backend_ctx->program_conv_2d_f16 = nullptr;
1512
+ backend_ctx->kernel_conv_2d_f16 = nullptr;
1513
+ backend_ctx->program_conv_2d_f32 = nullptr;
1514
+ backend_ctx->kernel_conv_2d_f32 = nullptr;
1515
+ }
1516
+ if (!kernel_src_f16_f32.empty()) {
1517
+ backend_ctx->program_conv_2d_f16_f32 =
1518
+ build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f16_f32.c_str(), compile_opts);
1519
+ CL_CHECK((backend_ctx->kernel_conv_2d_f16_f32 = clCreateKernel(backend_ctx->program_conv_2d_f16_f32, "kernel_conv_2d", &err), err));
1520
+ GGML_LOG_CONT(".");
1521
+ } else {
1522
+ GGML_LOG_WARN("ggml_opencl: conv2d_f16_f32 kernel source not found or empty. This op will not be available.\n");
1523
+ backend_ctx->program_conv_2d_f16_f32 = nullptr;
1524
+ backend_ctx->kernel_conv_2d_f16_f32 = nullptr;
1525
+ }
1526
+ }
1527
+
1528
  // mul_mv_id_q4_0_f32_8x_flat
1529
  {
1530
  #ifdef GGML_OPENCL_EMBED_KERNELS
 
2408
  op->src[0]->ne[3] == 1 && op->ne[3] == 1;
2409
  case GGML_OP_UPSCALE:
2410
  return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
2411
+ case GGML_OP_CONV_2D:
2412
+ return (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16) ||
2413
+ (op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) ||
2414
+ (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32);
2415
  case GGML_OP_CONCAT:
2416
  return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
2417
  case GGML_OP_TIMESTEP_EMBEDDING:
 
5049
  backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size, local_work_size, dst);
5050
  }
5051
 
5052
+ static void ggml_cl_conv_2d(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
5053
+ GGML_TENSOR_BINARY_OP_LOCALS;
5054
+ ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
5055
+
5056
+ ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
5057
+ ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
5058
+ ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
5059
+
5060
+ cl_ulong offset0 = extra0->offset + src0->view_offs;
5061
+ cl_ulong offset1 = extra1->offset + src1->view_offs;
5062
+ cl_ulong offsetd = extrad->offset + dst->view_offs;
5063
+
5064
+ const cl_uint Cout = ne03; const cl_uint Cin = ne02; const cl_uint N = ne13;
5065
+ const cl_uint KW = ne00; const cl_uint KH = ne01; const cl_uint W = ne10; const cl_uint H = ne11; const cl_uint OW = ne0; const cl_uint OH = ne1;
5066
+
5067
+ const cl_uint s0 = dst->op_params[0]; const cl_uint s1 = dst->op_params[1];
5068
+ const cl_uint p0 = dst->op_params[2]; const cl_uint p1 = dst->op_params[3];
5069
+ const cl_uint d0 = dst->op_params[4]; const cl_uint d1 = dst->op_params[5];
5070
+
5071
+ const cl_uint cl_nb01 = nb01/ggml_type_size(src0->type); const cl_uint cl_nb02 = nb02/ggml_type_size(src0->type); const cl_uint cl_nb03 = nb03/ggml_type_size(src0->type);
5072
+ const cl_uint cl_nb11 = nb11/ggml_type_size(src1->type); const cl_uint cl_nb12 = nb12/ggml_type_size(src1->type); const cl_uint cl_nb13 = nb13/ggml_type_size(src1->type);
5073
+ const cl_uint cl_nb1 = nb1/ggml_type_size(dst->type); const cl_uint cl_nb2 = nb2/ggml_type_size(dst->type); const cl_uint cl_nb3 = nb3/ggml_type_size(dst->type);
5074
+
5075
+ const int64_t NPQ = (int64_t)N * OW * OH;
5076
+
5077
+ const uint32_t BS_K = 64;
5078
+ const uint32_t BS_NPQ = 64;
5079
+ const uint32_t BS_CRS = 16;
5080
+ const uint32_t VEC_SIZE = 4;
5081
+
5082
+ const uint32_t TS_K = 4;
5083
+ const uint32_t TS_NPQ = 8;
5084
+
5085
+ const uint32_t WG_K = BS_K / TS_K;
5086
+ const uint32_t WG_NPQ = BS_NPQ / TS_NPQ;
5087
+
5088
+ auto splitWork = [](uint32_t work_size, uint32_t block_size) { return (block_size + work_size - 1) / block_size; };
5089
+ const uint32_t NB_K = splitWork(Cout, BS_K);
5090
+ const uint32_t NB_NPQ = splitWork(NPQ, BS_NPQ);
5091
+
5092
+ cl_kernel kernel;
5093
+ size_t shmem_size;
5094
+
5095
+ if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
5096
+ kernel = backend_ctx->kernel_conv_2d_f16;
5097
+ shmem_size = (size_t)(BS_K * BS_CRS * sizeof(cl_half) + BS_CRS * (BS_NPQ / VEC_SIZE) * sizeof(cl_half4));
5098
+ } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
5099
+ kernel = backend_ctx->kernel_conv_2d_f32;
5100
+ shmem_size = (size_t)(BS_K * BS_CRS * sizeof(cl_float) + BS_CRS * (BS_NPQ / VEC_SIZE) * sizeof(cl_float4));
5101
+ } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
5102
+ kernel = backend_ctx->kernel_conv_2d_f16_f32;
5103
+ shmem_size = (size_t)(BS_K * BS_CRS * sizeof(cl_half) + BS_CRS * (BS_NPQ / VEC_SIZE) * sizeof(cl_float4));
5104
+ } else {
5105
+ GGML_ASSERT(false && "Unsupported data type combination for conv2d");
5106
+ return;
5107
+ }
5108
+
5109
+ cl_uint idx = 0;
5110
+ CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_mem), &extra0->data_device)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_ulong), &offset0));
5111
+ CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_mem), &extra1->data_device)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_ulong), &offset1));
5112
+ CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_mem), &extrad->data_device)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_ulong), &offsetd));
5113
+ CL_CHECK(clSetKernelArg(kernel, idx++, shmem_size, NULL));
5114
+ CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &Cout)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &Cin)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &N));
5115
+ CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &KW)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &KH)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &W)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &H));
5116
+ CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &OW)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &OH));
5117
+ CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &s0)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &s1)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &p0)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &p1));
5118
+ CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &d0)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &d1));
5119
+ CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb01)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb02)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb03));
5120
+ CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb11)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb12)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb13));
5121
+ CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb1)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb2)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb3));
5122
+
5123
+ size_t global_work_size[] = { (size_t)NB_K * WG_K, (size_t)NB_NPQ * WG_NPQ, 1 };
5124
+ size_t local_work_size[] = { (size_t)WG_K, (size_t)WG_NPQ, 1 };
5125
+
5126
+ backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size, local_work_size, dst);
5127
+ }
5128
+
5129
  static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
5130
  GGML_ASSERT(src0);
5131
  GGML_ASSERT(src0->extra);
 
6880
  }
6881
  ggml_cl_upscale(backend, tensor->src[0], tensor);
6882
  return true;
6883
+ case GGML_OP_CONV_2D:
6884
+ if (!any_on_device) {
6885
+ return false;
6886
+ }
6887
+ func = ggml_cl_conv_2d;
6888
+ break;
6889
  case GGML_OP_CONCAT:
6890
  if (!any_on_device) {
6891
  return false;
ggml/src/ggml-opencl/kernels/conv2d.cl ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #ifdef USE_FP16
2
+ #pragma OPENCL EXTENSION cl_khr_fp16 : enable
3
+ #define T_FLOAT half
4
+ #define T_FLOAT4 half4
5
+ #define VSTORE_T_FLOAT4(data, offset, p) vstore_half4_rte(data, offset, p)
6
+ #else
7
+ #define T_FLOAT float
8
+ #define T_FLOAT4 float4
9
+ #define VSTORE_T_FLOAT4(data, offset, p) vstore4(data, offset, p)
10
+ #endif
11
+
12
+ #if defined(cl_qcom_reqd_sub_group_size)
13
+ #pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
14
+ #define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
15
+ #else
16
+ #define REQD_SUBGROUP_SIZE_128
17
+ #endif
18
+
19
+ #define T_ACCUM float4
20
+ #define VEC_SIZE 4
21
+
22
+ #define BS_K 64
23
+ #define BS_NPQ 64
24
+ #define BS_CRS 16
25
+
26
+ #define TS_K 4
27
+ #define TS_NPQ 8
28
+
29
+ #define WG_K (BS_K / TS_K)
30
+ #define WG_NPQ (BS_NPQ / TS_NPQ)
31
+
32
+ #define BS_NPQ_VEC (BS_NPQ / VEC_SIZE)
33
+ #define TS_NPQ_VEC (TS_NPQ / VEC_SIZE)
34
+
35
+ static inline uint splitWork(uint work_size, uint block_size){
36
+ return (work_size + block_size - 1) / block_size;
37
+ }
38
+
39
+ REQD_SUBGROUP_SIZE_128
40
+ kernel void kernel_conv_2d(
41
+ global void* p_knl,
42
+ ulong off_knl,
43
+ global void* p_src,
44
+ ulong off_src,
45
+ global void* p_dst,
46
+ ulong off_dst,
47
+ local void* shared,
48
+ uint Cout, uint Cin, uint N,
49
+ uint KW, uint KH, uint W, uint H, uint OW, uint OH,
50
+ uint s0, uint s1, uint p0, uint p1, uint d0, uint d1,
51
+ uint nb01, uint nb02, uint nb03,
52
+ uint nb11, uint nb12, uint nb13,
53
+ uint nb1, uint nb2, uint nb3
54
+ ) {
55
+ global T_FLOAT* knl_data = (global T_FLOAT*) ((global char*)p_knl + off_knl);
56
+ global T_FLOAT* src_data = (global T_FLOAT*) ((global char*)p_src + off_src);
57
+ global T_FLOAT* dst_data = (global T_FLOAT*) ((global char*)p_dst + off_dst);
58
+
59
+ const uint K = Cout;
60
+ const uint CRS = Cin*KH*KW;
61
+ const uint NPQ = N*OH*OW;
62
+
63
+ const uint lid_k = get_local_id(0);
64
+ const uint lid_npq = get_local_id(1);
65
+ const uint tid = lid_npq * WG_K + lid_k;
66
+
67
+ const uint B_idx_K = get_group_id(0);
68
+ const uint B_idx_NPQ = get_group_id(1);
69
+
70
+ const uint offset_k = B_idx_K * BS_K;
71
+ const uint offset_npq = B_idx_NPQ * BS_NPQ;
72
+
73
+ local T_FLOAT* Ash = (local T_FLOAT*)shared;
74
+ local T_FLOAT4* Bsh = (local T_FLOAT4*) &Ash[BS_K * BS_CRS];
75
+
76
+ T_ACCUM regC[TS_K][TS_NPQ_VEC];
77
+ for (int i = 0; i < TS_K; ++i) {
78
+ for (int j = 0; j < TS_NPQ_VEC; ++j) {
79
+ regC[i][j] = (T_ACCUM)(0.0f);
80
+ }
81
+ }
82
+
83
+ const uint NB_CRS = splitWork(CRS, BS_CRS);
84
+
85
+ for (uint B_idx_CRS = 0; B_idx_CRS < NB_CRS; ++B_idx_CRS) {
86
+ const uint offset_crs = B_idx_CRS * BS_CRS;
87
+
88
+ for (int i = tid; i < BS_K * BS_CRS; i += (WG_K * WG_NPQ)) {
89
+ const uint k_l = i / BS_CRS;
90
+ const uint crs_l = i % BS_CRS;
91
+ const uint k_g = offset_k + k_l;
92
+ const uint crs_g = offset_crs + crs_l;
93
+
94
+ if (k_g < K && crs_g < CRS) {
95
+ const uint Cin_idx = crs_g / (KW*KH);
96
+ const uint KH_idx = (crs_g - Cin_idx*KW*KH) / KW;
97
+ const uint KW_idx = crs_g - Cin_idx*KW*KH - KH_idx*KW;
98
+ const uint knl_idx = KW_idx + KH_idx*nb01 + Cin_idx*nb02 + k_g*nb03;
99
+ Ash[k_l * BS_CRS + crs_l] = knl_data[knl_idx];
100
+ } else {
101
+ Ash[k_l * BS_CRS + crs_l] = (T_FLOAT)0.0f;
102
+ }
103
+ }
104
+
105
+ for (int i = tid; i < BS_CRS * BS_NPQ_VEC; i += (WG_K * WG_NPQ)) {
106
+ const uint crs_l = i / BS_NPQ_VEC;
107
+ const uint npq_l_vec = i % BS_NPQ_VEC;
108
+ const uint crs_g = offset_crs + crs_l;
109
+
110
+ T_FLOAT4 val = (T_FLOAT4)(0.0f);
111
+ if (crs_g < CRS) {
112
+ const uint Cin_idx = crs_g / (KW * KH);
113
+ const uint KH_idx = (crs_g - Cin_idx * KW * KH) / KW;
114
+ const uint KW_idx = crs_g - Cin_idx * KW * KH - KH_idx * KW;
115
+ for (int v = 0; v < VEC_SIZE; ++v) {
116
+ const uint npq_g = offset_npq + npq_l_vec * VEC_SIZE + v;
117
+ if (npq_g < NPQ) {
118
+ const uint N_idx = npq_g / (OH * OW);
119
+ const uint pq_idx = npq_g % (OH * OW);
120
+ const uint OH_idx = pq_idx / OW;
121
+ const uint OW_idx = pq_idx % OW;
122
+ const int H_idx = (int)(OH_idx * s1 + KH_idx * d1 - p1);
123
+ const int W_idx = (int)(OW_idx * s0 + KW_idx * d0 - p0);
124
+
125
+ if (H_idx >= 0 && H_idx < H && W_idx >= 0 && W_idx < W) {
126
+ const uint src_idx = W_idx + H_idx * nb11 + Cin_idx * nb12 + N_idx * nb13;
127
+ ((T_FLOAT*)&val)[v] = src_data[src_idx];
128
+ }
129
+ }
130
+ }
131
+ }
132
+ Bsh[crs_l * BS_NPQ_VEC + npq_l_vec] = val;
133
+ }
134
+
135
+ barrier(CLK_LOCAL_MEM_FENCE);
136
+
137
+ #pragma unroll
138
+ for (uint crs_l = 0; crs_l < BS_CRS; ++crs_l) {
139
+ T_FLOAT regA[TS_K];
140
+ for (uint k_l_reg = 0; k_l_reg < TS_K; ++k_l_reg) {
141
+ regA[k_l_reg] = Ash[(lid_k * TS_K + k_l_reg) * BS_CRS + crs_l];
142
+ }
143
+
144
+ for (uint npq_l_vec_reg = 0; npq_l_vec_reg < TS_NPQ_VEC; ++npq_l_vec_reg) {
145
+ T_FLOAT4 regB = Bsh[crs_l * BS_NPQ_VEC + lid_npq * TS_NPQ_VEC + npq_l_vec_reg];
146
+ for (uint k_l_reg = 0; k_l_reg < TS_K; ++k_l_reg) {
147
+ regC[k_l_reg][npq_l_vec_reg] = mad(convert_float(regA[k_l_reg]), convert_float4(regB), regC[k_l_reg][npq_l_vec_reg]);
148
+ }
149
+ }
150
+ }
151
+ barrier(CLK_LOCAL_MEM_FENCE);
152
+ }
153
+
154
+ for (uint k_l_reg = 0; k_l_reg < TS_K; ++k_l_reg) {
155
+ const uint k_g = offset_k + lid_k * TS_K + k_l_reg;
156
+ if (k_g >= K) continue;
157
+
158
+ for (uint npq_l_vec_reg = 0; npq_l_vec_reg < TS_NPQ_VEC; ++npq_l_vec_reg) {
159
+ const uint npq_g_base = offset_npq + (lid_npq * TS_NPQ_VEC + npq_l_vec_reg) * VEC_SIZE;
160
+
161
+ const uint N_idx = npq_g_base / (OH * OW);
162
+ const uint pq_idx = npq_g_base % (OH * OW);
163
+ const uint OH_idx = pq_idx / OW;
164
+ const uint OW_idx = pq_idx % OW;
165
+
166
+ if (nb1 == OW && OW_idx + VEC_SIZE <= OW && npq_g_base + VEC_SIZE <= NPQ) {
167
+ const uint dst_idx = OW_idx + OH_idx*nb1 + k_g*nb2 + N_idx*nb3;
168
+ VSTORE_T_FLOAT4(regC[k_l_reg][npq_l_vec_reg], 0, &dst_data[dst_idx]);
169
+ } else {
170
+ T_ACCUM res = regC[k_l_reg][npq_l_vec_reg];
171
+ for (int v = 0; v < VEC_SIZE; ++v) {
172
+ const uint npq_g = npq_g_base + v;
173
+ if (npq_g < NPQ) {
174
+ const uint N_idx_s = npq_g / (OH*OW);
175
+ const uint pq_idx_s = npq_g % (OH*OW);
176
+ const uint OH_idx_s = pq_idx_s / OW;
177
+ const uint OW_idx_s = pq_idx_s % OW;
178
+ const uint dst_idx_s = OW_idx_s + OH_idx_s*nb1 + k_g*nb2 + N_idx_s*nb3;
179
+ dst_data[dst_idx_s] = (T_FLOAT)(((float*)&res)[v]);
180
+ }
181
+ }
182
+ }
183
+ }
184
+ }
185
+ }
ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma OPENCL EXTENSION cl_khr_fp16 : enable
2
+
3
+ #if defined(cl_qcom_reqd_sub_group_size)
4
+ #pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
5
+ #define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
6
+ #else
7
+ #define REQD_SUBGROUP_SIZE_128
8
+ #endif
9
+
10
+ #define T_ACCUM float4
11
+ #define VEC_SIZE 4
12
+
13
+ #define BS_K 64
14
+ #define BS_NPQ 64
15
+ #define BS_CRS 16
16
+
17
+ #define TS_K 4
18
+ #define TS_NPQ 8
19
+
20
+ #define WG_K (BS_K / TS_K)
21
+ #define WG_NPQ (BS_NPQ / TS_NPQ)
22
+
23
+ #define BS_NPQ_VEC (BS_NPQ / VEC_SIZE)
24
+ #define TS_NPQ_VEC (TS_NPQ / VEC_SIZE)
25
+
26
+ static inline uint splitWork(uint work_size, uint block_size){
27
+ return (work_size + block_size - 1) / block_size;
28
+ }
29
+
30
+ REQD_SUBGROUP_SIZE_128
31
+ kernel void kernel_conv_2d(
32
+ global void* p_knl,
33
+ ulong off_knl,
34
+ global void* p_src,
35
+ ulong off_src,
36
+ global void* p_dst,
37
+ ulong off_dst,
38
+ local void* shared,
39
+ uint Cout, uint Cin, uint N,
40
+ uint KW, uint KH, uint W, uint H, uint OW, uint OH,
41
+ uint s0, uint s1, uint p0, uint p1, uint d0, uint d1,
42
+ uint nb01, uint nb02, uint nb03,
43
+ uint nb11, uint nb12, uint nb13,
44
+ uint nb1, uint nb2, uint nb3
45
+ ) {
46
+ global half* knl_data = (global half*) ((global char*)p_knl + off_knl);
47
+ global float* src_data = (global float*) ((global char*)p_src + off_src);
48
+ global float* dst_data = (global float*) ((global char*)p_dst + off_dst);
49
+
50
+ const uint K = Cout;
51
+ const uint CRS = Cin*KH*KW;
52
+ const uint NPQ = N*OH*OW;
53
+
54
+ const uint lid_k = get_local_id(0);
55
+ const uint lid_npq = get_local_id(1);
56
+ const uint tid = lid_npq * WG_K + lid_k;
57
+
58
+ const uint B_idx_K = get_group_id(0);
59
+ const uint B_idx_NPQ = get_group_id(1);
60
+
61
+ const uint offset_k = B_idx_K * BS_K;
62
+ const uint offset_npq = B_idx_NPQ * BS_NPQ;
63
+
64
+ local half* Ash = (local half*)shared;
65
+ local float4* Bsh = (local float4*) &Ash[BS_K * BS_CRS];
66
+
67
+ T_ACCUM regC[TS_K][TS_NPQ_VEC];
68
+ for (int i = 0; i < TS_K; ++i) {
69
+ for (int j = 0; j < TS_NPQ_VEC; ++j) {
70
+ regC[i][j] = (T_ACCUM)(0.0f);
71
+ }
72
+ }
73
+
74
+ const uint NB_CRS = splitWork(CRS, BS_CRS);
75
+
76
+ for (uint B_idx_CRS = 0; B_idx_CRS < NB_CRS; ++B_idx_CRS) {
77
+ const uint offset_crs = B_idx_CRS * BS_CRS;
78
+
79
+ for (int i = tid; i < BS_K * BS_CRS; i += (WG_K * WG_NPQ)) {
80
+ const uint k_l = i / BS_CRS;
81
+ const uint crs_l = i % BS_CRS;
82
+ const uint k_g = offset_k + k_l;
83
+ const uint crs_g = offset_crs + crs_l;
84
+
85
+ if (k_g < K && crs_g < CRS) {
86
+ const uint Cin_idx = crs_g / (KW*KH);
87
+ const uint KH_idx = (crs_g - Cin_idx*KW*KH) / KW;
88
+ const uint KW_idx = crs_g - Cin_idx*KW*KH - KH_idx*KW;
89
+ const uint knl_idx = KW_idx + KH_idx*nb01 + Cin_idx*nb02 + k_g*nb03;
90
+ Ash[k_l * BS_CRS + crs_l] = knl_data[knl_idx];
91
+ } else {
92
+ Ash[k_l * BS_CRS + crs_l] = (half)0.0f;
93
+ }
94
+ }
95
+
96
+ for (int i = tid; i < BS_CRS * BS_NPQ_VEC; i += (WG_K * WG_NPQ)) {
97
+ const uint crs_l = i / BS_NPQ_VEC;
98
+ const uint npq_l_vec = i % BS_NPQ_VEC;
99
+ const uint crs_g = offset_crs + crs_l;
100
+
101
+ float4 val = (float4)(0.0f);
102
+ if (crs_g < CRS) {
103
+ const uint Cin_idx = crs_g / (KW * KH);
104
+ const uint KH_idx = (crs_g - Cin_idx * KW * KH) / KW;
105
+ const uint KW_idx = crs_g - Cin_idx * KW * KH - KH_idx * KW;
106
+ for (int v = 0; v < VEC_SIZE; ++v) {
107
+ const uint npq_g = offset_npq + npq_l_vec * VEC_SIZE + v;
108
+ if (npq_g < NPQ) {
109
+ const uint N_idx = npq_g / (OH * OW);
110
+ const uint pq_idx = npq_g % (OH * OW);
111
+ const uint OH_idx = pq_idx / OW;
112
+ const uint OW_idx = pq_idx % OW;
113
+ const int H_idx = (int)(OH_idx * s1 + KH_idx * d1 - p1);
114
+ const int W_idx = (int)(OW_idx * s0 + KW_idx * d0 - p0);
115
+
116
+ if (H_idx >= 0 && H_idx < H && W_idx >= 0 && W_idx < W) {
117
+ const uint src_idx = W_idx + H_idx * nb11 + Cin_idx * nb12 + N_idx * nb13;
118
+ ((float*)&val)[v] = src_data[src_idx];
119
+ }
120
+ }
121
+ }
122
+ }
123
+ Bsh[crs_l * BS_NPQ_VEC + npq_l_vec] = val;
124
+ }
125
+
126
+ barrier(CLK_LOCAL_MEM_FENCE);
127
+
128
+ #pragma unroll
129
+ for (uint crs_l = 0; crs_l < BS_CRS; ++crs_l) {
130
+ half regA[TS_K];
131
+ for (uint k_l_reg = 0; k_l_reg < TS_K; ++k_l_reg) {
132
+ regA[k_l_reg] = Ash[(lid_k * TS_K + k_l_reg) * BS_CRS + crs_l];
133
+ }
134
+
135
+ for (uint npq_l_vec_reg = 0; npq_l_vec_reg < TS_NPQ_VEC; ++npq_l_vec_reg) {
136
+ float4 regB = Bsh[crs_l * BS_NPQ_VEC + lid_npq * TS_NPQ_VEC + npq_l_vec_reg];
137
+ for (uint k_l_reg = 0; k_l_reg < TS_K; ++k_l_reg) {
138
+ regC[k_l_reg][npq_l_vec_reg] = mad(convert_float(regA[k_l_reg]), regB, regC[k_l_reg][npq_l_vec_reg]);
139
+ }
140
+ }
141
+ }
142
+ barrier(CLK_LOCAL_MEM_FENCE);
143
+ }
144
+
145
+ for (uint k_l_reg = 0; k_l_reg < TS_K; ++k_l_reg) {
146
+ const uint k_g = offset_k + lid_k * TS_K + k_l_reg;
147
+ if (k_g >= K) continue;
148
+
149
+ for (uint npq_l_vec_reg = 0; npq_l_vec_reg < TS_NPQ_VEC; ++npq_l_vec_reg) {
150
+ const uint npq_g_base = offset_npq + (lid_npq * TS_NPQ_VEC + npq_l_vec_reg) * VEC_SIZE;
151
+
152
+ const uint N_idx = npq_g_base / (OH * OW);
153
+ const uint pq_idx = npq_g_base % (OH * OW);
154
+ const uint OH_idx = pq_idx / OW;
155
+ const uint OW_idx = pq_idx % OW;
156
+
157
+ if (nb1 == OW && OW_idx + VEC_SIZE <= OW && npq_g_base + VEC_SIZE <= NPQ) {
158
+ const uint dst_idx = OW_idx + OH_idx*nb1 + k_g*nb2 + N_idx*nb3;
159
+ vstore4(regC[k_l_reg][npq_l_vec_reg], 0, &dst_data[dst_idx]);
160
+ } else {
161
+ T_ACCUM res = regC[k_l_reg][npq_l_vec_reg];
162
+ for (int v = 0; v < VEC_SIZE; ++v) {
163
+ const uint npq_g = npq_g_base + v;
164
+ if (npq_g < NPQ) {
165
+ const uint N_idx_s = npq_g / (OH*OW);
166
+ const uint pq_idx_s = npq_g % (OH*OW);
167
+ const uint OH_idx_s = pq_idx_s / OW;
168
+ const uint OW_idx_s = pq_idx_s % OW;
169
+ const uint dst_idx_s = OW_idx_s + OH_idx_s*nb1 + k_g*nb2 + N_idx_s*nb3;
170
+ dst_data[dst_idx_s] = ((float*)&res)[v];
171
+ }
172
+ }
173
+ }
174
+ }
175
+ }
176
+ }