Spaces:
Sleeping
Sleeping
CUDA: add bf16 and i32 to getrows (llama/14529)
Browse files
ggml/src/ggml-cuda/getrows.cu
CHANGED
|
@@ -168,6 +168,10 @@ static void ggml_cuda_get_rows_switch_src0_type(
|
|
| 168 |
get_rows_cuda_float((const float *) src0_d, src1_d, dst_d,
|
| 169 |
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
|
| 170 |
break;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
case GGML_TYPE_BF16:
|
| 172 |
get_rows_cuda_float((const nv_bfloat16 *) src0_d, src1_d, dst_d,
|
| 173 |
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
|
|
@@ -210,6 +214,10 @@ void get_rows_cuda(
|
|
| 210 |
ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (float *) dst_d,
|
| 211 |
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
|
| 212 |
break;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 213 |
case GGML_TYPE_F16:
|
| 214 |
ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (half *) dst_d,
|
| 215 |
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
|
|
|
|
| 168 |
get_rows_cuda_float((const float *) src0_d, src1_d, dst_d,
|
| 169 |
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
|
| 170 |
break;
|
| 171 |
+
case GGML_TYPE_I32:
|
| 172 |
+
get_rows_cuda_float((const int32_t *) src0_d, src1_d, dst_d,
|
| 173 |
+
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
|
| 174 |
+
break;
|
| 175 |
case GGML_TYPE_BF16:
|
| 176 |
get_rows_cuda_float((const nv_bfloat16 *) src0_d, src1_d, dst_d,
|
| 177 |
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
|
|
|
|
| 214 |
ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (float *) dst_d,
|
| 215 |
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
|
| 216 |
break;
|
| 217 |
+
case GGML_TYPE_I32:
|
| 218 |
+
ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (int32_t *) dst_d,
|
| 219 |
+
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
|
| 220 |
+
break;
|
| 221 |
case GGML_TYPE_F16:
|
| 222 |
ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (half *) dst_d,
|
| 223 |
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
|
ggml/src/ggml-cuda/ggml-cuda.cu
CHANGED
|
@@ -3200,6 +3200,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|
| 3200 |
switch (op->src[0]->type) {
|
| 3201 |
case GGML_TYPE_F16:
|
| 3202 |
case GGML_TYPE_F32:
|
|
|
|
|
|
|
| 3203 |
case GGML_TYPE_Q4_0:
|
| 3204 |
case GGML_TYPE_Q4_1:
|
| 3205 |
case GGML_TYPE_Q5_0:
|
|
|
|
| 3200 |
switch (op->src[0]->type) {
|
| 3201 |
case GGML_TYPE_F16:
|
| 3202 |
case GGML_TYPE_F32:
|
| 3203 |
+
case GGML_TYPE_BF16:
|
| 3204 |
+
case GGML_TYPE_I32:
|
| 3205 |
case GGML_TYPE_Q4_0:
|
| 3206 |
case GGML_TYPE_Q4_1:
|
| 3207 |
case GGML_TYPE_Q5_0:
|