Spaces:
Running
Running
ggml : implement ggml_compute_forward_dup_f16() special cases
Browse files
ggml.c
CHANGED
|
@@ -3178,22 +3178,96 @@ void ggml_compute_forward_dup_f16(
|
|
| 3178 |
return;
|
| 3179 |
}
|
| 3180 |
|
| 3181 |
-
|
| 3182 |
-
|
| 3183 |
-
|
| 3184 |
-
|
| 3185 |
|
| 3186 |
-
|
| 3187 |
-
|
| 3188 |
-
|
| 3189 |
-
|
| 3190 |
|
| 3191 |
if (ggml_is_contiguous(src0) && src0->type == dst->type) {
|
| 3192 |
memcpy(dst->data, src0->data, ggml_nelements(dst) * GGML_TYPE_SIZE[src0->type]);
|
| 3193 |
return;
|
| 3194 |
}
|
| 3195 |
|
| 3196 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3197 |
}
|
| 3198 |
|
| 3199 |
void ggml_compute_forward_dup_f32(
|
|
|
|
| 3178 |
return;
|
| 3179 |
}
|
| 3180 |
|
| 3181 |
+
const int ne00 = src0->ne[0];
|
| 3182 |
+
const int ne01 = src0->ne[1];
|
| 3183 |
+
const int ne02 = src0->ne[2];
|
| 3184 |
+
const int ne03 = src0->ne[3];
|
| 3185 |
|
| 3186 |
+
const size_t nb00 = src0->nb[0];
|
| 3187 |
+
const size_t nb01 = src0->nb[1];
|
| 3188 |
+
const size_t nb02 = src0->nb[2];
|
| 3189 |
+
const size_t nb03 = src0->nb[3];
|
| 3190 |
|
| 3191 |
if (ggml_is_contiguous(src0) && src0->type == dst->type) {
|
| 3192 |
memcpy(dst->data, src0->data, ggml_nelements(dst) * GGML_TYPE_SIZE[src0->type]);
|
| 3193 |
return;
|
| 3194 |
}
|
| 3195 |
|
| 3196 |
+
if (src0->nb[0] == sizeof(ggml_fp16_t)) {
|
| 3197 |
+
if (dst->type == GGML_TYPE_F16) {
|
| 3198 |
+
int id = 0;
|
| 3199 |
+
const size_t rs = ne00*nb00;
|
| 3200 |
+
|
| 3201 |
+
for (int i03 = 0; i03 < ne03; i03++) {
|
| 3202 |
+
for (int i02 = 0; i02 < ne02; i02++) {
|
| 3203 |
+
for (int i01 = 0; i01 < ne01; i01++) {
|
| 3204 |
+
const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
|
| 3205 |
+
char * dst_ptr = (char *) dst->data + id*rs;
|
| 3206 |
+
|
| 3207 |
+
memcpy(dst_ptr, src0_ptr, rs);
|
| 3208 |
+
|
| 3209 |
+
id++;
|
| 3210 |
+
}
|
| 3211 |
+
}
|
| 3212 |
+
}
|
| 3213 |
+
} else if (dst->type == GGML_TYPE_F32) {
|
| 3214 |
+
int id = 0;
|
| 3215 |
+
float * dst_ptr = (float *) dst->data;
|
| 3216 |
+
|
| 3217 |
+
for (int i03 = 0; i03 < ne03; i03++) {
|
| 3218 |
+
for (int i02 = 0; i02 < ne02; i02++) {
|
| 3219 |
+
for (int i01 = 0; i01 < ne01; i01++) {
|
| 3220 |
+
for (int i00 = 0; i00 < ne00; i00++) {
|
| 3221 |
+
const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
| 3222 |
+
|
| 3223 |
+
dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr);
|
| 3224 |
+
id++;
|
| 3225 |
+
}
|
| 3226 |
+
}
|
| 3227 |
+
}
|
| 3228 |
+
}
|
| 3229 |
+
} else {
|
| 3230 |
+
GGML_ASSERT(false); // TODO: implement
|
| 3231 |
+
}
|
| 3232 |
+
} else {
|
| 3233 |
+
//printf("%s: this is not optimal - fix me\n", __func__);
|
| 3234 |
+
|
| 3235 |
+
if (dst->type == GGML_TYPE_F32) {
|
| 3236 |
+
int id = 0;
|
| 3237 |
+
float * dst_ptr = (float *) dst->data;
|
| 3238 |
+
|
| 3239 |
+
for (int i03 = 0; i03 < ne03; i03++) {
|
| 3240 |
+
for (int i02 = 0; i02 < ne02; i02++) {
|
| 3241 |
+
for (int i01 = 0; i01 < ne01; i01++) {
|
| 3242 |
+
for (int i00 = 0; i00 < ne00; i00++) {
|
| 3243 |
+
const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
| 3244 |
+
|
| 3245 |
+
dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr);
|
| 3246 |
+
id++;
|
| 3247 |
+
}
|
| 3248 |
+
}
|
| 3249 |
+
}
|
| 3250 |
+
}
|
| 3251 |
+
} else if (dst->type == GGML_TYPE_F16) {
|
| 3252 |
+
int id = 0;
|
| 3253 |
+
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
|
| 3254 |
+
|
| 3255 |
+
for (int i03 = 0; i03 < ne03; i03++) {
|
| 3256 |
+
for (int i02 = 0; i02 < ne02; i02++) {
|
| 3257 |
+
for (int i01 = 0; i01 < ne01; i01++) {
|
| 3258 |
+
for (int i00 = 0; i00 < ne00; i00++) {
|
| 3259 |
+
const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
| 3260 |
+
|
| 3261 |
+
dst_ptr[id] = *src0_ptr;
|
| 3262 |
+
id++;
|
| 3263 |
+
}
|
| 3264 |
+
}
|
| 3265 |
+
}
|
| 3266 |
+
}
|
| 3267 |
+
} else {
|
| 3268 |
+
GGML_ASSERT(false); // TODO: implement
|
| 3269 |
+
}
|
| 3270 |
+
}
|
| 3271 |
}
|
| 3272 |
|
| 3273 |
void ggml_compute_forward_dup_f32(
|