Ouadie EL FAROUKI commited on
Commit
ae75124
·
1 Parent(s): 8dade62

Minor arithmetic improvement to mmvq wrapper kernel (llama/7172)

Browse files
Files changed (1) hide show
  1. ggml-sycl.cpp +11 -9
ggml-sycl.cpp CHANGED
@@ -8330,24 +8330,26 @@ static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict_
8330
  const int blocks_per_row = ncols / qk;
8331
  const int blocks_per_warp = vdr * WARP_SIZE / qi;
8332
 
8333
- // partial sum for each thread
 
 
8334
  float tmp = 0.0f;
8335
 
8336
  const block_q_t * x = (const block_q_t *) vx;
8337
  const block_q8_1 * y = (const block_q8_1 *) vy;
8338
 
8339
- for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;
8340
  i += blocks_per_warp) {
8341
- const int ibx = row*blocks_per_row + i; // x block index
8342
 
8343
- const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
8344
 
8345
- const int iqs =
8346
- vdr *
8347
- (item_ct1.get_local_id(2) %
8348
- (qi / vdr)); // x block quant index when casting the quants to int
8349
 
8350
- tmp += vec_dot_q_sycl(&x[ibx], &y[iby], iqs);
8351
  }
8352
 
8353
  // sum up partial sums and write back result
 
8330
  const int blocks_per_row = ncols / qk;
8331
  const int blocks_per_warp = vdr * WARP_SIZE / qi;
8332
 
8333
+ const int qi_vdr = (qi / vdr); // N_threads processing 1 qk block
8334
+
8335
+ // partial sum for each thread
8336
  float tmp = 0.0f;
8337
 
8338
  const block_q_t * x = (const block_q_t *) vx;
8339
  const block_q8_1 * y = (const block_q8_1 *) vy;
8340
 
8341
+ for (int i = item_ct1.get_local_id(2) / qi_vdr; i < blocks_per_row;
8342
  i += blocks_per_warp) {
8343
+ const int ibx = row * blocks_per_row + i; // x block index
8344
 
8345
+ const int iby = i * (qk / QK8_1); // y block index that aligns with ibx
8346
 
8347
+ const int iqs =
8348
+ vdr *
8349
+ (item_ct1.get_local_id(2) -
8350
+ i * qi_vdr); // x block quant index when casting the quants to int
8351
 
8352
+ tmp += vec_dot_q_sycl(&x[ibx], &y[iby], iqs);
8353
  }
8354
 
8355
  // sum up partial sums and write back result