slaren commited on
Commit
2314334
·
1 Parent(s): 5a33963

cuda : update supports_op for matrix multiplication (llama/8245)

Browse files
Files changed (1) hide show
  1. ggml/src/ggml-cuda.cu +30 -17
ggml/src/ggml-cuda.cu CHANGED
@@ -2715,27 +2715,40 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
2715
  case GGML_OP_MUL_MAT:
2716
  case GGML_OP_MUL_MAT_ID:
2717
  {
2718
- struct ggml_tensor * a;
2719
- struct ggml_tensor * b;
2720
  if (op->op == GGML_OP_MUL_MAT) {
2721
- a = op->src[0];
2722
- b = op->src[1];
2723
- } else {
2724
- a = op->src[2];
2725
- b = op->src[1];
2726
- }
2727
- if (a->ne[3] != b->ne[3]) {
2728
- return false;
2729
- }
2730
- ggml_type a_type = a->type;
2731
- if (a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS || a_type == GGML_TYPE_IQ3_XXS ||
2732
- a_type == GGML_TYPE_IQ1_S || a_type == GGML_TYPE_IQ4_NL || a_type == GGML_TYPE_IQ3_S ||
2733
- a_type == GGML_TYPE_IQ1_M || a_type == GGML_TYPE_IQ2_S || a_type == GGML_TYPE_IQ4_XS) {
2734
- if (b->ne[1] == 1 && ggml_nrows(b) > 1) {
2735
  return false;
2736
  }
2737
  }
2738
- return true;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2739
  } break;
2740
  case GGML_OP_GET_ROWS:
2741
  {
 
2715
  case GGML_OP_MUL_MAT:
2716
  case GGML_OP_MUL_MAT_ID:
2717
  {
2718
+ struct ggml_tensor * a = op->src[0];
 
2719
  if (op->op == GGML_OP_MUL_MAT) {
2720
+ struct ggml_tensor * b = op->src[1];
2721
+ if (a->ne[3] != b->ne[3]) {
 
 
 
 
 
 
 
 
 
 
 
 
2722
  return false;
2723
  }
2724
  }
2725
+ switch (a->type) {
2726
+ case GGML_TYPE_F32:
2727
+ case GGML_TYPE_F16:
2728
+ case GGML_TYPE_Q4_0:
2729
+ case GGML_TYPE_Q4_1:
2730
+ case GGML_TYPE_Q5_0:
2731
+ case GGML_TYPE_Q5_1:
2732
+ case GGML_TYPE_Q8_0:
2733
+ case GGML_TYPE_Q2_K:
2734
+ case GGML_TYPE_Q3_K:
2735
+ case GGML_TYPE_Q4_K:
2736
+ case GGML_TYPE_Q5_K:
2737
+ case GGML_TYPE_Q6_K:
2738
+ case GGML_TYPE_Q8_K:
2739
+ case GGML_TYPE_IQ1_M:
2740
+ case GGML_TYPE_IQ1_S:
2741
+ case GGML_TYPE_IQ2_S:
2742
+ case GGML_TYPE_IQ2_XS:
2743
+ case GGML_TYPE_IQ2_XXS:
2744
+ case GGML_TYPE_IQ3_S:
2745
+ case GGML_TYPE_IQ3_XXS:
2746
+ case GGML_TYPE_IQ4_NL:
2747
+ case GGML_TYPE_IQ4_XS:
2748
+ return true;
2749
+ default:
2750
+ return false;
2751
+ }
2752
  } break;
2753
  case GGML_OP_GET_ROWS:
2754
  {