lhez commited on
Commit
5629961
·
1 Parent(s): a91e2f3

opencl: add fused `rms_norm_mul` (llama/14841)

Browse files

* opencl: add fused `rms_norm` + `mul`

* opencl: improve workgroup size for `rms_norm_mul`

ggml/src/ggml-opencl/ggml-opencl.cpp CHANGED
@@ -333,6 +333,7 @@ struct ggml_backend_opencl_context {
333
  size_t max_alloc_size;
334
  bool fp16_support;
335
  bool has_vector_subgroup_broadcast;
 
336
  ggml_cl_compiler_version adreno_cl_compiler_version;
337
 
338
  int adreno_wave_size;
@@ -411,7 +412,7 @@ struct ggml_backend_opencl_context {
411
  cl_kernel kernel_geglu, kernel_reglu, kernel_swiglu, kernel_geglu_erf, kernel_geglu_quick,
412
  kernel_geglu_f16, kernel_reglu_f16, kernel_swiglu_f16, kernel_geglu_erf_f16, kernel_geglu_quick_f16;
413
  cl_kernel kernel_norm;
414
- cl_kernel kernel_rms_norm;
415
  cl_kernel kernel_group_norm;
416
  cl_kernel kernel_diag_mask_inf, kernel_diag_mask_inf_8;
417
  cl_kernel kernel_soft_max, kernel_soft_max_4;
@@ -1100,7 +1101,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
1100
  backend_ctx->program_rms_norm =
1101
  build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
1102
 
1103
- CL_CHECK((backend_ctx->kernel_rms_norm = clCreateKernel(backend_ctx->program_rms_norm, "kernel_rms_norm", &err), err));
 
1104
  GGML_LOG_CONT(".");
1105
  }
1106
 
@@ -2110,6 +2112,8 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
2110
  CL_CHECK((backend_ctx->B_d_max = clCreateBuffer(context, 0, max_B_d_bytes, NULL, &err), err));
2111
  #endif // GGML_OPENCL_USE_ADRENO_KERNELS
2112
 
 
 
2113
  dev_ctx->backend_ctx = backend_ctx.release();
2114
  return dev_ctx->backend_ctx;
2115
  }
@@ -2279,7 +2283,45 @@ static void sync_with_other_backends(ggml_backend_t backend) {
2279
  sync_with_other_backends(backend_ctx);
2280
  }
2281
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2282
  static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
 
 
2283
  for (int i = 0; i < cgraph->n_nodes; i++) {
2284
  ggml_tensor * node = cgraph->nodes[i];
2285
 
@@ -2292,6 +2334,12 @@ static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggm
2292
  continue;
2293
  }
2294
 
 
 
 
 
 
 
2295
  bool ok = ggml_cl_compute_forward(backend, node);
2296
  if (!ok) {
2297
  GGML_LOG_ERROR("%s: error: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
@@ -4455,6 +4503,117 @@ static void ggml_cl_rms_norm(ggml_backend_t backend, const ggml_tensor * src0, c
4455
  backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
4456
  }
4457
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4458
  static void ggml_cl_group_norm(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
4459
  GGML_ASSERT(src0);
4460
  GGML_ASSERT(src0->extra);
 
333
  size_t max_alloc_size;
334
  bool fp16_support;
335
  bool has_vector_subgroup_broadcast;
336
+ bool disable_fusion;
337
  ggml_cl_compiler_version adreno_cl_compiler_version;
338
 
339
  int adreno_wave_size;
 
412
  cl_kernel kernel_geglu, kernel_reglu, kernel_swiglu, kernel_geglu_erf, kernel_geglu_quick,
413
  kernel_geglu_f16, kernel_reglu_f16, kernel_swiglu_f16, kernel_geglu_erf_f16, kernel_geglu_quick_f16;
414
  cl_kernel kernel_norm;
415
+ cl_kernel kernel_rms_norm, kernel_rms_norm_mul;
416
  cl_kernel kernel_group_norm;
417
  cl_kernel kernel_diag_mask_inf, kernel_diag_mask_inf_8;
418
  cl_kernel kernel_soft_max, kernel_soft_max_4;
 
1101
  backend_ctx->program_rms_norm =
1102
  build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
1103
 
1104
+ CL_CHECK((backend_ctx->kernel_rms_norm = clCreateKernel(backend_ctx->program_rms_norm, "kernel_rms_norm", &err), err));
1105
+ CL_CHECK((backend_ctx->kernel_rms_norm_mul = clCreateKernel(backend_ctx->program_rms_norm, "kernel_rms_norm_mul", &err), err));
1106
  GGML_LOG_CONT(".");
1107
  }
1108
 
 
2112
  CL_CHECK((backend_ctx->B_d_max = clCreateBuffer(context, 0, max_B_d_bytes, NULL, &err), err));
2113
  #endif // GGML_OPENCL_USE_ADRENO_KERNELS
2114
 
2115
+ backend_ctx->disable_fusion = getenv("GGML_OPENCL_DISABLE_FUSION") != nullptr;
2116
+
2117
  dev_ctx->backend_ctx = backend_ctx.release();
2118
  return dev_ctx->backend_ctx;
2119
  }
 
2283
  sync_with_other_backends(backend_ctx);
2284
  }
2285
 
2286
+ static bool ggml_opencl_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops) {
2287
+ if (!ggml_can_fuse(cgraph, node_idx, ops)) {
2288
+ return false;
2289
+ }
2290
+
2291
+ if (ops.size() == 2 && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) {
2292
+ const ggml_tensor *rms_norm = cgraph->nodes[node_idx];
2293
+ const ggml_tensor *mul = cgraph->nodes[node_idx+1];
2294
+
2295
+ GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32);
2296
+ GGML_ASSERT(rms_norm->type == GGML_TYPE_F32);
2297
+
2298
+ // rms_norm only supports f32
2299
+ if (mul->src[0]->type != GGML_TYPE_F32 ||
2300
+ mul->src[1]->type != GGML_TYPE_F32 ||
2301
+ mul->type != GGML_TYPE_F32) {
2302
+ return false;
2303
+ }
2304
+
2305
+ // if rms_norm is the B operand, then we don't handle broadcast
2306
+ if (rms_norm == mul->src[1] &&
2307
+ !ggml_are_same_shape(mul->src[0], rms_norm->src[1])) {
2308
+ return false;
2309
+ }
2310
+
2311
+ // rms_norm assumes contiguous rows
2312
+ if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) {
2313
+ return false;
2314
+ }
2315
+ }
2316
+
2317
+ return true;
2318
+ }
2319
+
2320
+ static void ggml_opencl_op_rms_norm_fused(ggml_backend_t backend, ggml_tensor * rms_norm_tensor, ggml_tensor * mul_tensor);
2321
+
2322
  static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
2323
+ ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
2324
+
2325
  for (int i = 0; i < cgraph->n_nodes; i++) {
2326
  ggml_tensor * node = cgraph->nodes[i];
2327
 
 
2334
  continue;
2335
  }
2336
 
2337
+ if (!backend_ctx->disable_fusion && ggml_opencl_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
2338
+ ggml_opencl_op_rms_norm_fused(backend, node, cgraph->nodes[i+1]);
2339
+ i++;
2340
+ continue;
2341
+ }
2342
+
2343
  bool ok = ggml_cl_compute_forward(backend, node);
2344
  if (!ok) {
2345
  GGML_LOG_ERROR("%s: error: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
 
4503
  backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
4504
  }
4505
 
4506
+ static void ggml_opencl_op_rms_norm_fused(ggml_backend_t backend, ggml_tensor * rms_norm_tensor, ggml_tensor * mul_tensor) {
4507
+ GGML_ASSERT(mul_tensor);
4508
+ GGML_ASSERT(rms_norm_tensor);
4509
+
4510
+ // src0 is the src of rms_norm, src1 is the other src of mul (one being rms_norm)
4511
+ const ggml_tensor * src0 = rms_norm_tensor->src[0];
4512
+ const ggml_tensor * src1;
4513
+ if (mul_tensor->src[0] == rms_norm_tensor) {
4514
+ src1 = mul_tensor->src[1];
4515
+ } else if (mul_tensor->src[1] == rms_norm_tensor) {
4516
+ src1 = mul_tensor->src[0];
4517
+ } else {
4518
+ GGML_ASSERT(false && "Invalid args for rms_norm and mul");
4519
+ }
4520
+ const ggml_tensor * dst = mul_tensor;
4521
+
4522
+ GGML_ASSERT(src0);
4523
+ GGML_ASSERT(src0->extra);
4524
+ GGML_ASSERT(src1);
4525
+ GGML_ASSERT(src1->extra);
4526
+ GGML_ASSERT(dst);
4527
+ GGML_ASSERT(dst->extra);
4528
+
4529
+ ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
4530
+ ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
4531
+ ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
4532
+
4533
+ cl_ulong offset0 = extra0->offset + src0->view_offs;
4534
+ cl_ulong offset1 = extra1->offset + src0->view_offs;
4535
+ cl_ulong offsetd = extrad->offset + dst->view_offs;
4536
+
4537
+ ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
4538
+
4539
+ float eps;
4540
+ memcpy(&eps, rms_norm_tensor->op_params, sizeof(float));
4541
+
4542
+ const int ne00 = src0->ne[0];
4543
+ const int ne01 = src0->ne[1];
4544
+ const int ne02 = src0->ne[2];
4545
+ const int ne03 = src0->ne[3];
4546
+
4547
+ const cl_ulong nb01 = src0->nb[1];
4548
+ const cl_ulong nb02 = src0->nb[2];
4549
+ const cl_ulong nb03 = src0->nb[3];
4550
+
4551
+ const int ne10 = src1->ne[0];
4552
+ const int ne11 = src1->ne[1];
4553
+ const int ne12 = src1->ne[2];
4554
+ const int ne13 = src1->ne[3];
4555
+
4556
+ const cl_ulong nb11 = src1->nb[1];
4557
+ const cl_ulong nb12 = src1->nb[2];
4558
+ const cl_ulong nb13 = src1->nb[3];
4559
+
4560
+ const cl_ulong nb1 = dst->nb[1];
4561
+ const cl_ulong nb2 = dst->nb[2];
4562
+ const cl_ulong nb3 = dst->nb[3];
4563
+
4564
+ GGML_ASSERT(ne00 % 4 == 0);
4565
+
4566
+ size_t sgs;
4567
+ if (backend_ctx->gpu_family == ADRENO) {
4568
+ sgs = 64;
4569
+ } else if (backend_ctx->gpu_family == INTEL) {
4570
+ sgs = 32;
4571
+ } else {
4572
+ GGML_ASSERT(false && "Unsupported GPU");
4573
+ }
4574
+
4575
+ cl_kernel kernel = backend_ctx->kernel_rms_norm_mul;
4576
+
4577
+ int nth = sgs;
4578
+ int max_workgroup_size = backend_ctx->get_kernel_workgroup_size(kernel);
4579
+ while (nth < ne00 && nth < max_workgroup_size) {
4580
+ nth *= 2;
4581
+ }
4582
+ nth = MIN(nth, max_workgroup_size);
4583
+ nth = MIN(nth, ne00);
4584
+
4585
+ size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
4586
+ size_t local_work_size[] = {(size_t)nth, 1, 1};
4587
+
4588
+ CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
4589
+ CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
4590
+ CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
4591
+ CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
4592
+ CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
4593
+ CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
4594
+ CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
4595
+ CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01));
4596
+ CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02));
4597
+ CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne03));
4598
+ CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb01));
4599
+ CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb02));
4600
+ CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb03));
4601
+ CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne10));
4602
+ CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne11));
4603
+ CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne12));
4604
+ CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne13));
4605
+ CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb11));
4606
+ CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb12));
4607
+ CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb13));
4608
+ CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb1));
4609
+ CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb2));
4610
+ CL_CHECK(clSetKernelArg(kernel, 22, sizeof(cl_ulong), &nb3));
4611
+ CL_CHECK(clSetKernelArg(kernel, 23, sizeof(float), &eps));
4612
+ CL_CHECK(clSetKernelArg(kernel, 24, sizeof(float)*nth/sgs, NULL));
4613
+
4614
+ backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
4615
+ }
4616
+
4617
  static void ggml_cl_group_norm(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
4618
  GGML_ASSERT(src0);
4619
  GGML_ASSERT(src0->extra);
ggml/src/ggml-opencl/kernels/rms_norm.cl CHANGED
@@ -94,3 +94,82 @@ kernel void kernel_rms_norm(
94
  }
95
  }
96
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  }
95
  }
96
  }
97
+
98
+ //------------------------------------------------------------------------------
99
+ // rms_norm_mul
100
+ //------------------------------------------------------------------------------
101
+ #ifdef INTEL_GPU
102
+ REQD_SUBGROUP_SIZE_32
103
+ #elif defined (ADRENO_GPU)
104
+ REQD_SUBGROUP_SIZE_64
105
+ #endif
106
+ kernel void kernel_rms_norm_mul(
107
+ global char * src0,
108
+ ulong offset0,
109
+ global char * src1,
110
+ ulong offset1,
111
+ global char * dst,
112
+ ulong offsetd,
113
+ int ne00,
114
+ int ne01,
115
+ int ne02,
116
+ int ne03,
117
+ ulong nb01,
118
+ ulong nb02,
119
+ ulong nb03,
120
+ int ne10,
121
+ int ne11,
122
+ int ne12,
123
+ int ne13,
124
+ ulong nb11,
125
+ ulong nb12,
126
+ ulong nb13,
127
+ ulong nb1,
128
+ ulong nb2,
129
+ ulong nb3,
130
+ float eps,
131
+ local float * sum
132
+ ) {
133
+ src0 = src0 + offset0;
134
+ src1 = src1 + offset1;
135
+ dst = dst + offsetd;
136
+
137
+ int i03 = get_group_id(2);
138
+ int i02 = get_group_id(1);
139
+ int i01 = get_group_id(0);
140
+
141
+ global float4 * x = (global float4 *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
142
+ global float4 * f = (global float4 *) (src1 + (i03%ne13)*nb13 + (i02%ne12)*nb12 + (i01%ne11)*nb11);
143
+
144
+ float sumf = 0;
145
+
146
+ // parallel sum
147
+ for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
148
+ sumf += dot(x[i00], x[i00]);
149
+ }
150
+ sumf = sub_group_reduce_add(sumf);
151
+ if (get_sub_group_local_id() == 0) {
152
+ sum[get_sub_group_id()] = sumf;
153
+ }
154
+
155
+ barrier(CLK_LOCAL_MEM_FENCE);
156
+
157
+ for (uint i = get_local_size(0) / get_max_sub_group_size() / 2; i > 0; i /= 2) {
158
+ if (get_local_id(0) < i) {
159
+ sum[get_local_id(0)] += sum[get_local_id(0) + i];
160
+ }
161
+ }
162
+ if (get_local_id(0) == 0) {
163
+ sum[0] /= ne00;
164
+ }
165
+
166
+ barrier(CLK_LOCAL_MEM_FENCE);
167
+
168
+ float mean = sum[0];
169
+ float scale = 1.0f/sqrt(mean + eps);
170
+
171
+ global float4 * y = (global float4 *) (dst + i03*nb3 + i02*nb2 + i01*nb1);
172
+ for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
173
+ y[i00] = (x[i00] * scale) * f[i00%(ne10/4)];
174
+ }
175
+ }