Spaces:
Running
Running
slaren
commited on
Commit
·
2314334
1
Parent(s):
5a33963
cuda : update supports_op for matrix multiplication (llama/8245)
Browse files- ggml/src/ggml-cuda.cu +30 -17
ggml/src/ggml-cuda.cu
CHANGED
|
@@ -2715,27 +2715,40 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
|
| 2715 |
case GGML_OP_MUL_MAT:
|
| 2716 |
case GGML_OP_MUL_MAT_ID:
|
| 2717 |
{
|
| 2718 |
-
struct ggml_tensor * a;
|
| 2719 |
-
struct ggml_tensor * b;
|
| 2720 |
if (op->op == GGML_OP_MUL_MAT) {
|
| 2721 |
-
|
| 2722 |
-
|
| 2723 |
-
} else {
|
| 2724 |
-
a = op->src[2];
|
| 2725 |
-
b = op->src[1];
|
| 2726 |
-
}
|
| 2727 |
-
if (a->ne[3] != b->ne[3]) {
|
| 2728 |
-
return false;
|
| 2729 |
-
}
|
| 2730 |
-
ggml_type a_type = a->type;
|
| 2731 |
-
if (a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS || a_type == GGML_TYPE_IQ3_XXS ||
|
| 2732 |
-
a_type == GGML_TYPE_IQ1_S || a_type == GGML_TYPE_IQ4_NL || a_type == GGML_TYPE_IQ3_S ||
|
| 2733 |
-
a_type == GGML_TYPE_IQ1_M || a_type == GGML_TYPE_IQ2_S || a_type == GGML_TYPE_IQ4_XS) {
|
| 2734 |
-
if (b->ne[1] == 1 && ggml_nrows(b) > 1) {
|
| 2735 |
return false;
|
| 2736 |
}
|
| 2737 |
}
|
| 2738 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2739 |
} break;
|
| 2740 |
case GGML_OP_GET_ROWS:
|
| 2741 |
{
|
|
|
|
| 2715 |
case GGML_OP_MUL_MAT:
|
| 2716 |
case GGML_OP_MUL_MAT_ID:
|
| 2717 |
{
|
| 2718 |
+
struct ggml_tensor * a = op->src[0];
|
|
|
|
| 2719 |
if (op->op == GGML_OP_MUL_MAT) {
|
| 2720 |
+
struct ggml_tensor * b = op->src[1];
|
| 2721 |
+
if (a->ne[3] != b->ne[3]) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2722 |
return false;
|
| 2723 |
}
|
| 2724 |
}
|
| 2725 |
+
switch (a->type) {
|
| 2726 |
+
case GGML_TYPE_F32:
|
| 2727 |
+
case GGML_TYPE_F16:
|
| 2728 |
+
case GGML_TYPE_Q4_0:
|
| 2729 |
+
case GGML_TYPE_Q4_1:
|
| 2730 |
+
case GGML_TYPE_Q5_0:
|
| 2731 |
+
case GGML_TYPE_Q5_1:
|
| 2732 |
+
case GGML_TYPE_Q8_0:
|
| 2733 |
+
case GGML_TYPE_Q2_K:
|
| 2734 |
+
case GGML_TYPE_Q3_K:
|
| 2735 |
+
case GGML_TYPE_Q4_K:
|
| 2736 |
+
case GGML_TYPE_Q5_K:
|
| 2737 |
+
case GGML_TYPE_Q6_K:
|
| 2738 |
+
case GGML_TYPE_Q8_K:
|
| 2739 |
+
case GGML_TYPE_IQ1_M:
|
| 2740 |
+
case GGML_TYPE_IQ1_S:
|
| 2741 |
+
case GGML_TYPE_IQ2_S:
|
| 2742 |
+
case GGML_TYPE_IQ2_XS:
|
| 2743 |
+
case GGML_TYPE_IQ2_XXS:
|
| 2744 |
+
case GGML_TYPE_IQ3_S:
|
| 2745 |
+
case GGML_TYPE_IQ3_XXS:
|
| 2746 |
+
case GGML_TYPE_IQ4_NL:
|
| 2747 |
+
case GGML_TYPE_IQ4_XS:
|
| 2748 |
+
return true;
|
| 2749 |
+
default:
|
| 2750 |
+
return false;
|
| 2751 |
+
}
|
| 2752 |
} break;
|
| 2753 |
case GGML_OP_GET_ROWS:
|
| 2754 |
{
|