Nekotekina commited on
Commit
6201c74
·
1 Parent(s): aca04d5

cuda: add q8_0->f32 cpy operation (llama/9571)

Browse files

llama: enable K-shift for quantized KV cache
It will fail on unsupported backends or quant types.

Files changed (2) hide show
  1. ggml/src/ggml-cuda.cu +3 -0
  2. ggml/src/ggml-cuda/cpy.cu +51 -0
ggml/src/ggml-cuda.cu CHANGED
@@ -2899,6 +2899,9 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
2899
  if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q8_0) {
2900
  return true;
2901
  }
 
 
 
2902
  if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_0) {
2903
  return true;
2904
  }
 
2899
  if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q8_0) {
2900
  return true;
2901
  }
2902
+ if (src0_type == GGML_TYPE_Q8_0 && src1_type == GGML_TYPE_F32) {
2903
+ return true;
2904
+ }
2905
  if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_0) {
2906
  return true;
2907
  }
ggml/src/ggml-cuda/cpy.cu CHANGED
@@ -81,6 +81,17 @@ static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
81
  }
82
  }
83
 
 
 
 
 
 
 
 
 
 
 
 
84
  static __device__ void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) {
85
  const float * xi = (const float *) cxi;
86
  block_q4_0 * dsti = (block_q4_0 *) cdsti;
@@ -288,6 +299,32 @@ static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne,
288
  cpy_blck(cx + x_offset, cdst + dst_offset);
289
  }
290
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291
  static void ggml_cpy_f16_f32_cuda(
292
  const char * cx, char * cdst, const int ne,
293
  const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@@ -329,6 +366,16 @@ static void ggml_cpy_f32_q8_0_cuda(
329
  (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
330
  }
331
 
 
 
 
 
 
 
 
 
 
 
332
  static void ggml_cpy_f32_q4_0_cuda(
333
  const char * cx, char * cdst, const int ne,
334
  const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@@ -437,6 +484,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
437
  ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
438
  } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
439
  ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
 
 
440
  } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
441
  ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
442
  } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
@@ -471,6 +520,8 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
471
  return (void*) cpy_f32_f16<cpy_1_f32_f16>;
472
  } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
473
  return (void*) cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>;
 
 
474
  } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
475
  return (void*) cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>;
476
  } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
 
81
  }
82
  }
83
 
84
+ static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) {
85
+ const block_q8_0 * xi = (const block_q8_0 *) cxi;
86
+ float * dsti = (float *) cdsti;
87
+
88
+ const float d = (float)xi->d;
89
+
90
+ for (int j = 0; j < QK8_0; j++) {
91
+ dsti[j] = xi->qs[j] * d;
92
+ }
93
+ }
94
+
95
  static __device__ void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) {
96
  const float * xi = (const float *) cxi;
97
  block_q4_0 * dsti = (block_q4_0 *) cdsti;
 
299
  cpy_blck(cx + x_offset, cdst + dst_offset);
300
  }
301
 
302
+ template <cpy_kernel_t cpy_blck, int qk>
303
+ static __global__ void cpy_q_f32(const char * cx, char * cdst, const int ne,
304
+ const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
305
+ const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
306
+ const int nb12, const int nb13) {
307
+ const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk;
308
+
309
+ if (i >= ne) {
310
+ return;
311
+ }
312
+
313
+ const int i03 = i/(ne00 * ne01 * ne02);
314
+ const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
315
+ const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
316
+ const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
317
+ const int x_offset = (i00/qk)*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
318
+
319
+ const int i13 = i/(ne10 * ne11 * ne12);
320
+ const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
321
+ const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
322
+ const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
323
+ const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
324
+
325
+ cpy_blck(cx + x_offset, cdst + dst_offset);
326
+ }
327
+
328
  static void ggml_cpy_f16_f32_cuda(
329
  const char * cx, char * cdst, const int ne,
330
  const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
 
366
  (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
367
  }
368
 
369
+ static void ggml_cpy_q8_0_f32_cuda(
370
+ const char * cx, char * cdst, const int ne,
371
+ const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
372
+ const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
373
+
374
+ const int num_blocks = ne;
375
+ cpy_q_f32<cpy_blck_q8_0_f32, QK8_0><<<num_blocks, 1, 0, stream>>>
376
+ (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
377
+ }
378
+
379
  static void ggml_cpy_f32_q4_0_cuda(
380
  const char * cx, char * cdst, const int ne,
381
  const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
 
484
  ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
485
  } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
486
  ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
487
+ } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
488
+ ggml_cpy_q8_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
489
  } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
490
  ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
491
  } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
 
520
  return (void*) cpy_f32_f16<cpy_1_f32_f16>;
521
  } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
522
  return (void*) cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>;
523
+ } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
524
+ return (void*) cpy_q_f32<cpy_blck_q8_0_f32, QK8_0>;
525
  } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
526
  return (void*) cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>;
527
  } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {