Spaces:
Running
Running
opencl: add conv2d kernel (llama/14403)
Browse files* add conv2d kernel
* fix trailing whitespace
* whitespace fixe
* handle f16 input and f16 kernel, more opt
* resolve conflicts
* use enqueue_ndrange_kernel
ggml/src/ggml-opencl/CMakeLists.txt
CHANGED
|
@@ -105,6 +105,8 @@ set(GGML_OPENCL_KERNELS
|
|
| 105 |
pad
|
| 106 |
repeat
|
| 107 |
mul_mat_f16_f32
|
|
|
|
|
|
|
| 108 |
)
|
| 109 |
|
| 110 |
foreach (K ${GGML_OPENCL_KERNELS})
|
|
|
|
| 105 |
pad
|
| 106 |
repeat
|
| 107 |
mul_mat_f16_f32
|
| 108 |
+
conv2d
|
| 109 |
+
conv2d_f16_f32
|
| 110 |
)
|
| 111 |
|
| 112 |
foreach (K ${GGML_OPENCL_KERNELS})
|
ggml/src/ggml-opencl/ggml-opencl.cpp
CHANGED
|
@@ -390,6 +390,9 @@ struct ggml_backend_opencl_context {
|
|
| 390 |
cl_program program_tanh;
|
| 391 |
cl_program program_upscale;
|
| 392 |
cl_program program_concat;
|
|
|
|
|
|
|
|
|
|
| 393 |
cl_program program_tsembd;
|
| 394 |
cl_program program_mul_mv_id_q4_0_f32_8x_flat;
|
| 395 |
|
|
@@ -441,6 +444,9 @@ struct ggml_backend_opencl_context {
|
|
| 441 |
cl_kernel kernel_upscale_bilinear;
|
| 442 |
cl_kernel kernel_concat_f32_contiguous;
|
| 443 |
cl_kernel kernel_concat_f32_non_contiguous;
|
|
|
|
|
|
|
|
|
|
| 444 |
cl_kernel kernel_timestep_embedding;
|
| 445 |
cl_kernel kernel_mul_mv_id_q4_0_f32_8x_flat;
|
| 446 |
|
|
@@ -1478,6 +1484,47 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
|
|
| 1478 |
GGML_LOG_CONT(".");
|
| 1479 |
}
|
| 1480 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1481 |
// mul_mv_id_q4_0_f32_8x_flat
|
| 1482 |
{
|
| 1483 |
#ifdef GGML_OPENCL_EMBED_KERNELS
|
|
@@ -2361,6 +2408,10 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
|
|
| 2361 |
op->src[0]->ne[3] == 1 && op->ne[3] == 1;
|
| 2362 |
case GGML_OP_UPSCALE:
|
| 2363 |
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2364 |
case GGML_OP_CONCAT:
|
| 2365 |
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
|
| 2366 |
case GGML_OP_TIMESTEP_EMBEDDING:
|
|
@@ -4998,6 +5049,83 @@ static void ggml_cl_mul_mat_f16_f32_tiled(ggml_backend_t backend, const ggml_ten
|
|
| 4998 |
backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size, local_work_size, dst);
|
| 4999 |
}
|
| 5000 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5001 |
static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 5002 |
GGML_ASSERT(src0);
|
| 5003 |
GGML_ASSERT(src0->extra);
|
|
@@ -6752,6 +6880,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
|
|
| 6752 |
}
|
| 6753 |
ggml_cl_upscale(backend, tensor->src[0], tensor);
|
| 6754 |
return true;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6755 |
case GGML_OP_CONCAT:
|
| 6756 |
if (!any_on_device) {
|
| 6757 |
return false;
|
|
|
|
| 390 |
cl_program program_tanh;
|
| 391 |
cl_program program_upscale;
|
| 392 |
cl_program program_concat;
|
| 393 |
+
cl_program program_conv_2d_f16;
|
| 394 |
+
cl_program program_conv_2d_f32;
|
| 395 |
+
cl_program program_conv_2d_f16_f32;
|
| 396 |
cl_program program_tsembd;
|
| 397 |
cl_program program_mul_mv_id_q4_0_f32_8x_flat;
|
| 398 |
|
|
|
|
| 444 |
cl_kernel kernel_upscale_bilinear;
|
| 445 |
cl_kernel kernel_concat_f32_contiguous;
|
| 446 |
cl_kernel kernel_concat_f32_non_contiguous;
|
| 447 |
+
cl_kernel kernel_conv_2d_f16;
|
| 448 |
+
cl_kernel kernel_conv_2d_f32;
|
| 449 |
+
cl_kernel kernel_conv_2d_f16_f32;
|
| 450 |
cl_kernel kernel_timestep_embedding;
|
| 451 |
cl_kernel kernel_mul_mv_id_q4_0_f32_8x_flat;
|
| 452 |
|
|
|
|
| 1484 |
GGML_LOG_CONT(".");
|
| 1485 |
}
|
| 1486 |
|
| 1487 |
+
// conv2d
|
| 1488 |
+
{
|
| 1489 |
+
#ifdef GGML_OPENCL_EMBED_KERNELS
|
| 1490 |
+
const std::string kernel_src {
|
| 1491 |
+
#include "conv2d.cl.h"
|
| 1492 |
+
};
|
| 1493 |
+
const std::string kernel_src_f16_f32 {
|
| 1494 |
+
#include "conv2d_f16_f32.cl.h"
|
| 1495 |
+
};
|
| 1496 |
+
#else
|
| 1497 |
+
const std::string kernel_src = read_file("conv2d.cl");
|
| 1498 |
+
const std::string kernel_src_f16_f32 = read_file("conv2d_f16_f32.cl");
|
| 1499 |
+
#endif
|
| 1500 |
+
if (!kernel_src.empty()) {
|
| 1501 |
+
backend_ctx->program_conv_2d_f16 =
|
| 1502 |
+
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), (std::string(compile_opts) + " -DUSE_FP16=1").c_str());
|
| 1503 |
+
CL_CHECK((backend_ctx->kernel_conv_2d_f16 = clCreateKernel(backend_ctx->program_conv_2d_f16, "kernel_conv_2d", &err), err));
|
| 1504 |
+
GGML_LOG_CONT(".");
|
| 1505 |
+
backend_ctx->program_conv_2d_f32 =
|
| 1506 |
+
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
|
| 1507 |
+
CL_CHECK((backend_ctx->kernel_conv_2d_f32 = clCreateKernel(backend_ctx->program_conv_2d_f32, "kernel_conv_2d", &err), err));
|
| 1508 |
+
GGML_LOG_CONT(".");
|
| 1509 |
+
} else {
|
| 1510 |
+
GGML_LOG_WARN("ggml_opencl: conv2d kernel source not found or empty. This op will not be available.\n");
|
| 1511 |
+
backend_ctx->program_conv_2d_f16 = nullptr;
|
| 1512 |
+
backend_ctx->kernel_conv_2d_f16 = nullptr;
|
| 1513 |
+
backend_ctx->program_conv_2d_f32 = nullptr;
|
| 1514 |
+
backend_ctx->kernel_conv_2d_f32 = nullptr;
|
| 1515 |
+
}
|
| 1516 |
+
if (!kernel_src_f16_f32.empty()) {
|
| 1517 |
+
backend_ctx->program_conv_2d_f16_f32 =
|
| 1518 |
+
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f16_f32.c_str(), compile_opts);
|
| 1519 |
+
CL_CHECK((backend_ctx->kernel_conv_2d_f16_f32 = clCreateKernel(backend_ctx->program_conv_2d_f16_f32, "kernel_conv_2d", &err), err));
|
| 1520 |
+
GGML_LOG_CONT(".");
|
| 1521 |
+
} else {
|
| 1522 |
+
GGML_LOG_WARN("ggml_opencl: conv2d_f16_f32 kernel source not found or empty. This op will not be available.\n");
|
| 1523 |
+
backend_ctx->program_conv_2d_f16_f32 = nullptr;
|
| 1524 |
+
backend_ctx->kernel_conv_2d_f16_f32 = nullptr;
|
| 1525 |
+
}
|
| 1526 |
+
}
|
| 1527 |
+
|
| 1528 |
// mul_mv_id_q4_0_f32_8x_flat
|
| 1529 |
{
|
| 1530 |
#ifdef GGML_OPENCL_EMBED_KERNELS
|
|
|
|
| 2408 |
op->src[0]->ne[3] == 1 && op->ne[3] == 1;
|
| 2409 |
case GGML_OP_UPSCALE:
|
| 2410 |
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
|
| 2411 |
+
case GGML_OP_CONV_2D:
|
| 2412 |
+
return (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16) ||
|
| 2413 |
+
(op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) ||
|
| 2414 |
+
(op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32);
|
| 2415 |
case GGML_OP_CONCAT:
|
| 2416 |
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
|
| 2417 |
case GGML_OP_TIMESTEP_EMBEDDING:
|
|
|
|
| 5049 |
backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size, local_work_size, dst);
|
| 5050 |
}
|
| 5051 |
|
| 5052 |
+
static void ggml_cl_conv_2d(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 5053 |
+
GGML_TENSOR_BINARY_OP_LOCALS;
|
| 5054 |
+
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
|
| 5055 |
+
|
| 5056 |
+
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
|
| 5057 |
+
ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
|
| 5058 |
+
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
|
| 5059 |
+
|
| 5060 |
+
cl_ulong offset0 = extra0->offset + src0->view_offs;
|
| 5061 |
+
cl_ulong offset1 = extra1->offset + src1->view_offs;
|
| 5062 |
+
cl_ulong offsetd = extrad->offset + dst->view_offs;
|
| 5063 |
+
|
| 5064 |
+
const cl_uint Cout = ne03; const cl_uint Cin = ne02; const cl_uint N = ne13;
|
| 5065 |
+
const cl_uint KW = ne00; const cl_uint KH = ne01; const cl_uint W = ne10; const cl_uint H = ne11; const cl_uint OW = ne0; const cl_uint OH = ne1;
|
| 5066 |
+
|
| 5067 |
+
const cl_uint s0 = dst->op_params[0]; const cl_uint s1 = dst->op_params[1];
|
| 5068 |
+
const cl_uint p0 = dst->op_params[2]; const cl_uint p1 = dst->op_params[3];
|
| 5069 |
+
const cl_uint d0 = dst->op_params[4]; const cl_uint d1 = dst->op_params[5];
|
| 5070 |
+
|
| 5071 |
+
const cl_uint cl_nb01 = nb01/ggml_type_size(src0->type); const cl_uint cl_nb02 = nb02/ggml_type_size(src0->type); const cl_uint cl_nb03 = nb03/ggml_type_size(src0->type);
|
| 5072 |
+
const cl_uint cl_nb11 = nb11/ggml_type_size(src1->type); const cl_uint cl_nb12 = nb12/ggml_type_size(src1->type); const cl_uint cl_nb13 = nb13/ggml_type_size(src1->type);
|
| 5073 |
+
const cl_uint cl_nb1 = nb1/ggml_type_size(dst->type); const cl_uint cl_nb2 = nb2/ggml_type_size(dst->type); const cl_uint cl_nb3 = nb3/ggml_type_size(dst->type);
|
| 5074 |
+
|
| 5075 |
+
const int64_t NPQ = (int64_t)N * OW * OH;
|
| 5076 |
+
|
| 5077 |
+
const uint32_t BS_K = 64;
|
| 5078 |
+
const uint32_t BS_NPQ = 64;
|
| 5079 |
+
const uint32_t BS_CRS = 16;
|
| 5080 |
+
const uint32_t VEC_SIZE = 4;
|
| 5081 |
+
|
| 5082 |
+
const uint32_t TS_K = 4;
|
| 5083 |
+
const uint32_t TS_NPQ = 8;
|
| 5084 |
+
|
| 5085 |
+
const uint32_t WG_K = BS_K / TS_K;
|
| 5086 |
+
const uint32_t WG_NPQ = BS_NPQ / TS_NPQ;
|
| 5087 |
+
|
| 5088 |
+
auto splitWork = [](uint32_t work_size, uint32_t block_size) { return (block_size + work_size - 1) / block_size; };
|
| 5089 |
+
const uint32_t NB_K = splitWork(Cout, BS_K);
|
| 5090 |
+
const uint32_t NB_NPQ = splitWork(NPQ, BS_NPQ);
|
| 5091 |
+
|
| 5092 |
+
cl_kernel kernel;
|
| 5093 |
+
size_t shmem_size;
|
| 5094 |
+
|
| 5095 |
+
if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
|
| 5096 |
+
kernel = backend_ctx->kernel_conv_2d_f16;
|
| 5097 |
+
shmem_size = (size_t)(BS_K * BS_CRS * sizeof(cl_half) + BS_CRS * (BS_NPQ / VEC_SIZE) * sizeof(cl_half4));
|
| 5098 |
+
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
|
| 5099 |
+
kernel = backend_ctx->kernel_conv_2d_f32;
|
| 5100 |
+
shmem_size = (size_t)(BS_K * BS_CRS * sizeof(cl_float) + BS_CRS * (BS_NPQ / VEC_SIZE) * sizeof(cl_float4));
|
| 5101 |
+
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
|
| 5102 |
+
kernel = backend_ctx->kernel_conv_2d_f16_f32;
|
| 5103 |
+
shmem_size = (size_t)(BS_K * BS_CRS * sizeof(cl_half) + BS_CRS * (BS_NPQ / VEC_SIZE) * sizeof(cl_float4));
|
| 5104 |
+
} else {
|
| 5105 |
+
GGML_ASSERT(false && "Unsupported data type combination for conv2d");
|
| 5106 |
+
return;
|
| 5107 |
+
}
|
| 5108 |
+
|
| 5109 |
+
cl_uint idx = 0;
|
| 5110 |
+
CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_mem), &extra0->data_device)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_ulong), &offset0));
|
| 5111 |
+
CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_mem), &extra1->data_device)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_ulong), &offset1));
|
| 5112 |
+
CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_mem), &extrad->data_device)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_ulong), &offsetd));
|
| 5113 |
+
CL_CHECK(clSetKernelArg(kernel, idx++, shmem_size, NULL));
|
| 5114 |
+
CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &Cout)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &Cin)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &N));
|
| 5115 |
+
CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &KW)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &KH)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &W)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &H));
|
| 5116 |
+
CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &OW)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &OH));
|
| 5117 |
+
CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &s0)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &s1)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &p0)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &p1));
|
| 5118 |
+
CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &d0)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &d1));
|
| 5119 |
+
CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb01)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb02)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb03));
|
| 5120 |
+
CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb11)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb12)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb13));
|
| 5121 |
+
CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb1)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb2)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb3));
|
| 5122 |
+
|
| 5123 |
+
size_t global_work_size[] = { (size_t)NB_K * WG_K, (size_t)NB_NPQ * WG_NPQ, 1 };
|
| 5124 |
+
size_t local_work_size[] = { (size_t)WG_K, (size_t)WG_NPQ, 1 };
|
| 5125 |
+
|
| 5126 |
+
backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size, local_work_size, dst);
|
| 5127 |
+
}
|
| 5128 |
+
|
| 5129 |
static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 5130 |
GGML_ASSERT(src0);
|
| 5131 |
GGML_ASSERT(src0->extra);
|
|
|
|
| 6880 |
}
|
| 6881 |
ggml_cl_upscale(backend, tensor->src[0], tensor);
|
| 6882 |
return true;
|
| 6883 |
+
case GGML_OP_CONV_2D:
|
| 6884 |
+
if (!any_on_device) {
|
| 6885 |
+
return false;
|
| 6886 |
+
}
|
| 6887 |
+
func = ggml_cl_conv_2d;
|
| 6888 |
+
break;
|
| 6889 |
case GGML_OP_CONCAT:
|
| 6890 |
if (!any_on_device) {
|
| 6891 |
return false;
|
ggml/src/ggml-opencl/kernels/conv2d.cl
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#ifdef USE_FP16
|
| 2 |
+
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
| 3 |
+
#define T_FLOAT half
|
| 4 |
+
#define T_FLOAT4 half4
|
| 5 |
+
#define VSTORE_T_FLOAT4(data, offset, p) vstore_half4_rte(data, offset, p)
|
| 6 |
+
#else
|
| 7 |
+
#define T_FLOAT float
|
| 8 |
+
#define T_FLOAT4 float4
|
| 9 |
+
#define VSTORE_T_FLOAT4(data, offset, p) vstore4(data, offset, p)
|
| 10 |
+
#endif
|
| 11 |
+
|
| 12 |
+
#if defined(cl_qcom_reqd_sub_group_size)
|
| 13 |
+
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
|
| 14 |
+
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
|
| 15 |
+
#else
|
| 16 |
+
#define REQD_SUBGROUP_SIZE_128
|
| 17 |
+
#endif
|
| 18 |
+
|
| 19 |
+
#define T_ACCUM float4
|
| 20 |
+
#define VEC_SIZE 4
|
| 21 |
+
|
| 22 |
+
#define BS_K 64
|
| 23 |
+
#define BS_NPQ 64
|
| 24 |
+
#define BS_CRS 16
|
| 25 |
+
|
| 26 |
+
#define TS_K 4
|
| 27 |
+
#define TS_NPQ 8
|
| 28 |
+
|
| 29 |
+
#define WG_K (BS_K / TS_K)
|
| 30 |
+
#define WG_NPQ (BS_NPQ / TS_NPQ)
|
| 31 |
+
|
| 32 |
+
#define BS_NPQ_VEC (BS_NPQ / VEC_SIZE)
|
| 33 |
+
#define TS_NPQ_VEC (TS_NPQ / VEC_SIZE)
|
| 34 |
+
|
| 35 |
+
static inline uint splitWork(uint work_size, uint block_size){
|
| 36 |
+
return (work_size + block_size - 1) / block_size;
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
REQD_SUBGROUP_SIZE_128
|
| 40 |
+
kernel void kernel_conv_2d(
|
| 41 |
+
global void* p_knl,
|
| 42 |
+
ulong off_knl,
|
| 43 |
+
global void* p_src,
|
| 44 |
+
ulong off_src,
|
| 45 |
+
global void* p_dst,
|
| 46 |
+
ulong off_dst,
|
| 47 |
+
local void* shared,
|
| 48 |
+
uint Cout, uint Cin, uint N,
|
| 49 |
+
uint KW, uint KH, uint W, uint H, uint OW, uint OH,
|
| 50 |
+
uint s0, uint s1, uint p0, uint p1, uint d0, uint d1,
|
| 51 |
+
uint nb01, uint nb02, uint nb03,
|
| 52 |
+
uint nb11, uint nb12, uint nb13,
|
| 53 |
+
uint nb1, uint nb2, uint nb3
|
| 54 |
+
) {
|
| 55 |
+
global T_FLOAT* knl_data = (global T_FLOAT*) ((global char*)p_knl + off_knl);
|
| 56 |
+
global T_FLOAT* src_data = (global T_FLOAT*) ((global char*)p_src + off_src);
|
| 57 |
+
global T_FLOAT* dst_data = (global T_FLOAT*) ((global char*)p_dst + off_dst);
|
| 58 |
+
|
| 59 |
+
const uint K = Cout;
|
| 60 |
+
const uint CRS = Cin*KH*KW;
|
| 61 |
+
const uint NPQ = N*OH*OW;
|
| 62 |
+
|
| 63 |
+
const uint lid_k = get_local_id(0);
|
| 64 |
+
const uint lid_npq = get_local_id(1);
|
| 65 |
+
const uint tid = lid_npq * WG_K + lid_k;
|
| 66 |
+
|
| 67 |
+
const uint B_idx_K = get_group_id(0);
|
| 68 |
+
const uint B_idx_NPQ = get_group_id(1);
|
| 69 |
+
|
| 70 |
+
const uint offset_k = B_idx_K * BS_K;
|
| 71 |
+
const uint offset_npq = B_idx_NPQ * BS_NPQ;
|
| 72 |
+
|
| 73 |
+
local T_FLOAT* Ash = (local T_FLOAT*)shared;
|
| 74 |
+
local T_FLOAT4* Bsh = (local T_FLOAT4*) &Ash[BS_K * BS_CRS];
|
| 75 |
+
|
| 76 |
+
T_ACCUM regC[TS_K][TS_NPQ_VEC];
|
| 77 |
+
for (int i = 0; i < TS_K; ++i) {
|
| 78 |
+
for (int j = 0; j < TS_NPQ_VEC; ++j) {
|
| 79 |
+
regC[i][j] = (T_ACCUM)(0.0f);
|
| 80 |
+
}
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
const uint NB_CRS = splitWork(CRS, BS_CRS);
|
| 84 |
+
|
| 85 |
+
for (uint B_idx_CRS = 0; B_idx_CRS < NB_CRS; ++B_idx_CRS) {
|
| 86 |
+
const uint offset_crs = B_idx_CRS * BS_CRS;
|
| 87 |
+
|
| 88 |
+
for (int i = tid; i < BS_K * BS_CRS; i += (WG_K * WG_NPQ)) {
|
| 89 |
+
const uint k_l = i / BS_CRS;
|
| 90 |
+
const uint crs_l = i % BS_CRS;
|
| 91 |
+
const uint k_g = offset_k + k_l;
|
| 92 |
+
const uint crs_g = offset_crs + crs_l;
|
| 93 |
+
|
| 94 |
+
if (k_g < K && crs_g < CRS) {
|
| 95 |
+
const uint Cin_idx = crs_g / (KW*KH);
|
| 96 |
+
const uint KH_idx = (crs_g - Cin_idx*KW*KH) / KW;
|
| 97 |
+
const uint KW_idx = crs_g - Cin_idx*KW*KH - KH_idx*KW;
|
| 98 |
+
const uint knl_idx = KW_idx + KH_idx*nb01 + Cin_idx*nb02 + k_g*nb03;
|
| 99 |
+
Ash[k_l * BS_CRS + crs_l] = knl_data[knl_idx];
|
| 100 |
+
} else {
|
| 101 |
+
Ash[k_l * BS_CRS + crs_l] = (T_FLOAT)0.0f;
|
| 102 |
+
}
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
for (int i = tid; i < BS_CRS * BS_NPQ_VEC; i += (WG_K * WG_NPQ)) {
|
| 106 |
+
const uint crs_l = i / BS_NPQ_VEC;
|
| 107 |
+
const uint npq_l_vec = i % BS_NPQ_VEC;
|
| 108 |
+
const uint crs_g = offset_crs + crs_l;
|
| 109 |
+
|
| 110 |
+
T_FLOAT4 val = (T_FLOAT4)(0.0f);
|
| 111 |
+
if (crs_g < CRS) {
|
| 112 |
+
const uint Cin_idx = crs_g / (KW * KH);
|
| 113 |
+
const uint KH_idx = (crs_g - Cin_idx * KW * KH) / KW;
|
| 114 |
+
const uint KW_idx = crs_g - Cin_idx * KW * KH - KH_idx * KW;
|
| 115 |
+
for (int v = 0; v < VEC_SIZE; ++v) {
|
| 116 |
+
const uint npq_g = offset_npq + npq_l_vec * VEC_SIZE + v;
|
| 117 |
+
if (npq_g < NPQ) {
|
| 118 |
+
const uint N_idx = npq_g / (OH * OW);
|
| 119 |
+
const uint pq_idx = npq_g % (OH * OW);
|
| 120 |
+
const uint OH_idx = pq_idx / OW;
|
| 121 |
+
const uint OW_idx = pq_idx % OW;
|
| 122 |
+
const int H_idx = (int)(OH_idx * s1 + KH_idx * d1 - p1);
|
| 123 |
+
const int W_idx = (int)(OW_idx * s0 + KW_idx * d0 - p0);
|
| 124 |
+
|
| 125 |
+
if (H_idx >= 0 && H_idx < H && W_idx >= 0 && W_idx < W) {
|
| 126 |
+
const uint src_idx = W_idx + H_idx * nb11 + Cin_idx * nb12 + N_idx * nb13;
|
| 127 |
+
((T_FLOAT*)&val)[v] = src_data[src_idx];
|
| 128 |
+
}
|
| 129 |
+
}
|
| 130 |
+
}
|
| 131 |
+
}
|
| 132 |
+
Bsh[crs_l * BS_NPQ_VEC + npq_l_vec] = val;
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
barrier(CLK_LOCAL_MEM_FENCE);
|
| 136 |
+
|
| 137 |
+
#pragma unroll
|
| 138 |
+
for (uint crs_l = 0; crs_l < BS_CRS; ++crs_l) {
|
| 139 |
+
T_FLOAT regA[TS_K];
|
| 140 |
+
for (uint k_l_reg = 0; k_l_reg < TS_K; ++k_l_reg) {
|
| 141 |
+
regA[k_l_reg] = Ash[(lid_k * TS_K + k_l_reg) * BS_CRS + crs_l];
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
for (uint npq_l_vec_reg = 0; npq_l_vec_reg < TS_NPQ_VEC; ++npq_l_vec_reg) {
|
| 145 |
+
T_FLOAT4 regB = Bsh[crs_l * BS_NPQ_VEC + lid_npq * TS_NPQ_VEC + npq_l_vec_reg];
|
| 146 |
+
for (uint k_l_reg = 0; k_l_reg < TS_K; ++k_l_reg) {
|
| 147 |
+
regC[k_l_reg][npq_l_vec_reg] = mad(convert_float(regA[k_l_reg]), convert_float4(regB), regC[k_l_reg][npq_l_vec_reg]);
|
| 148 |
+
}
|
| 149 |
+
}
|
| 150 |
+
}
|
| 151 |
+
barrier(CLK_LOCAL_MEM_FENCE);
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
for (uint k_l_reg = 0; k_l_reg < TS_K; ++k_l_reg) {
|
| 155 |
+
const uint k_g = offset_k + lid_k * TS_K + k_l_reg;
|
| 156 |
+
if (k_g >= K) continue;
|
| 157 |
+
|
| 158 |
+
for (uint npq_l_vec_reg = 0; npq_l_vec_reg < TS_NPQ_VEC; ++npq_l_vec_reg) {
|
| 159 |
+
const uint npq_g_base = offset_npq + (lid_npq * TS_NPQ_VEC + npq_l_vec_reg) * VEC_SIZE;
|
| 160 |
+
|
| 161 |
+
const uint N_idx = npq_g_base / (OH * OW);
|
| 162 |
+
const uint pq_idx = npq_g_base % (OH * OW);
|
| 163 |
+
const uint OH_idx = pq_idx / OW;
|
| 164 |
+
const uint OW_idx = pq_idx % OW;
|
| 165 |
+
|
| 166 |
+
if (nb1 == OW && OW_idx + VEC_SIZE <= OW && npq_g_base + VEC_SIZE <= NPQ) {
|
| 167 |
+
const uint dst_idx = OW_idx + OH_idx*nb1 + k_g*nb2 + N_idx*nb3;
|
| 168 |
+
VSTORE_T_FLOAT4(regC[k_l_reg][npq_l_vec_reg], 0, &dst_data[dst_idx]);
|
| 169 |
+
} else {
|
| 170 |
+
T_ACCUM res = regC[k_l_reg][npq_l_vec_reg];
|
| 171 |
+
for (int v = 0; v < VEC_SIZE; ++v) {
|
| 172 |
+
const uint npq_g = npq_g_base + v;
|
| 173 |
+
if (npq_g < NPQ) {
|
| 174 |
+
const uint N_idx_s = npq_g / (OH*OW);
|
| 175 |
+
const uint pq_idx_s = npq_g % (OH*OW);
|
| 176 |
+
const uint OH_idx_s = pq_idx_s / OW;
|
| 177 |
+
const uint OW_idx_s = pq_idx_s % OW;
|
| 178 |
+
const uint dst_idx_s = OW_idx_s + OH_idx_s*nb1 + k_g*nb2 + N_idx_s*nb3;
|
| 179 |
+
dst_data[dst_idx_s] = (T_FLOAT)(((float*)&res)[v]);
|
| 180 |
+
}
|
| 181 |
+
}
|
| 182 |
+
}
|
| 183 |
+
}
|
| 184 |
+
}
|
| 185 |
+
}
|
ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
| 2 |
+
|
| 3 |
+
#if defined(cl_qcom_reqd_sub_group_size)
|
| 4 |
+
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
|
| 5 |
+
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
|
| 6 |
+
#else
|
| 7 |
+
#define REQD_SUBGROUP_SIZE_128
|
| 8 |
+
#endif
|
| 9 |
+
|
| 10 |
+
#define T_ACCUM float4
|
| 11 |
+
#define VEC_SIZE 4
|
| 12 |
+
|
| 13 |
+
#define BS_K 64
|
| 14 |
+
#define BS_NPQ 64
|
| 15 |
+
#define BS_CRS 16
|
| 16 |
+
|
| 17 |
+
#define TS_K 4
|
| 18 |
+
#define TS_NPQ 8
|
| 19 |
+
|
| 20 |
+
#define WG_K (BS_K / TS_K)
|
| 21 |
+
#define WG_NPQ (BS_NPQ / TS_NPQ)
|
| 22 |
+
|
| 23 |
+
#define BS_NPQ_VEC (BS_NPQ / VEC_SIZE)
|
| 24 |
+
#define TS_NPQ_VEC (TS_NPQ / VEC_SIZE)
|
| 25 |
+
|
| 26 |
+
static inline uint splitWork(uint work_size, uint block_size){
|
| 27 |
+
return (work_size + block_size - 1) / block_size;
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
REQD_SUBGROUP_SIZE_128
|
| 31 |
+
kernel void kernel_conv_2d(
|
| 32 |
+
global void* p_knl,
|
| 33 |
+
ulong off_knl,
|
| 34 |
+
global void* p_src,
|
| 35 |
+
ulong off_src,
|
| 36 |
+
global void* p_dst,
|
| 37 |
+
ulong off_dst,
|
| 38 |
+
local void* shared,
|
| 39 |
+
uint Cout, uint Cin, uint N,
|
| 40 |
+
uint KW, uint KH, uint W, uint H, uint OW, uint OH,
|
| 41 |
+
uint s0, uint s1, uint p0, uint p1, uint d0, uint d1,
|
| 42 |
+
uint nb01, uint nb02, uint nb03,
|
| 43 |
+
uint nb11, uint nb12, uint nb13,
|
| 44 |
+
uint nb1, uint nb2, uint nb3
|
| 45 |
+
) {
|
| 46 |
+
global half* knl_data = (global half*) ((global char*)p_knl + off_knl);
|
| 47 |
+
global float* src_data = (global float*) ((global char*)p_src + off_src);
|
| 48 |
+
global float* dst_data = (global float*) ((global char*)p_dst + off_dst);
|
| 49 |
+
|
| 50 |
+
const uint K = Cout;
|
| 51 |
+
const uint CRS = Cin*KH*KW;
|
| 52 |
+
const uint NPQ = N*OH*OW;
|
| 53 |
+
|
| 54 |
+
const uint lid_k = get_local_id(0);
|
| 55 |
+
const uint lid_npq = get_local_id(1);
|
| 56 |
+
const uint tid = lid_npq * WG_K + lid_k;
|
| 57 |
+
|
| 58 |
+
const uint B_idx_K = get_group_id(0);
|
| 59 |
+
const uint B_idx_NPQ = get_group_id(1);
|
| 60 |
+
|
| 61 |
+
const uint offset_k = B_idx_K * BS_K;
|
| 62 |
+
const uint offset_npq = B_idx_NPQ * BS_NPQ;
|
| 63 |
+
|
| 64 |
+
local half* Ash = (local half*)shared;
|
| 65 |
+
local float4* Bsh = (local float4*) &Ash[BS_K * BS_CRS];
|
| 66 |
+
|
| 67 |
+
T_ACCUM regC[TS_K][TS_NPQ_VEC];
|
| 68 |
+
for (int i = 0; i < TS_K; ++i) {
|
| 69 |
+
for (int j = 0; j < TS_NPQ_VEC; ++j) {
|
| 70 |
+
regC[i][j] = (T_ACCUM)(0.0f);
|
| 71 |
+
}
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
const uint NB_CRS = splitWork(CRS, BS_CRS);
|
| 75 |
+
|
| 76 |
+
for (uint B_idx_CRS = 0; B_idx_CRS < NB_CRS; ++B_idx_CRS) {
|
| 77 |
+
const uint offset_crs = B_idx_CRS * BS_CRS;
|
| 78 |
+
|
| 79 |
+
for (int i = tid; i < BS_K * BS_CRS; i += (WG_K * WG_NPQ)) {
|
| 80 |
+
const uint k_l = i / BS_CRS;
|
| 81 |
+
const uint crs_l = i % BS_CRS;
|
| 82 |
+
const uint k_g = offset_k + k_l;
|
| 83 |
+
const uint crs_g = offset_crs + crs_l;
|
| 84 |
+
|
| 85 |
+
if (k_g < K && crs_g < CRS) {
|
| 86 |
+
const uint Cin_idx = crs_g / (KW*KH);
|
| 87 |
+
const uint KH_idx = (crs_g - Cin_idx*KW*KH) / KW;
|
| 88 |
+
const uint KW_idx = crs_g - Cin_idx*KW*KH - KH_idx*KW;
|
| 89 |
+
const uint knl_idx = KW_idx + KH_idx*nb01 + Cin_idx*nb02 + k_g*nb03;
|
| 90 |
+
Ash[k_l * BS_CRS + crs_l] = knl_data[knl_idx];
|
| 91 |
+
} else {
|
| 92 |
+
Ash[k_l * BS_CRS + crs_l] = (half)0.0f;
|
| 93 |
+
}
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
for (int i = tid; i < BS_CRS * BS_NPQ_VEC; i += (WG_K * WG_NPQ)) {
|
| 97 |
+
const uint crs_l = i / BS_NPQ_VEC;
|
| 98 |
+
const uint npq_l_vec = i % BS_NPQ_VEC;
|
| 99 |
+
const uint crs_g = offset_crs + crs_l;
|
| 100 |
+
|
| 101 |
+
float4 val = (float4)(0.0f);
|
| 102 |
+
if (crs_g < CRS) {
|
| 103 |
+
const uint Cin_idx = crs_g / (KW * KH);
|
| 104 |
+
const uint KH_idx = (crs_g - Cin_idx * KW * KH) / KW;
|
| 105 |
+
const uint KW_idx = crs_g - Cin_idx * KW * KH - KH_idx * KW;
|
| 106 |
+
for (int v = 0; v < VEC_SIZE; ++v) {
|
| 107 |
+
const uint npq_g = offset_npq + npq_l_vec * VEC_SIZE + v;
|
| 108 |
+
if (npq_g < NPQ) {
|
| 109 |
+
const uint N_idx = npq_g / (OH * OW);
|
| 110 |
+
const uint pq_idx = npq_g % (OH * OW);
|
| 111 |
+
const uint OH_idx = pq_idx / OW;
|
| 112 |
+
const uint OW_idx = pq_idx % OW;
|
| 113 |
+
const int H_idx = (int)(OH_idx * s1 + KH_idx * d1 - p1);
|
| 114 |
+
const int W_idx = (int)(OW_idx * s0 + KW_idx * d0 - p0);
|
| 115 |
+
|
| 116 |
+
if (H_idx >= 0 && H_idx < H && W_idx >= 0 && W_idx < W) {
|
| 117 |
+
const uint src_idx = W_idx + H_idx * nb11 + Cin_idx * nb12 + N_idx * nb13;
|
| 118 |
+
((float*)&val)[v] = src_data[src_idx];
|
| 119 |
+
}
|
| 120 |
+
}
|
| 121 |
+
}
|
| 122 |
+
}
|
| 123 |
+
Bsh[crs_l * BS_NPQ_VEC + npq_l_vec] = val;
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
barrier(CLK_LOCAL_MEM_FENCE);
|
| 127 |
+
|
| 128 |
+
#pragma unroll
|
| 129 |
+
for (uint crs_l = 0; crs_l < BS_CRS; ++crs_l) {
|
| 130 |
+
half regA[TS_K];
|
| 131 |
+
for (uint k_l_reg = 0; k_l_reg < TS_K; ++k_l_reg) {
|
| 132 |
+
regA[k_l_reg] = Ash[(lid_k * TS_K + k_l_reg) * BS_CRS + crs_l];
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
for (uint npq_l_vec_reg = 0; npq_l_vec_reg < TS_NPQ_VEC; ++npq_l_vec_reg) {
|
| 136 |
+
float4 regB = Bsh[crs_l * BS_NPQ_VEC + lid_npq * TS_NPQ_VEC + npq_l_vec_reg];
|
| 137 |
+
for (uint k_l_reg = 0; k_l_reg < TS_K; ++k_l_reg) {
|
| 138 |
+
regC[k_l_reg][npq_l_vec_reg] = mad(convert_float(regA[k_l_reg]), regB, regC[k_l_reg][npq_l_vec_reg]);
|
| 139 |
+
}
|
| 140 |
+
}
|
| 141 |
+
}
|
| 142 |
+
barrier(CLK_LOCAL_MEM_FENCE);
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
for (uint k_l_reg = 0; k_l_reg < TS_K; ++k_l_reg) {
|
| 146 |
+
const uint k_g = offset_k + lid_k * TS_K + k_l_reg;
|
| 147 |
+
if (k_g >= K) continue;
|
| 148 |
+
|
| 149 |
+
for (uint npq_l_vec_reg = 0; npq_l_vec_reg < TS_NPQ_VEC; ++npq_l_vec_reg) {
|
| 150 |
+
const uint npq_g_base = offset_npq + (lid_npq * TS_NPQ_VEC + npq_l_vec_reg) * VEC_SIZE;
|
| 151 |
+
|
| 152 |
+
const uint N_idx = npq_g_base / (OH * OW);
|
| 153 |
+
const uint pq_idx = npq_g_base % (OH * OW);
|
| 154 |
+
const uint OH_idx = pq_idx / OW;
|
| 155 |
+
const uint OW_idx = pq_idx % OW;
|
| 156 |
+
|
| 157 |
+
if (nb1 == OW && OW_idx + VEC_SIZE <= OW && npq_g_base + VEC_SIZE <= NPQ) {
|
| 158 |
+
const uint dst_idx = OW_idx + OH_idx*nb1 + k_g*nb2 + N_idx*nb3;
|
| 159 |
+
vstore4(regC[k_l_reg][npq_l_vec_reg], 0, &dst_data[dst_idx]);
|
| 160 |
+
} else {
|
| 161 |
+
T_ACCUM res = regC[k_l_reg][npq_l_vec_reg];
|
| 162 |
+
for (int v = 0; v < VEC_SIZE; ++v) {
|
| 163 |
+
const uint npq_g = npq_g_base + v;
|
| 164 |
+
if (npq_g < NPQ) {
|
| 165 |
+
const uint N_idx_s = npq_g / (OH*OW);
|
| 166 |
+
const uint pq_idx_s = npq_g % (OH*OW);
|
| 167 |
+
const uint OH_idx_s = pq_idx_s / OW;
|
| 168 |
+
const uint OW_idx_s = pq_idx_s % OW;
|
| 169 |
+
const uint dst_idx_s = OW_idx_s + OH_idx_s*nb1 + k_g*nb2 + N_idx_s*nb3;
|
| 170 |
+
dst_data[dst_idx_s] = ((float*)&res)[v];
|
| 171 |
+
}
|
| 172 |
+
}
|
| 173 |
+
}
|
| 174 |
+
}
|
| 175 |
+
}
|
| 176 |
+
}
|