mollysama commited on
Commit
d003891
·
1 Parent(s): a12468a

ggml : add epsilon as a parameter for group_norm (llama/8818)

Browse files
ggml/include/ggml.h CHANGED
@@ -1140,16 +1140,17 @@ extern "C" {
1140
 
1141
  // group normalize along ne0*ne1*n_groups
1142
  // used in stable-diffusion
1143
- // TODO: eps is hardcoded to 1e-6 for now
1144
  GGML_API struct ggml_tensor * ggml_group_norm(
1145
  struct ggml_context * ctx,
1146
  struct ggml_tensor * a,
1147
- int n_groups);
 
1148
 
1149
  GGML_API struct ggml_tensor * ggml_group_norm_inplace(
1150
  struct ggml_context * ctx,
1151
  struct ggml_tensor * a,
1152
- int n_groups);
 
1153
 
1154
  // a - x
1155
  // b - dy
 
1140
 
1141
  // group normalize along ne0*ne1*n_groups
1142
  // used in stable-diffusion
 
1143
  GGML_API struct ggml_tensor * ggml_group_norm(
1144
  struct ggml_context * ctx,
1145
  struct ggml_tensor * a,
1146
+ int n_groups,
1147
+ float eps);
1148
 
1149
  GGML_API struct ggml_tensor * ggml_group_norm_inplace(
1150
  struct ggml_context * ctx,
1151
  struct ggml_tensor * a,
1152
+ int n_groups,
1153
+ float eps);
1154
 
1155
  // a - x
1156
  // b - dy
ggml/src/ggml-cuda/norm.cu CHANGED
@@ -142,8 +142,7 @@ static void norm_f32_cuda(const float * x, float * dst, const int ncols, const i
142
  }
143
  }
144
 
145
- static void group_norm_f32_cuda(const float * x, float * dst, const int num_groups, const int group_size, const int ne_elements, cudaStream_t stream) {
146
- static const float eps = 1e-6f;
147
  if (group_size < 1024) {
148
  const dim3 block_dims(WARP_SIZE, 1, 1);
149
  group_norm_f32<WARP_SIZE><<<num_groups, block_dims, 0, stream>>>(x, dst, group_size, ne_elements, eps);
@@ -196,8 +195,12 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
196
  GGML_ASSERT( dst->type == GGML_TYPE_F32);
197
 
198
  int num_groups = dst->op_params[0];
 
 
 
 
199
  int group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
200
- group_norm_f32_cuda(src0_d, dst_d, num_groups * src0->ne[3], group_size, ggml_nelements(src0), stream);
201
  }
202
 
203
  void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
 
142
  }
143
  }
144
 
145
+ static void group_norm_f32_cuda(const float * x, float * dst, const int num_groups, const float eps, const int group_size, const int ne_elements, cudaStream_t stream) {
 
146
  if (group_size < 1024) {
147
  const dim3 block_dims(WARP_SIZE, 1, 1);
148
  group_norm_f32<WARP_SIZE><<<num_groups, block_dims, 0, stream>>>(x, dst, group_size, ne_elements, eps);
 
195
  GGML_ASSERT( dst->type == GGML_TYPE_F32);
196
 
197
  int num_groups = dst->op_params[0];
198
+
199
+ float eps;
200
+ memcpy(&eps, dst->op_params + 1, sizeof(float));
201
+
202
  int group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
203
+ group_norm_f32_cuda(src0_d, dst_d, num_groups * src0->ne[3], eps, group_size, ggml_nelements(src0), stream);
204
  }
205
 
206
  void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml/src/ggml-metal.m CHANGED
@@ -2236,10 +2236,8 @@ static enum ggml_status ggml_metal_graph_compute(
2236
  GGML_ASSERT(ne00 % 4 == 0);
2237
  GGML_ASSERT(ggml_is_contiguous(src0));
2238
 
2239
- //float eps;
2240
- //memcpy(&eps, dst->op_params, sizeof(float));
2241
-
2242
- const float eps = 1e-6f; // TODO: temporarily hardcoded
2243
 
2244
  const int32_t n_groups = ((int32_t *) dst->op_params)[0];
2245
 
 
2236
  GGML_ASSERT(ne00 % 4 == 0);
2237
  GGML_ASSERT(ggml_is_contiguous(src0));
2238
 
2239
+ float eps;
2240
+ memcpy(&eps, dst->op_params + 1, sizeof(float));
 
 
2241
 
2242
  const int32_t n_groups = ((int32_t *) dst->op_params)[0];
2243
 
ggml/src/ggml-sycl/norm.cpp CHANGED
@@ -225,9 +225,8 @@ static void norm_f32_sycl(const float* x, float* dst, const int ncols,
225
  }
226
 
227
  static void group_norm_f32_sycl(const float* x, float* dst,
228
- const int num_groups, const int group_size,
229
  const int ne_elements, queue_ptr stream, int device) {
230
- static const float eps = 1e-6f;
231
  if (group_size < 1024) {
232
  const sycl::range<3> block_dims(1, 1, WARP_SIZE);
233
  stream->submit([&](sycl::handler& cgh) {
@@ -343,8 +342,12 @@ void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, const ggml_tensor*
343
  GGML_ASSERT(dst->type == GGML_TYPE_F32);
344
 
345
  int num_groups = dst->op_params[0];
 
 
 
 
346
  int group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
347
- group_norm_f32_sycl(src0_dd, dst_dd, num_groups, group_size, src0->ne[0] * src0->ne[1] * src0->ne[2], main_stream, ctx.device);
348
 
349
  (void)src1;
350
  (void)dst;
 
225
  }
226
 
227
  static void group_norm_f32_sycl(const float* x, float* dst,
228
+ const int num_groups, const float eps, const int group_size,
229
  const int ne_elements, queue_ptr stream, int device) {
 
230
  if (group_size < 1024) {
231
  const sycl::range<3> block_dims(1, 1, WARP_SIZE);
232
  stream->submit([&](sycl::handler& cgh) {
 
342
  GGML_ASSERT(dst->type == GGML_TYPE_F32);
343
 
344
  int num_groups = dst->op_params[0];
345
+
346
+ float eps;
347
+ memcpy(&eps, dst->op_params + 1, sizeof(float));
348
+
349
  int group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
350
+ group_norm_f32_sycl(src0_dd, dst_dd, num_groups, eps, group_size, src0->ne[0] * src0->ne[1] * src0->ne[2], main_stream, ctx.device);
351
 
352
  (void)src1;
353
  (void)dst;
ggml/src/ggml.c CHANGED
@@ -5377,6 +5377,7 @@ static struct ggml_tensor * ggml_group_norm_impl(
5377
  struct ggml_context * ctx,
5378
  struct ggml_tensor * a,
5379
  int n_groups,
 
5380
  bool inplace) {
5381
 
5382
  bool is_node = false;
@@ -5387,7 +5388,8 @@ static struct ggml_tensor * ggml_group_norm_impl(
5387
 
5388
  struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
5389
 
5390
- result->op_params[0] = n_groups;
 
5391
 
5392
  result->op = GGML_OP_GROUP_NORM;
5393
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -5399,15 +5401,17 @@ static struct ggml_tensor * ggml_group_norm_impl(
5399
  struct ggml_tensor * ggml_group_norm(
5400
  struct ggml_context * ctx,
5401
  struct ggml_tensor * a,
5402
- int n_groups) {
5403
- return ggml_group_norm_impl(ctx, a, n_groups, false);
 
5404
  }
5405
 
5406
  struct ggml_tensor * ggml_group_norm_inplace(
5407
  struct ggml_context * ctx,
5408
  struct ggml_tensor * a,
5409
- int n_groups) {
5410
- return ggml_group_norm_impl(ctx, a, n_groups, true);
 
5411
  }
5412
 
5413
  // ggml_mul_mat
@@ -12098,10 +12102,11 @@ static void ggml_compute_forward_group_norm_f32(
12098
 
12099
  GGML_TENSOR_UNARY_OP_LOCALS
12100
 
12101
- const float eps = 1e-6f; // TODO: make this a parameter
12102
-
12103
  // TODO: optimize
12104
 
 
 
 
12105
  int n_channels = src0->ne[2];
12106
  int n_groups = dst->op_params[0];
12107
  int n_channels_per_group = (n_channels + n_groups - 1) / n_groups;
 
5377
  struct ggml_context * ctx,
5378
  struct ggml_tensor * a,
5379
  int n_groups,
5380
+ float eps,
5381
  bool inplace) {
5382
 
5383
  bool is_node = false;
 
5388
 
5389
  struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
5390
 
5391
+ ggml_set_op_params_i32(result, 0, n_groups);
5392
+ ggml_set_op_params_f32(result, 1, eps);
5393
 
5394
  result->op = GGML_OP_GROUP_NORM;
5395
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
 
5401
  struct ggml_tensor * ggml_group_norm(
5402
  struct ggml_context * ctx,
5403
  struct ggml_tensor * a,
5404
+ int n_groups,
5405
+ float eps) {
5406
+ return ggml_group_norm_impl(ctx, a, n_groups, eps, false);
5407
  }
5408
 
5409
  struct ggml_tensor * ggml_group_norm_inplace(
5410
  struct ggml_context * ctx,
5411
  struct ggml_tensor * a,
5412
+ int n_groups,
5413
+ float eps) {
5414
+ return ggml_group_norm_impl(ctx, a, n_groups, eps, true);
5415
  }
5416
 
5417
  // ggml_mul_mat
 
12102
 
12103
  GGML_TENSOR_UNARY_OP_LOCALS
12104
 
 
 
12105
  // TODO: optimize
12106
 
12107
+ float eps;
12108
+ memcpy(&eps, dst->op_params + 1, sizeof(float));
12109
+
12110
  int n_channels = src0->ne[2];
12111
  int n_groups = dst->op_params[0];
12112
  int n_channels_per_group = (n_channels + n_groups - 1) / n_groups;