JohannesGaessler commited on
Commit
4751b2f
·
1 Parent(s): 5b36f0b

tests: add gradient tests for all backends (ggml/932)

Browse files

* tests: add gradient checking to test-backend-ops

* remove old comment

* reorder includes

* adjust SIN/COS parameters

* add documentation, use supports_op if possible

ggml/include/ggml.h CHANGED
@@ -1234,7 +1234,7 @@ extern "C" {
1234
  size_t nb1,
1235
  size_t nb2,
1236
  size_t nb3,
1237
- size_t offset);
1238
 
1239
  // b -> view(a,offset,nb1,nb2,3), return view(a)
1240
  GGML_API struct ggml_tensor * ggml_set_inplace(
@@ -1244,19 +1244,19 @@ extern "C" {
1244
  size_t nb1,
1245
  size_t nb2,
1246
  size_t nb3,
1247
- size_t offset);
1248
 
1249
  GGML_API struct ggml_tensor * ggml_set_1d(
1250
  struct ggml_context * ctx,
1251
  struct ggml_tensor * a,
1252
  struct ggml_tensor * b,
1253
- size_t offset);
1254
 
1255
  GGML_API struct ggml_tensor * ggml_set_1d_inplace(
1256
  struct ggml_context * ctx,
1257
  struct ggml_tensor * a,
1258
  struct ggml_tensor * b,
1259
- size_t offset);
1260
 
1261
  // b -> view(a,offset,nb1,nb2,3), return modified a
1262
  GGML_API struct ggml_tensor * ggml_set_2d(
@@ -1264,7 +1264,7 @@ extern "C" {
1264
  struct ggml_tensor * a,
1265
  struct ggml_tensor * b,
1266
  size_t nb1,
1267
- size_t offset);
1268
 
1269
  // b -> view(a,offset,nb1,nb2,3), return view(a)
1270
  GGML_API struct ggml_tensor * ggml_set_2d_inplace(
@@ -1272,7 +1272,7 @@ extern "C" {
1272
  struct ggml_tensor * a,
1273
  struct ggml_tensor * b,
1274
  size_t nb1,
1275
- size_t offset);
1276
 
1277
  // a -> b, return view(b)
1278
  GGML_API struct ggml_tensor * ggml_cpy(
 
1234
  size_t nb1,
1235
  size_t nb2,
1236
  size_t nb3,
1237
+ size_t offset); // in bytes
1238
 
1239
  // b -> view(a,offset,nb1,nb2,3), return view(a)
1240
  GGML_API struct ggml_tensor * ggml_set_inplace(
 
1244
  size_t nb1,
1245
  size_t nb2,
1246
  size_t nb3,
1247
+ size_t offset); // in bytes
1248
 
1249
  GGML_API struct ggml_tensor * ggml_set_1d(
1250
  struct ggml_context * ctx,
1251
  struct ggml_tensor * a,
1252
  struct ggml_tensor * b,
1253
+ size_t offset); // in bytes
1254
 
1255
  GGML_API struct ggml_tensor * ggml_set_1d_inplace(
1256
  struct ggml_context * ctx,
1257
  struct ggml_tensor * a,
1258
  struct ggml_tensor * b,
1259
+ size_t offset); // in bytes
1260
 
1261
  // b -> view(a,offset,nb1,nb2,3), return modified a
1262
  GGML_API struct ggml_tensor * ggml_set_2d(
 
1264
  struct ggml_tensor * a,
1265
  struct ggml_tensor * b,
1266
  size_t nb1,
1267
+ size_t offset); // in bytes
1268
 
1269
  // b -> view(a,offset,nb1,nb2,3), return view(a)
1270
  GGML_API struct ggml_tensor * ggml_set_2d_inplace(
 
1272
  struct ggml_tensor * a,
1273
  struct ggml_tensor * b,
1274
  size_t nb1,
1275
+ size_t offset); // in bytes
1276
 
1277
  // a -> b, return view(b)
1278
  GGML_API struct ggml_tensor * ggml_cpy(
ggml/src/ggml-backend.c CHANGED
@@ -825,6 +825,10 @@ GGML_CALL static bool ggml_backend_cpu_supports_op(ggml_backend_t backend, const
825
  op->type != GGML_TYPE_IQ1_M; // missing type_traits.from_float
826
  case GGML_OP_MUL_MAT:
827
  return op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == ggml_internal_get_type_traits(op->src[0]->type).vec_dot_type;
 
 
 
 
828
  default:
829
  return true;
830
  }
 
825
  op->type != GGML_TYPE_IQ1_M; // missing type_traits.from_float
826
  case GGML_OP_MUL_MAT:
827
  return op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == ggml_internal_get_type_traits(op->src[0]->type).vec_dot_type;
828
+ case GGML_OP_ROPE_BACK:
829
+ return op->src[2] == NULL && (op->op_params[2] & 4) == 0;
830
+ case GGML_OP_IM2COL_BACK:
831
+ return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
832
  default:
833
  return true;
834
  }
ggml/src/ggml-cuda.cu CHANGED
@@ -27,6 +27,7 @@
27
  #include "ggml-cuda/rope.cuh"
28
  #include "ggml-cuda/scale.cuh"
29
  #include "ggml-cuda/softmax.cuh"
 
30
  #include "ggml-cuda/sumrows.cuh"
31
  #include "ggml-cuda/tsembd.cuh"
32
  #include "ggml-cuda/unary.cuh"
@@ -2180,6 +2181,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
2180
  ggml_cuda_dup(ctx, dst);
2181
  break;
2182
  case GGML_OP_ADD:
 
2183
  ggml_cuda_op_add(ctx, dst);
2184
  break;
2185
  case GGML_OP_SUB:
@@ -2196,6 +2198,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
2196
  break;
2197
  case GGML_OP_UNARY:
2198
  switch (ggml_get_unary_op(dst)) {
 
 
 
2199
  case GGML_UNARY_OP_GELU:
2200
  ggml_cuda_op_gelu(ctx, dst);
2201
  break;
@@ -2304,6 +2309,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
2304
  case GGML_OP_POOL_2D:
2305
  ggml_cuda_op_pool2d(ctx, dst);
2306
  break;
 
 
 
2307
  case GGML_OP_SUM_ROWS:
2308
  ggml_cuda_op_sum_rows(ctx, dst);
2309
  break;
@@ -2741,6 +2749,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
2741
  switch (op->op) {
2742
  case GGML_OP_UNARY:
2743
  switch (ggml_get_unary_op(op)) {
 
2744
  case GGML_UNARY_OP_GELU:
2745
  case GGML_UNARY_OP_SILU:
2746
  case GGML_UNARY_OP_RELU:
@@ -2867,6 +2876,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
2867
  case GGML_OP_TRANSPOSE:
2868
  case GGML_OP_NORM:
2869
  case GGML_OP_ADD:
 
2870
  case GGML_OP_SUB:
2871
  case GGML_OP_MUL:
2872
  case GGML_OP_DIV:
@@ -2886,7 +2896,9 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
2886
  case GGML_OP_ROPE:
2887
  return ggml_is_contiguous(op->src[0]);
2888
  case GGML_OP_IM2COL:
 
2889
  case GGML_OP_POOL_2D:
 
2890
  case GGML_OP_SUM_ROWS:
2891
  case GGML_OP_ARGSORT:
2892
  case GGML_OP_ACC:
 
27
  #include "ggml-cuda/rope.cuh"
28
  #include "ggml-cuda/scale.cuh"
29
  #include "ggml-cuda/softmax.cuh"
30
+ #include "ggml-cuda/sum.cuh"
31
  #include "ggml-cuda/sumrows.cuh"
32
  #include "ggml-cuda/tsembd.cuh"
33
  #include "ggml-cuda/unary.cuh"
 
2181
  ggml_cuda_dup(ctx, dst);
2182
  break;
2183
  case GGML_OP_ADD:
2184
+ case GGML_OP_ADD1: // TODO: more efficient implementation
2185
  ggml_cuda_op_add(ctx, dst);
2186
  break;
2187
  case GGML_OP_SUB:
 
2198
  break;
2199
  case GGML_OP_UNARY:
2200
  switch (ggml_get_unary_op(dst)) {
2201
+ case GGML_UNARY_OP_NEG:
2202
+ ggml_cuda_op_neg(ctx, dst);
2203
+ break;
2204
  case GGML_UNARY_OP_GELU:
2205
  ggml_cuda_op_gelu(ctx, dst);
2206
  break;
 
2309
  case GGML_OP_POOL_2D:
2310
  ggml_cuda_op_pool2d(ctx, dst);
2311
  break;
2312
+ case GGML_OP_SUM:
2313
+ ggml_cuda_op_sum(ctx, dst);
2314
+ break;
2315
  case GGML_OP_SUM_ROWS:
2316
  ggml_cuda_op_sum_rows(ctx, dst);
2317
  break;
 
2749
  switch (op->op) {
2750
  case GGML_OP_UNARY:
2751
  switch (ggml_get_unary_op(op)) {
2752
+ case GGML_UNARY_OP_NEG:
2753
  case GGML_UNARY_OP_GELU:
2754
  case GGML_UNARY_OP_SILU:
2755
  case GGML_UNARY_OP_RELU:
 
2876
  case GGML_OP_TRANSPOSE:
2877
  case GGML_OP_NORM:
2878
  case GGML_OP_ADD:
2879
+ case GGML_OP_ADD1:
2880
  case GGML_OP_SUB:
2881
  case GGML_OP_MUL:
2882
  case GGML_OP_DIV:
 
2896
  case GGML_OP_ROPE:
2897
  return ggml_is_contiguous(op->src[0]);
2898
  case GGML_OP_IM2COL:
2899
+ return op->src[0]->type == GGML_TYPE_F16;
2900
  case GGML_OP_POOL_2D:
2901
+ case GGML_OP_SUM:
2902
  case GGML_OP_SUM_ROWS:
2903
  case GGML_OP_ARGSORT:
2904
  case GGML_OP_ACC:
ggml/src/ggml-cuda/cross-entropy-loss.cu CHANGED
@@ -1,6 +1,6 @@
1
  #include "common.cuh"
2
  #include "cross-entropy-loss.cuh"
3
- #include "sumrows.cuh"
4
 
5
  #include <cmath>
6
  #include <cstdint>
@@ -102,5 +102,5 @@ void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor *
102
  cross_entropy_loss_f32<<<blocks_num, blocks_dim, shmem, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
103
 
104
  // Combine results from individual blocks:
105
- sum_rows_f32_cuda(dst_tmp.ptr, dst_d, blocks_num.x, 1, stream);
106
  }
 
1
  #include "common.cuh"
2
  #include "cross-entropy-loss.cuh"
3
+ #include "sum.cuh"
4
 
5
  #include <cmath>
6
  #include <cstdint>
 
102
  cross_entropy_loss_f32<<<blocks_num, blocks_dim, shmem, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
103
 
104
  // Combine results from individual blocks:
105
+ sum_f32_cuda(pool, dst_tmp.ptr, dst_d, blocks_num.x, stream);
106
  }
ggml/src/ggml-cuda/sum.cu ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "sumrows.cuh"
2
+ #include "sum.cuh"
3
+
4
+ #include <cstdint>
5
+
6
+ #if !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
7
+ #include <cub/cub.cuh>
8
+ using namespace cub;
9
+ #endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
10
+
11
+ void sum_f32_cuda(ggml_cuda_pool & pool, const float * x, float * dst, const int64_t ne, cudaStream_t stream) {
12
+ #if !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
13
+ size_t tmp_size = 0;
14
+ DeviceReduce::Sum(nullptr, tmp_size, x, dst, ne, stream);
15
+ ggml_cuda_pool_alloc<uint8_t> tmp_alloc(pool, tmp_size);
16
+ DeviceReduce::Sum(tmp_alloc.ptr, tmp_size, x, dst, ne, stream);
17
+ #else
18
+ // Use (inefficient) sum_rows implementation as a fallback.
19
+ // For AMD there is rocPRIM which could be used as a drop-in replacement via hipcub but this would require C++11 -> C++14.
20
+ sum_rows_f32_cuda(x, dst, ne, 1, stream);
21
+ GGML_UNUSED(pool);
22
+ #endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
23
+ }
24
+
25
+ void ggml_cuda_op_sum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
26
+ const ggml_tensor * src0 = dst->src[0];
27
+
28
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
29
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
30
+ GGML_ASSERT(ggml_is_contiguous(src0));
31
+
32
+ const float * src0_d = (const float *) src0->data;
33
+ float * dst_d = (float *) dst->data;
34
+
35
+ const int64_t ne = ggml_nelements(src0);
36
+
37
+ ggml_cuda_pool & pool = ctx.pool();
38
+ cudaStream_t stream = ctx.stream();
39
+
40
+ sum_f32_cuda(pool, src0_d, dst_d, ne, stream);
41
+ }
ggml/src/ggml-cuda/sum.cuh ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ #include "common.cuh"
2
+
3
+ void sum_f32_cuda(ggml_cuda_pool & pool, const float * x, float * dst, const int64_t ne, cudaStream_t stream);
4
+
5
+ void ggml_cuda_op_sum(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
ggml/src/ggml-cuda/unary.cu CHANGED
@@ -1,5 +1,15 @@
1
  #include "unary.cuh"
2
 
 
 
 
 
 
 
 
 
 
 
3
  static __global__ void gelu_f32(const float * x, float * dst, const int k) {
4
  const float GELU_COEF_A = 0.044715f;
5
  const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
@@ -119,6 +129,11 @@ static __global__ void cos_f32(const float * x, float * dst, const int k) {
119
  dst[i] = cosf(x[i]);
120
  }
121
 
 
 
 
 
 
122
  static void gelu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
123
  const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
124
  gelu_f32<<<num_blocks, CUDA_GELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
@@ -184,6 +199,20 @@ static void cos_f32_cuda(const float * x, float * dst, const int k, cudaStream_t
184
  cos_f32<<<num_blocks, CUDA_COS_BLOCK_SIZE, 0, stream>>>(x, dst, k);
185
  }
186
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
  void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
188
  const ggml_tensor * src0 = dst->src[0];
189
  const float * src0_d = (const float *)src0->data;
 
1
  #include "unary.cuh"
2
 
3
+ static __global__ void neg_f32(const float * x, float * dst, const int k) {
4
+ const int i = blockDim.x*blockIdx.x + threadIdx.x;
5
+
6
+ if (i >= k) {
7
+ return;
8
+ }
9
+
10
+ dst[i] = -x[i];
11
+ }
12
+
13
  static __global__ void gelu_f32(const float * x, float * dst, const int k) {
14
  const float GELU_COEF_A = 0.044715f;
15
  const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
 
129
  dst[i] = cosf(x[i]);
130
  }
131
 
132
+ static void neg_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
133
+ const int num_blocks = (k + CUDA_NEG_BLOCK_SIZE - 1) / CUDA_NEG_BLOCK_SIZE;
134
+ neg_f32<<<num_blocks, CUDA_NEG_BLOCK_SIZE, 0, stream>>>(x, dst, k);
135
+ }
136
+
137
  static void gelu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
138
  const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
139
  gelu_f32<<<num_blocks, CUDA_GELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
 
199
  cos_f32<<<num_blocks, CUDA_COS_BLOCK_SIZE, 0, stream>>>(x, dst, k);
200
  }
201
 
202
+ void ggml_cuda_op_neg(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
203
+ const ggml_tensor * src0 = dst->src[0];
204
+ const float * src0_d = (const float *)src0->data;
205
+ float * dst_d = (float *)dst->data;
206
+ cudaStream_t stream = ctx.stream();
207
+
208
+ GGML_ASSERT(ggml_is_contiguous(src0));
209
+
210
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
211
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
212
+
213
+ neg_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
214
+ }
215
+
216
  void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
217
  const ggml_tensor * src0 = dst->src[0];
218
  const float * src0_d = (const float *)src0->data;
ggml/src/ggml-cuda/unary.cuh CHANGED
@@ -1,5 +1,6 @@
1
  #include "common.cuh"
2
 
 
3
  #define CUDA_GELU_BLOCK_SIZE 256
4
  #define CUDA_SILU_BLOCK_SIZE 256
5
  #define CUDA_TANH_BLOCK_SIZE 256
@@ -12,6 +13,8 @@
12
  #define CUDA_SIN_BLOCK_SIZE 256
13
  #define CUDA_COS_BLOCK_SIZE 256
14
 
 
 
15
  void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
16
 
17
  void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
1
  #include "common.cuh"
2
 
3
+ #define CUDA_NEG_BLOCK_SIZE 256
4
  #define CUDA_GELU_BLOCK_SIZE 256
5
  #define CUDA_SILU_BLOCK_SIZE 256
6
  #define CUDA_TANH_BLOCK_SIZE 256
 
13
  #define CUDA_SIN_BLOCK_SIZE 256
14
  #define CUDA_COS_BLOCK_SIZE 256
15
 
16
+ void ggml_cuda_op_neg(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
17
+
18
  void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
19
 
20
  void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
ggml/src/ggml.c CHANGED
@@ -5131,6 +5131,7 @@ struct ggml_tensor * ggml_concat(
5131
  bool is_node = false;
5132
 
5133
  if (a->grad || b->grad) {
 
5134
  is_node = true;
5135
  }
5136
 
@@ -5252,6 +5253,7 @@ struct ggml_tensor * ggml_leaky_relu(
5252
  bool is_node = false;
5253
 
5254
  if (!inplace && (a->grad)) {
 
5255
  is_node = true;
5256
  }
5257
 
@@ -5677,6 +5679,7 @@ static struct ggml_tensor * ggml_set_impl(
5677
  // make a view of the destination
5678
  struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
5679
 
 
5680
  int32_t params[] = { nb1, nb2, nb3, offset, inplace ? 1 : 0 };
5681
  ggml_set_op_params(result, params, sizeof(params));
5682
 
@@ -6634,14 +6637,12 @@ struct ggml_tensor * ggml_rope_back(
6634
  GGML_ASSERT(ggml_is_vector(b));
6635
  GGML_ASSERT(b->type == GGML_TYPE_I32);
6636
  GGML_ASSERT(a->ne[2] == b->ne[0]);
6637
- GGML_ASSERT(c == NULL && "freq factors not implemented yet");
6638
-
6639
- GGML_ASSERT((mode & 4) == 0 && "ggml_rope_back() for ChatGLM not implemented yet");
6640
 
6641
  bool is_node = false;
6642
 
6643
  if (a->grad) {
6644
- is_node = false; // TODO: implement backward
 
6645
  }
6646
 
6647
  struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
@@ -6659,6 +6660,7 @@ struct ggml_tensor * ggml_rope_back(
6659
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6660
  result->src[0] = a;
6661
  result->src[1] = b;
 
6662
 
6663
  return result;
6664
  }
@@ -7212,6 +7214,11 @@ struct ggml_tensor * ggml_argsort(
7212
  enum ggml_sort_order order) {
7213
  bool is_node = false;
7214
 
 
 
 
 
 
7215
  struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_I32, GGML_MAX_DIMS, a->ne);
7216
 
7217
  ggml_set_op_params_i32(result, 0, (int32_t) order);
@@ -10745,9 +10752,6 @@ static void ggml_compute_forward_sum_f32(
10745
  return;
10746
  }
10747
 
10748
- assert(ggml_is_scalar(dst));
10749
-
10750
-
10751
  assert(ggml_is_scalar(dst));
10752
  assert(src0->nb[0] == sizeof(float));
10753
 
@@ -18000,14 +18004,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18000
  if (src0->grad || src1->grad) {
18001
  GGML_ASSERT(src0->type == tensor->type);
18002
  GGML_ASSERT(tensor->grad->type == tensor->type);
18003
- GGML_ASSERT(tensor->grad->type == src1->grad->type);
18004
 
18005
  tensor_grad_view = ggml_view_4d(ctx,
18006
- tensor->grad,
18007
- src1->grad->ne[0],
18008
- src1->grad->ne[1],
18009
- src1->grad->ne[2],
18010
- src1->grad->ne[3],
18011
  nb1, nb2, nb3, offset);
18012
  }
18013
 
@@ -18076,9 +18076,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18076
 
18077
  memcpy(&offset, tensor->op_params, sizeof(offset));
18078
 
18079
- size_t nb1 = tensor->nb[1];
18080
- size_t nb2 = tensor->nb[2];
18081
- size_t nb3 = tensor->nb[3];
18082
 
18083
  if (src0->type != src0->grad->type) {
18084
  // gradient is typically F32, but src0 could be other type
 
5131
  bool is_node = false;
5132
 
5133
  if (a->grad || b->grad) {
5134
+ GGML_ABORT("fatal error"); // TODO: implement
5135
  is_node = true;
5136
  }
5137
 
 
5253
  bool is_node = false;
5254
 
5255
  if (!inplace && (a->grad)) {
5256
+ GGML_ABORT("fatal error"); // TODO: not implemented
5257
  is_node = true;
5258
  }
5259
 
 
5679
  // make a view of the destination
5680
  struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
5681
 
5682
+ GGML_ASSERT(offset < (size_t)(1 << 30));
5683
  int32_t params[] = { nb1, nb2, nb3, offset, inplace ? 1 : 0 };
5684
  ggml_set_op_params(result, params, sizeof(params));
5685
 
 
6637
  GGML_ASSERT(ggml_is_vector(b));
6638
  GGML_ASSERT(b->type == GGML_TYPE_I32);
6639
  GGML_ASSERT(a->ne[2] == b->ne[0]);
 
 
 
6640
 
6641
  bool is_node = false;
6642
 
6643
  if (a->grad) {
6644
+ GGML_ASSERT(false && "backwards pass not implemented");
6645
+ is_node = false;
6646
  }
6647
 
6648
  struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
 
6660
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6661
  result->src[0] = a;
6662
  result->src[1] = b;
6663
+ result->src[2] = c;
6664
 
6665
  return result;
6666
  }
 
7214
  enum ggml_sort_order order) {
7215
  bool is_node = false;
7216
 
7217
+ if (a->grad) {
7218
+ GGML_ABORT("fatal error"); // TODO: not implemented
7219
+ is_node = true;
7220
+ }
7221
+
7222
  struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_I32, GGML_MAX_DIMS, a->ne);
7223
 
7224
  ggml_set_op_params_i32(result, 0, (int32_t) order);
 
10752
  return;
10753
  }
10754
 
 
 
 
10755
  assert(ggml_is_scalar(dst));
10756
  assert(src0->nb[0] == sizeof(float));
10757
 
 
18004
  if (src0->grad || src1->grad) {
18005
  GGML_ASSERT(src0->type == tensor->type);
18006
  GGML_ASSERT(tensor->grad->type == tensor->type);
18007
+ GGML_ASSERT(!src1->grad || src1->grad->type == tensor->grad->type);
18008
 
18009
  tensor_grad_view = ggml_view_4d(ctx,
18010
+ tensor->grad, src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],
 
 
 
 
18011
  nb1, nb2, nb3, offset);
18012
  }
18013
 
 
18076
 
18077
  memcpy(&offset, tensor->op_params, sizeof(offset));
18078
 
18079
+ size_t nb1 = tensor->nb[1];
18080
+ size_t nb2 = tensor->nb[2];
18081
+ size_t nb3 = tensor->nb[3];
18082
 
18083
  if (src0->type != src0->grad->type) {
18084
  // gradient is typically F32, but src0 could be other type