ggerganov commited on
Commit
b3b8141
·
unverified ·
1 Parent(s): 068424c

ggml : implement ggml_compute_forward_dup_f16() special cases

Browse files
Files changed (1) hide show
  1. ggml.c +83 -9
ggml.c CHANGED
@@ -3178,22 +3178,96 @@ void ggml_compute_forward_dup_f16(
3178
  return;
3179
  }
3180
 
3181
- //const int ne00 = src0->ne[0];
3182
- //const int ne01 = src0->ne[1];
3183
- //const int ne02 = src0->ne[2];
3184
- //const int ne03 = src0->ne[3];
3185
 
3186
- //const size_t nb00 = src0->nb[0];
3187
- //const size_t nb01 = src0->nb[1];
3188
- //const size_t nb02 = src0->nb[2];
3189
- //const size_t nb03 = src0->nb[3];
3190
 
3191
  if (ggml_is_contiguous(src0) && src0->type == dst->type) {
3192
  memcpy(dst->data, src0->data, ggml_nelements(dst) * GGML_TYPE_SIZE[src0->type]);
3193
  return;
3194
  }
3195
 
3196
- GGML_ASSERT(false); // TODO: implement
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3197
  }
3198
 
3199
  void ggml_compute_forward_dup_f32(
 
3178
  return;
3179
  }
3180
 
3181
+ const int ne00 = src0->ne[0];
3182
+ const int ne01 = src0->ne[1];
3183
+ const int ne02 = src0->ne[2];
3184
+ const int ne03 = src0->ne[3];
3185
 
3186
+ const size_t nb00 = src0->nb[0];
3187
+ const size_t nb01 = src0->nb[1];
3188
+ const size_t nb02 = src0->nb[2];
3189
+ const size_t nb03 = src0->nb[3];
3190
 
3191
  if (ggml_is_contiguous(src0) && src0->type == dst->type) {
3192
  memcpy(dst->data, src0->data, ggml_nelements(dst) * GGML_TYPE_SIZE[src0->type]);
3193
  return;
3194
  }
3195
 
3196
+ if (src0->nb[0] == sizeof(ggml_fp16_t)) {
3197
+ if (dst->type == GGML_TYPE_F16) {
3198
+ int id = 0;
3199
+ const size_t rs = ne00*nb00;
3200
+
3201
+ for (int i03 = 0; i03 < ne03; i03++) {
3202
+ for (int i02 = 0; i02 < ne02; i02++) {
3203
+ for (int i01 = 0; i01 < ne01; i01++) {
3204
+ const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
3205
+ char * dst_ptr = (char *) dst->data + id*rs;
3206
+
3207
+ memcpy(dst_ptr, src0_ptr, rs);
3208
+
3209
+ id++;
3210
+ }
3211
+ }
3212
+ }
3213
+ } else if (dst->type == GGML_TYPE_F32) {
3214
+ int id = 0;
3215
+ float * dst_ptr = (float *) dst->data;
3216
+
3217
+ for (int i03 = 0; i03 < ne03; i03++) {
3218
+ for (int i02 = 0; i02 < ne02; i02++) {
3219
+ for (int i01 = 0; i01 < ne01; i01++) {
3220
+ for (int i00 = 0; i00 < ne00; i00++) {
3221
+ const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
3222
+
3223
+ dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr);
3224
+ id++;
3225
+ }
3226
+ }
3227
+ }
3228
+ }
3229
+ } else {
3230
+ GGML_ASSERT(false); // TODO: implement
3231
+ }
3232
+ } else {
3233
+ //printf("%s: this is not optimal - fix me\n", __func__);
3234
+
3235
+ if (dst->type == GGML_TYPE_F32) {
3236
+ int id = 0;
3237
+ float * dst_ptr = (float *) dst->data;
3238
+
3239
+ for (int i03 = 0; i03 < ne03; i03++) {
3240
+ for (int i02 = 0; i02 < ne02; i02++) {
3241
+ for (int i01 = 0; i01 < ne01; i01++) {
3242
+ for (int i00 = 0; i00 < ne00; i00++) {
3243
+ const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
3244
+
3245
+ dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr);
3246
+ id++;
3247
+ }
3248
+ }
3249
+ }
3250
+ }
3251
+ } else if (dst->type == GGML_TYPE_F16) {
3252
+ int id = 0;
3253
+ ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
3254
+
3255
+ for (int i03 = 0; i03 < ne03; i03++) {
3256
+ for (int i02 = 0; i02 < ne02; i02++) {
3257
+ for (int i01 = 0; i01 < ne01; i01++) {
3258
+ for (int i00 = 0; i00 < ne00; i00++) {
3259
+ const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
3260
+
3261
+ dst_ptr[id] = *src0_ptr;
3262
+ id++;
3263
+ }
3264
+ }
3265
+ }
3266
+ }
3267
+ } else {
3268
+ GGML_ASSERT(false); // TODO: implement
3269
+ }
3270
+ }
3271
  }
3272
 
3273
  void ggml_compute_forward_dup_f32(