JohannesGaessler commited on
Commit
e1e87a3
·
1 Parent(s): ba483f7

feat: ref. cross entropy, add CUDA, fix grad test (ggml/929)

Browse files
ggml/include/ggml-backend.h CHANGED
@@ -63,6 +63,7 @@ extern "C" {
63
  GGML_API void ggml_backend_tensor_set_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
64
  GGML_API void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
65
 
 
66
  GGML_API GGML_CALL void ggml_backend_tensor_set( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
67
  GGML_API GGML_CALL void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
68
 
 
63
  GGML_API void ggml_backend_tensor_set_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
64
  GGML_API void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
65
 
66
+ // "offset" refers to the offset of the tensor data for setting/getting data
67
  GGML_API GGML_CALL void ggml_backend_tensor_set( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
68
  GGML_API GGML_CALL void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
69
 
ggml/src/ggml-cuda.cu CHANGED
@@ -9,8 +9,10 @@
9
  #include "ggml-cuda/binbcast.cuh"
10
  #include "ggml-cuda/clamp.cuh"
11
  #include "ggml-cuda/concat.cuh"
 
12
  #include "ggml-cuda/convert.cuh"
13
  #include "ggml-cuda/cpy.cuh"
 
14
  #include "ggml-cuda/diagmask.cuh"
15
  #include "ggml-cuda/dmmv.cuh"
16
  #include "ggml-cuda/fattn.cuh"
@@ -29,7 +31,6 @@
29
  #include "ggml-cuda/tsembd.cuh"
30
  #include "ggml-cuda/unary.cuh"
31
  #include "ggml-cuda/upscale.cuh"
32
- #include "ggml-cuda/conv-transpose-1d.cuh"
33
 
34
  #include <algorithm>
35
  #include <array>
@@ -2312,6 +2313,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
2312
  case GGML_OP_FLASH_ATTN_EXT:
2313
  ggml_cuda_flash_attn_ext(ctx, dst);
2314
  break;
 
 
 
2315
  default:
2316
  return false;
2317
  }
@@ -2619,6 +2623,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
2619
  assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device));
2620
  for (int j = 0; j < GGML_MAX_SRC; j++) {
2621
  if (node->src[j] != nullptr) {
 
2622
  assert(node->src[j]->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) || ggml_backend_buffer_is_cuda_split(node->src[j]->buffer));
2623
  }
2624
  }
@@ -2902,6 +2907,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
2902
  }
2903
  return ggml_cuda_info().devices[cuda_ctx->device].cc >= CC_VOLTA &&
2904
  op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
 
 
2905
  #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
2906
  default:
2907
  return false;
 
9
  #include "ggml-cuda/binbcast.cuh"
10
  #include "ggml-cuda/clamp.cuh"
11
  #include "ggml-cuda/concat.cuh"
12
+ #include "ggml-cuda/conv-transpose-1d.cuh"
13
  #include "ggml-cuda/convert.cuh"
14
  #include "ggml-cuda/cpy.cuh"
15
+ #include "ggml-cuda/cross-entropy-loss.cuh"
16
  #include "ggml-cuda/diagmask.cuh"
17
  #include "ggml-cuda/dmmv.cuh"
18
  #include "ggml-cuda/fattn.cuh"
 
31
  #include "ggml-cuda/tsembd.cuh"
32
  #include "ggml-cuda/unary.cuh"
33
  #include "ggml-cuda/upscale.cuh"
 
34
 
35
  #include <algorithm>
36
  #include <array>
 
2313
  case GGML_OP_FLASH_ATTN_EXT:
2314
  ggml_cuda_flash_attn_ext(ctx, dst);
2315
  break;
2316
+ case GGML_OP_CROSS_ENTROPY_LOSS:
2317
+ ggml_cuda_cross_entropy_loss(ctx, dst);
2318
+ break;
2319
  default:
2320
  return false;
2321
  }
 
2623
  assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device));
2624
  for (int j = 0; j < GGML_MAX_SRC; j++) {
2625
  if (node->src[j] != nullptr) {
2626
+ assert(node->src[j]->buffer);
2627
  assert(node->src[j]->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) || ggml_backend_buffer_is_cuda_split(node->src[j]->buffer));
2628
  }
2629
  }
 
2907
  }
2908
  return ggml_cuda_info().devices[cuda_ctx->device].cc >= CC_VOLTA &&
2909
  op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
2910
+ case GGML_OP_CROSS_ENTROPY_LOSS:
2911
+ return true;
2912
  #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
2913
  default:
2914
  return false;
ggml/src/ggml-cuda/cross-entropy-loss.cu ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "common.cuh"
2
+ #include "cross-entropy-loss.cuh"
3
+ #include "sumrows.cuh"
4
+
5
+ #include <cmath>
6
+ #include <cstdint>
7
+
8
+ static __global__ void cross_entropy_loss_f32(const float * logits, const float * labels, float * dst, const int nclasses, const int k) {
9
+ const int warp_id = threadIdx.x / WARP_SIZE;
10
+ const int lane_id = threadIdx.x % WARP_SIZE;
11
+ const int i0 = blockDim.x*blockIdx.x + warp_id*WARP_SIZE;
12
+
13
+ const int ne_tmp = WARP_SIZE*nclasses;
14
+
15
+ extern __shared__ float tmp_all[];
16
+ float * tmp_logits = tmp_all + (2*warp_id + 0)*ne_tmp;
17
+ float * tmp_labels = tmp_all + (2*warp_id + 1)*ne_tmp;
18
+
19
+ // Each warp first loads ne_tmp logits/labels into shared memory:
20
+ for (int i = lane_id; i < ne_tmp; i += WARP_SIZE) {
21
+ const int ig = i0*nclasses + i; // ig == i global
22
+
23
+ tmp_logits[i] = ig < k*nclasses ? logits[ig] : 0.0f;
24
+ tmp_labels[i] = ig < k*nclasses ? labels[ig] : 0.0f;
25
+ }
26
+
27
+ // Each thread in the warp then calculates the cross entropy loss for a single row.
28
+ // TODO: pad in order to avoid shared memory bank conflicts.
29
+
30
+ // Find maximum for softmax:
31
+ float max = -INFINITY;
32
+ for (int i = 0; i < nclasses; ++i) {
33
+ max = fmaxf(max, tmp_logits[lane_id*nclasses + i]);
34
+ }
35
+
36
+ // Calculate log(softmax(logits)) which is just logits - max:
37
+ float sum = 0.0f;
38
+ for (int i = 0; i < nclasses; ++i) {
39
+ float val = tmp_logits[lane_id*nclasses + i] - max;
40
+ sum += expf(val);
41
+ tmp_logits[lane_id*nclasses + i] = val;
42
+ }
43
+ sum = logf(sum);
44
+
45
+ // log(exp(logits - max) / sum) = (logits - max) - log(sum)
46
+ float loss = 0.0f;
47
+ for (int i = 0; i < nclasses; ++i) {
48
+ loss += (tmp_logits[lane_id*nclasses + i] - sum) * tmp_labels[lane_id*nclasses + i];
49
+ }
50
+ loss = -warp_reduce_sum(loss) / (float)k;
51
+
52
+ __syncthreads();
53
+
54
+ if (lane_id == 0) {
55
+ tmp_all[warp_id] = loss;
56
+ }
57
+
58
+ __syncthreads();
59
+
60
+ if (warp_id != 0) {
61
+ return;
62
+ }
63
+
64
+ loss = lane_id < CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE/WARP_SIZE ? tmp_all[lane_id] : 0.0f;
65
+ loss = warp_reduce_sum(loss);
66
+
67
+ if (lane_id != 0) {
68
+ return;
69
+ }
70
+
71
+ dst[blockIdx.x] = loss;
72
+ }
73
+
74
+ void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
75
+ const ggml_tensor * src0 = dst->src[0];
76
+ const ggml_tensor * src1 = dst->src[1];
77
+
78
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
79
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
80
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
81
+
82
+ GGML_ASSERT(ggml_is_contiguous(src0));
83
+ GGML_ASSERT(ggml_is_contiguous(src1));
84
+ GGML_ASSERT(ggml_is_contiguous(dst));
85
+
86
+ const int64_t ne00 = src0->ne[0];
87
+ const int64_t nrows = ggml_nrows(src0);
88
+
89
+ const float * src0_d = (const float *) src0->data;
90
+ const float * src1_d = (const float *) src1->data;
91
+ float * dst_d = (float *) dst->data;
92
+
93
+ ggml_cuda_pool & pool = ctx.pool();
94
+ cudaStream_t stream = ctx.stream();
95
+
96
+ const dim3 blocks_dim(CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE, 1, 1);
97
+ const dim3 blocks_num((nrows + CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE - 1) / CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE, 1, 1);
98
+ const int shmem = 2*CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE*ne00*sizeof(float);
99
+
100
+ ggml_cuda_pool_alloc<float> dst_tmp(pool, blocks_num.x);
101
+
102
+ cross_entropy_loss_f32<<<blocks_num, blocks_dim, shmem, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
103
+
104
+ // Combine results from individual blocks:
105
+ sum_rows_f32_cuda(dst_tmp.ptr, dst_d, blocks_num.x, 1, stream);
106
+ }
ggml/src/ggml-cuda/cross-entropy-loss.cuh ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ #include "common.cuh"
2
+
3
+ #define CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE 256
4
+
5
+ void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
ggml/src/ggml-cuda/sumrows.cu CHANGED
@@ -16,7 +16,7 @@ static __global__ void k_sum_rows_f32(const float * x, float * dst, const int nc
16
  }
17
  }
18
 
19
- static void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
20
  const dim3 block_dims(WARP_SIZE, 1, 1);
21
  const dim3 block_nums(nrows, 1, 1);
22
  k_sum_rows_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
@@ -32,7 +32,6 @@ void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
32
  GGML_ASSERT( dst->type == GGML_TYPE_F32);
33
  GGML_ASSERT(ggml_is_contiguous(src0));
34
 
35
-
36
  const int64_t ncols = src0->ne[0];
37
  const int64_t nrows = ggml_nrows(src0);
38
 
 
16
  }
17
  }
18
 
19
+ void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
20
  const dim3 block_dims(WARP_SIZE, 1, 1);
21
  const dim3 block_nums(nrows, 1, 1);
22
  k_sum_rows_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
 
32
  GGML_ASSERT( dst->type == GGML_TYPE_F32);
33
  GGML_ASSERT(ggml_is_contiguous(src0));
34
 
 
35
  const int64_t ncols = src0->ne[0];
36
  const int64_t nrows = ggml_nrows(src0);
37
 
ggml/src/ggml-cuda/sumrows.cuh CHANGED
@@ -1,3 +1,5 @@
1
  #include "common.cuh"
2
 
 
 
3
  void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
1
  #include "common.cuh"
2
 
3
+ void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream);
4
+
5
  void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
ggml/src/ggml.c CHANGED
@@ -2671,6 +2671,19 @@ static ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x,
2671
  return sum;
2672
  }
2673
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2674
  inline static float ggml_silu_backward_f32(float x, float dy) {
2675
  const float s = 1.0f/(1.0f + expf(-x));
2676
  return dy*s*(1.0f + x*(1.0f - s));
@@ -17022,8 +17035,6 @@ static void ggml_compute_forward_cross_entropy_loss_f32(
17022
  }
17023
  ggml_barrier(params->shared);
17024
 
17025
- const double eps = 1e-9;
17026
-
17027
  // rows per thread
17028
  const int dr = (nr + nth - 1)/nth;
17029
 
@@ -17044,20 +17055,15 @@ static void ggml_compute_forward_cross_entropy_loss_f32(
17044
  }
17045
  #endif
17046
 
17047
- // soft_max
17048
  float max = -INFINITY;
17049
  ggml_vec_max_f32(nc, &max, s0);
17050
- ggml_float sum = ggml_vec_soft_max_f32(nc, st, s0, max);
17051
- assert(sum > 0.0);
17052
- sum = (1.0 - eps) / sum;
17053
 
17054
- // avoid log(0) by rescaling from [0..1] to [eps..1]
17055
- ggml_vec_scale_f32(nc, st, sum);
17056
- ggml_vec_add1_f32(nc, st, st, eps);
17057
- ggml_vec_log_f32(nc, st, st);
17058
  ggml_vec_mul_f32(nc, st, st, s1);
17059
 
17060
- float st_sum = 0;
17061
  ggml_vec_sum_f32(nc, &st_sum, st);
17062
  sums[ith] += st_sum;
17063
 
@@ -17114,8 +17120,6 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
17114
  const int64_t ith = params->ith;
17115
  const int64_t nth = params->nth;
17116
 
17117
- const double eps = 1e-9;
17118
-
17119
  // TODO: handle transposed/permuted matrices
17120
  const int64_t nc = src0->ne[0];
17121
  const int64_t nr = ggml_nrows(src0);
@@ -17147,11 +17151,9 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
17147
  ggml_vec_max_f32(nc, &max, s0);
17148
  ggml_float sum = ggml_vec_soft_max_f32(nc, ds0, s0, max);
17149
  assert(sum > 0.0);
17150
- sum = (1.0 - eps) / sum;
17151
 
17152
  // grad(src0) = (softmax(src0) - src1) * grad(cross_entropy_loss(src0, src1)) / nr
17153
- ggml_vec_scale_f32(nc, ds0, sum);
17154
- ggml_vec_add1_f32(nc, ds0, ds0, eps);
17155
  ggml_vec_sub_f32(nc, ds0, ds0, s1);
17156
  ggml_vec_scale_f32(nc, ds0, d[0] / (float) nr);
17157
 
@@ -20287,6 +20289,7 @@ static enum ggml_opt_result ggml_opt_adam(
20287
  ggml_opt_callback callback,
20288
  void * callback_data) {
20289
  GGML_ASSERT(ggml_is_scalar(f));
 
20290
 
20291
  // these will store the parameters we want to optimize
20292
  struct ggml_tensor * ps[GGML_MAX_PARAMS];
 
2671
  return sum;
2672
  }
2673
 
2674
+ static ggml_float ggml_vec_log_soft_max_f32(const int n, float * y, const float * x, float max) {
2675
+ // log(soft_max) = log(soft_max_i / soft_max_sum) = log(soft_max_i) - log(soft_max_sum) = (logit_i - max) - log(soft_max_i)
2676
+
2677
+ int i = 0;
2678
+ ggml_float sum = 0;
2679
+ for (; i < n; ++i) {
2680
+ float val = x[i] - max;
2681
+ y[i] = val;
2682
+ sum += (ggml_float)expf(val);
2683
+ }
2684
+ return sum = (ggml_float)logf(sum);
2685
+ }
2686
+
2687
  inline static float ggml_silu_backward_f32(float x, float dy) {
2688
  const float s = 1.0f/(1.0f + expf(-x));
2689
  return dy*s*(1.0f + x*(1.0f - s));
 
17035
  }
17036
  ggml_barrier(params->shared);
17037
 
 
 
17038
  // rows per thread
17039
  const int dr = (nr + nth - 1)/nth;
17040
 
 
17055
  }
17056
  #endif
17057
 
 
17058
  float max = -INFINITY;
17059
  ggml_vec_max_f32(nc, &max, s0);
17060
+ ggml_float sum = ggml_vec_log_soft_max_f32(nc, st, s0, max);
17061
+ assert(sum >= 0.0);
 
17062
 
17063
+ ggml_vec_add1_f32(nc, st, st, -sum);
 
 
 
17064
  ggml_vec_mul_f32(nc, st, st, s1);
17065
 
17066
+ float st_sum = 0.0f;
17067
  ggml_vec_sum_f32(nc, &st_sum, st);
17068
  sums[ith] += st_sum;
17069
 
 
17120
  const int64_t ith = params->ith;
17121
  const int64_t nth = params->nth;
17122
 
 
 
17123
  // TODO: handle transposed/permuted matrices
17124
  const int64_t nc = src0->ne[0];
17125
  const int64_t nr = ggml_nrows(src0);
 
17151
  ggml_vec_max_f32(nc, &max, s0);
17152
  ggml_float sum = ggml_vec_soft_max_f32(nc, ds0, s0, max);
17153
  assert(sum > 0.0);
17154
+ ggml_vec_scale_f32(nc, ds0, 1.0/sum);
17155
 
17156
  // grad(src0) = (softmax(src0) - src1) * grad(cross_entropy_loss(src0, src1)) / nr
 
 
17157
  ggml_vec_sub_f32(nc, ds0, ds0, s1);
17158
  ggml_vec_scale_f32(nc, ds0, d[0] / (float) nr);
17159
 
 
20289
  ggml_opt_callback callback,
20290
  void * callback_data) {
20291
  GGML_ASSERT(ggml_is_scalar(f));
20292
+ GGML_ASSERT(f->type == GGML_TYPE_F32);
20293
 
20294
  // these will store the parameters we want to optimize
20295
  struct ggml_tensor * ps[GGML_MAX_PARAMS];