am17an commited on
Commit
014494c
·
1 Parent(s): effd61f

CUDA: add bf16 and i32 to getrows (llama/14529)

Browse files
ggml/src/ggml-cuda/getrows.cu CHANGED
@@ -168,6 +168,10 @@ static void ggml_cuda_get_rows_switch_src0_type(
168
  get_rows_cuda_float((const float *) src0_d, src1_d, dst_d,
169
  ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
170
  break;
 
 
 
 
171
  case GGML_TYPE_BF16:
172
  get_rows_cuda_float((const nv_bfloat16 *) src0_d, src1_d, dst_d,
173
  ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
@@ -210,6 +214,10 @@ void get_rows_cuda(
210
  ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (float *) dst_d,
211
  ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
212
  break;
 
 
 
 
213
  case GGML_TYPE_F16:
214
  ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (half *) dst_d,
215
  ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
 
168
  get_rows_cuda_float((const float *) src0_d, src1_d, dst_d,
169
  ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
170
  break;
171
+ case GGML_TYPE_I32:
172
+ get_rows_cuda_float((const int32_t *) src0_d, src1_d, dst_d,
173
+ ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
174
+ break;
175
  case GGML_TYPE_BF16:
176
  get_rows_cuda_float((const nv_bfloat16 *) src0_d, src1_d, dst_d,
177
  ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
 
214
  ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (float *) dst_d,
215
  ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
216
  break;
217
+ case GGML_TYPE_I32:
218
+ ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (int32_t *) dst_d,
219
+ ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
220
+ break;
221
  case GGML_TYPE_F16:
222
  ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (half *) dst_d,
223
  ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
ggml/src/ggml-cuda/ggml-cuda.cu CHANGED
@@ -3200,6 +3200,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3200
  switch (op->src[0]->type) {
3201
  case GGML_TYPE_F16:
3202
  case GGML_TYPE_F32:
 
 
3203
  case GGML_TYPE_Q4_0:
3204
  case GGML_TYPE_Q4_1:
3205
  case GGML_TYPE_Q5_0:
 
3200
  switch (op->src[0]->type) {
3201
  case GGML_TYPE_F16:
3202
  case GGML_TYPE_F32:
3203
+ case GGML_TYPE_BF16:
3204
+ case GGML_TYPE_I32:
3205
  case GGML_TYPE_Q4_0:
3206
  case GGML_TYPE_Q4_1:
3207
  case GGML_TYPE_Q5_0: