ggerganov commited on
Commit
f8e8d34
·
unverified ·
1 Parent(s): 3c39d4b

sync : llama.cpp (ggml/0)

Browse files
Files changed (8) hide show
  1. examples/common-ggml.cpp +2 -0
  2. ggml-cuda.cu +97 -1
  3. ggml-metal.m +35 -0
  4. ggml-metal.metal +214 -1
  5. ggml-quants.c +233 -1
  6. ggml-quants.h +13 -0
  7. ggml.c +30 -0
  8. ggml.h +2 -0
examples/common-ggml.cpp CHANGED
@@ -66,6 +66,7 @@ bool ggml_common_quantize_0(
66
  case GGML_FTYPE_MOSTLY_IQ2_XS:
67
  case GGML_FTYPE_MOSTLY_IQ3_XXS:
68
  case GGML_FTYPE_MOSTLY_IQ1_S:
 
69
  {
70
  fprintf(stderr, "%s: invalid model type %d\n", __func__, ftype);
71
  return false;
@@ -199,6 +200,7 @@ bool ggml_common_quantize_0(
199
  case GGML_TYPE_IQ2_XS:
200
  case GGML_TYPE_IQ3_XXS:
201
  case GGML_TYPE_IQ1_S:
 
202
  case GGML_TYPE_COUNT:
203
  {
204
  fprintf(stderr, "%s: unsupported quantization type %d (%s)\n", __func__, ttype, ggml_type_name((ggml_type) ttype));
 
66
  case GGML_FTYPE_MOSTLY_IQ2_XS:
67
  case GGML_FTYPE_MOSTLY_IQ3_XXS:
68
  case GGML_FTYPE_MOSTLY_IQ1_S:
69
+ case GGML_FTYPE_MOSTLY_IQ4_NL:
70
  {
71
  fprintf(stderr, "%s: invalid model type %d\n", __func__, ftype);
72
  return false;
 
200
  case GGML_TYPE_IQ2_XS:
201
  case GGML_TYPE_IQ3_XXS:
202
  case GGML_TYPE_IQ1_S:
203
+ case GGML_TYPE_IQ4_NL:
204
  case GGML_TYPE_COUNT:
205
  {
206
  fprintf(stderr, "%s: unsupported quantization type %d (%s)\n", __func__, ttype, ggml_type_name((ggml_type) ttype));
ggml-cuda.cu CHANGED
@@ -528,6 +528,15 @@ typedef struct {
528
  } block_iq1_s;
529
  static_assert(sizeof(block_iq1_s) == sizeof(ggml_fp16_t) + QK_K/8 + QK_K/16, "wrong iq1_s block size/padding");
530
 
 
 
 
 
 
 
 
 
 
531
  #define WARP_SIZE 32
532
  #define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
533
 
@@ -1987,6 +1996,26 @@ static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_
1987
 
1988
  }
1989
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1990
 
1991
  static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
1992
 
@@ -4732,6 +4761,56 @@ static __device__ __forceinline__ float vec_dot_iq1_s_q8_1(
4732
  #endif
4733
  }
4734
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4735
  template <int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x, int mmq_y, int nwarps,
4736
  allocate_tiles_cuda_t allocate_tiles, load_tiles_cuda_t load_tiles, int vdr, vec_dot_q_mul_mat_cuda_t vec_dot>
4737
  static __device__ __forceinline__ void mul_mat_q(
@@ -6777,6 +6856,12 @@ static void dequantize_row_iq1_s_cuda(const void * vx, dst_t * y, const int k, c
6777
  dequantize_block_iq1_s<<<nb, 32, 0, stream>>>(vx, y);
6778
  }
6779
 
 
 
 
 
 
 
6780
  template <typename src_t, typename dst_t>
6781
  static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) {
6782
  const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
@@ -6818,6 +6903,8 @@ static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
6818
  return dequantize_row_iq3_xxs_cuda;
6819
  case GGML_TYPE_IQ1_S:
6820
  return dequantize_row_iq1_s_cuda;
 
 
6821
  case GGML_TYPE_F32:
6822
  return convert_unary_cuda<float>;
6823
  default:
@@ -6855,6 +6942,8 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
6855
  return dequantize_row_iq3_xxs_cuda;
6856
  case GGML_TYPE_IQ1_S:
6857
  return dequantize_row_iq1_s_cuda;
 
 
6858
  case GGML_TYPE_F16:
6859
  return convert_unary_cuda<half>;
6860
  default:
@@ -8599,6 +8688,7 @@ static int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_CUD
8599
  case GGML_TYPE_IQ2_XS:
8600
  case GGML_TYPE_IQ3_XXS:
8601
  case GGML_TYPE_IQ1_S:
 
8602
  return max_compute_capability >= CC_RDNA2 ? 128 : 64;
8603
  default:
8604
  GGML_ASSERT(false);
@@ -8623,6 +8713,7 @@ static int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_CUD
8623
  case GGML_TYPE_IQ2_XS:
8624
  case GGML_TYPE_IQ3_XXS:
8625
  case GGML_TYPE_IQ1_S:
 
8626
  return max_compute_capability >= CC_VOLTA ? 128 : 64;
8627
  case GGML_TYPE_Q6_K:
8628
  return 64;
@@ -8724,6 +8815,10 @@ static void ggml_cuda_op_mul_mat_vec_q(
8724
  mul_mat_vec_q_cuda<QK_K, QI1_S, block_iq1_s, 1, vec_dot_iq1_s_q8_1>
8725
  (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
8726
  break;
 
 
 
 
8727
  default:
8728
  GGML_ASSERT(false);
8729
  break;
@@ -11446,7 +11541,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
11446
  return false;
11447
  }
11448
  ggml_type a_type = a->type;
11449
- if (a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS || a_type == GGML_TYPE_IQ3_XXS || a_type == GGML_TYPE_IQ1_S) {
 
11450
  if (b->ne[1] == 1 && ggml_nrows(b) > 1) {
11451
  return false;
11452
  }
 
528
  } block_iq1_s;
529
  static_assert(sizeof(block_iq1_s) == sizeof(ggml_fp16_t) + QK_K/8 + QK_K/16, "wrong iq1_s block size/padding");
530
 
531
+ #define QK4_NL 32
532
+ #define QR4_NL 2
533
+ #define QI4_NL (QK4_NL / (4*QR4_NL))
534
+ typedef struct {
535
+ half d;
536
+ uint8_t qs[QK4_NL/2];
537
+ } block_iq4_nl;
538
+ static_assert(sizeof(block_iq4_nl) == sizeof(ggml_fp16_t) + QK4_NL/2, "wrong iq4_nl block size/padding");
539
+
540
  #define WARP_SIZE 32
541
  #define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
542
 
 
1996
 
1997
  }
1998
 
1999
+ static const __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
2000
+
2001
+ template<typename dst_t>
2002
+ static __global__ void dequantize_block_iq4_nl(const void * __restrict__ vx, dst_t * __restrict__ yy) {
2003
+
2004
+ const int i = blockIdx.x;
2005
+ const block_iq4_nl * x = (const block_iq4_nl *) vx + i*(QK_K/QK4_NL);
2006
+
2007
+ const int tid = threadIdx.x;
2008
+ const int il = tid/8; // 0...3
2009
+ const int ib = tid%8; // 0...7
2010
+ dst_t * y = yy + i*QK_K + 32*ib + 4*il;
2011
+ const uint8_t * q4 = x[ib].qs + 4*il;
2012
+ const float d = (float)x[ib].d;
2013
+ for (int j = 0; j < 4; ++j) {
2014
+ y[j+ 0] = d * kvalues_iq4nl[q4[j] & 0xf];
2015
+ y[j+16] = d * kvalues_iq4nl[q4[j] >> 4];
2016
+ }
2017
+
2018
+ }
2019
 
2020
  static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
2021
 
 
4761
  #endif
4762
  }
4763
 
4764
+ #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
4765
+ static __device__ __forceinline__ void get_int_from_table_16(const uint32_t & q4, const uint8_t * values,
4766
+ int & val1, int & val2) {
4767
+
4768
+ uint32_t aux32; const uint8_t * q8 = (const uint8_t *)&aux32;
4769
+ aux32 = q4 & 0x0f0f0f0f;
4770
+ uint16_t v1 = values[q8[0]] | (values[q8[1]] << 8);
4771
+ uint16_t v2 = values[q8[2]] | (values[q8[3]] << 8);
4772
+ val1 = v1 | (v2 << 16);
4773
+ aux32 = (q4 >> 4) & 0x0f0f0f0f;
4774
+ v1 = values[q8[0]] | (values[q8[1]] << 8);
4775
+ v2 = values[q8[2]] | (values[q8[3]] << 8);
4776
+ val2 = v1 | (v2 << 16);
4777
+ }
4778
+ #endif
4779
+
4780
+ static __device__ __forceinline__ float vec_dot_iq4_nl_q8_1(
4781
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
4782
+
4783
+ const block_iq4_nl * bq = (const block_iq4_nl *) vbq;
4784
+
4785
+ #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
4786
+ const uint16_t * q4 = (const uint16_t *)bq->qs + 2*iqs;
4787
+ const int32_t * q8 = (const int32_t *)bq8_1->qs + iqs;
4788
+
4789
+ const uint8_t * values = (const uint8_t *)kvalues_iq4nl;
4790
+
4791
+ int v1, v2;
4792
+ int sumi1 = 0, sumi2 = 0;
4793
+ for (int l = 0; l < VDR_Q4_0_Q8_1_MMVQ; ++l) {
4794
+ const uint32_t aux = q4[2*l] | (q4[2*l+1] << 16);
4795
+ get_int_from_table_16(aux, values, v1, v2);
4796
+ sumi1 = __dp4a(v1, q8[l+0], sumi1);
4797
+ sumi2 = __dp4a(v2, q8[l+4], sumi2);
4798
+ }
4799
+
4800
+ #else
4801
+ const uint8_t * q4 = bq->qs + 4*iqs;
4802
+ const int8_t * q8 = bq8_1->qs + 4*iqs;
4803
+
4804
+ int sumi1 = 0, sumi2 = 0;
4805
+ for (int l = 0; l < 4*VDR_Q4_0_Q8_1_MMVQ; ++l) {
4806
+ sumi1 += q8[l+ 0] * kvalues_iq4nl[q4[l] & 0xf];
4807
+ sumi2 += q8[l+16] * kvalues_iq4nl[q4[l] >> 4];
4808
+ }
4809
+ #endif
4810
+ const float d = (float)bq->d * __low2float(bq8_1->ds);
4811
+ return d * (sumi1 + sumi2);
4812
+ }
4813
+
4814
  template <int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x, int mmq_y, int nwarps,
4815
  allocate_tiles_cuda_t allocate_tiles, load_tiles_cuda_t load_tiles, int vdr, vec_dot_q_mul_mat_cuda_t vec_dot>
4816
  static __device__ __forceinline__ void mul_mat_q(
 
6856
  dequantize_block_iq1_s<<<nb, 32, 0, stream>>>(vx, y);
6857
  }
6858
 
6859
+ template<typename dst_t>
6860
+ static void dequantize_row_iq4_nl_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
6861
+ const int nb = (k + QK_K - 1) / QK_K;
6862
+ dequantize_block_iq4_nl<<<nb, 32, 0, stream>>>(vx, y);
6863
+ }
6864
+
6865
  template <typename src_t, typename dst_t>
6866
  static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) {
6867
  const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
 
6903
  return dequantize_row_iq3_xxs_cuda;
6904
  case GGML_TYPE_IQ1_S:
6905
  return dequantize_row_iq1_s_cuda;
6906
+ case GGML_TYPE_IQ4_NL:
6907
+ return dequantize_row_iq4_nl_cuda;
6908
  case GGML_TYPE_F32:
6909
  return convert_unary_cuda<float>;
6910
  default:
 
6942
  return dequantize_row_iq3_xxs_cuda;
6943
  case GGML_TYPE_IQ1_S:
6944
  return dequantize_row_iq1_s_cuda;
6945
+ case GGML_TYPE_IQ4_NL:
6946
+ return dequantize_row_iq4_nl_cuda;
6947
  case GGML_TYPE_F16:
6948
  return convert_unary_cuda<half>;
6949
  default:
 
8688
  case GGML_TYPE_IQ2_XS:
8689
  case GGML_TYPE_IQ3_XXS:
8690
  case GGML_TYPE_IQ1_S:
8691
+ case GGML_TYPE_IQ4_NL:
8692
  return max_compute_capability >= CC_RDNA2 ? 128 : 64;
8693
  default:
8694
  GGML_ASSERT(false);
 
8713
  case GGML_TYPE_IQ2_XS:
8714
  case GGML_TYPE_IQ3_XXS:
8715
  case GGML_TYPE_IQ1_S:
8716
+ case GGML_TYPE_IQ4_NL:
8717
  return max_compute_capability >= CC_VOLTA ? 128 : 64;
8718
  case GGML_TYPE_Q6_K:
8719
  return 64;
 
8815
  mul_mat_vec_q_cuda<QK_K, QI1_S, block_iq1_s, 1, vec_dot_iq1_s_q8_1>
8816
  (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
8817
  break;
8818
+ case GGML_TYPE_IQ4_NL:
8819
+ mul_mat_vec_q_cuda<QK4_NL, QI4_NL, block_iq4_nl, VDR_Q4_0_Q8_1_MMVQ, vec_dot_iq4_nl_q8_1>
8820
+ (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
8821
+ break;
8822
  default:
8823
  GGML_ASSERT(false);
8824
  break;
 
11541
  return false;
11542
  }
11543
  ggml_type a_type = a->type;
11544
+ if (a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS || a_type == GGML_TYPE_IQ3_XXS ||
11545
+ a_type == GGML_TYPE_IQ1_S || a_type == GGML_TYPE_IQ4_NL) {
11546
  if (b->ne[1] == 1 && ggml_nrows(b) > 1) {
11547
  return false;
11548
  }
ggml-metal.m CHANGED
@@ -62,6 +62,7 @@ enum ggml_metal_kernel_type {
62
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS,
63
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS,
64
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S,
 
65
  GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
66
  GGML_METAL_KERNEL_TYPE_RMS_NORM,
67
  GGML_METAL_KERNEL_TYPE_GROUP_NORM,
@@ -85,6 +86,7 @@ enum ggml_metal_kernel_type {
85
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32,
86
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32,
87
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32,
 
88
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
89
  //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,
90
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,
@@ -104,6 +106,7 @@ enum ggml_metal_kernel_type {
104
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32,
105
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32,
106
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32,
 
107
  GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
108
  GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,
109
  GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,
@@ -120,6 +123,7 @@ enum ggml_metal_kernel_type {
120
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32,
121
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32,
122
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32,
 
123
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
124
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,
125
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,
@@ -136,6 +140,7 @@ enum ggml_metal_kernel_type {
136
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,
137
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32,
138
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32,
 
139
  GGML_METAL_KERNEL_TYPE_ROPE_F32,
140
  GGML_METAL_KERNEL_TYPE_ROPE_F16,
141
  GGML_METAL_KERNEL_TYPE_ALIBI_F32,
@@ -448,6 +453,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
448
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true);
449
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, get_rows_iq3_xxs, true);
450
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true);
 
451
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
452
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction);
453
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction);
@@ -471,6 +477,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
471
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction);
472
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, ctx->support_simdgroup_reduction);
473
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, ctx->support_simdgroup_reduction);
 
474
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction);
475
  //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction);
476
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction);
@@ -490,6 +497,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
490
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction);
491
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, ctx->support_simdgroup_reduction);
492
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, ctx->support_simdgroup_reduction);
 
493
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm);
494
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm);
495
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm);
@@ -506,6 +514,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
506
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm);
507
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, ctx->support_simdgroup_mm);
508
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, ctx->support_simdgroup_mm);
 
509
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm);
510
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm);
511
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm);
@@ -522,6 +531,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
522
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm);
523
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, ctx->support_simdgroup_mm);
524
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, ctx->support_simdgroup_mm);
 
525
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
526
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
527
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true);
@@ -1338,6 +1348,7 @@ static bool ggml_metal_graph_compute(
1338
  case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break;
1339
  case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32].pipeline; break;
1340
  case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32 ].pipeline; break;
 
1341
  default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
1342
  }
1343
 
@@ -1478,6 +1489,12 @@ static bool ggml_metal_graph_compute(
1478
  nth1 = 16;
1479
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32].pipeline;
1480
  } break;
 
 
 
 
 
 
1481
  default:
1482
  {
1483
  GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
@@ -1525,6 +1542,11 @@ static bool ggml_metal_graph_compute(
1525
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1526
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1527
  }
 
 
 
 
 
1528
  else if (src0t == GGML_TYPE_Q4_K) {
1529
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1530
  }
@@ -1619,6 +1641,7 @@ static bool ggml_metal_graph_compute(
1619
  case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break;
1620
  case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32].pipeline; break;
1621
  case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32 ].pipeline; break;
 
1622
  default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
1623
  }
1624
 
@@ -1762,6 +1785,12 @@ static bool ggml_metal_graph_compute(
1762
  nth1 = 16;
1763
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32].pipeline;
1764
  } break;
 
 
 
 
 
 
1765
  default:
1766
  {
1767
  GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t);
@@ -1825,6 +1854,11 @@ static bool ggml_metal_graph_compute(
1825
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1826
  [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1827
  }
 
 
 
 
 
1828
  else if (src2t == GGML_TYPE_Q4_K) {
1829
  [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1830
  }
@@ -1867,6 +1901,7 @@ static bool ggml_metal_graph_compute(
1867
  case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break;
1868
  case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS].pipeline; break;
1869
  case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S ].pipeline; break;
 
1870
  case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break;
1871
  default: GGML_ASSERT(false && "not implemented");
1872
  }
 
62
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS,
63
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS,
64
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S,
65
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,
66
  GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
67
  GGML_METAL_KERNEL_TYPE_RMS_NORM,
68
  GGML_METAL_KERNEL_TYPE_GROUP_NORM,
 
86
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32,
87
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32,
88
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32,
89
+ GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32,
90
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
91
  //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,
92
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,
 
106
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32,
107
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32,
108
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32,
109
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32,
110
  GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
111
  GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,
112
  GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,
 
123
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32,
124
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32,
125
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32,
126
+ GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,
127
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
128
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,
129
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,
 
140
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,
141
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32,
142
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32,
143
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,
144
  GGML_METAL_KERNEL_TYPE_ROPE_F32,
145
  GGML_METAL_KERNEL_TYPE_ROPE_F16,
146
  GGML_METAL_KERNEL_TYPE_ALIBI_F32,
 
453
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true);
454
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, get_rows_iq3_xxs, true);
455
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true);
456
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
457
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
458
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction);
459
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction);
 
477
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction);
478
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, ctx->support_simdgroup_reduction);
479
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, ctx->support_simdgroup_reduction);
480
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, ctx->support_simdgroup_reduction);
481
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction);
482
  //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction);
483
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction);
 
497
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction);
498
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, ctx->support_simdgroup_reduction);
499
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, ctx->support_simdgroup_reduction);
500
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, ctx->support_simdgroup_reduction);
501
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm);
502
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm);
503
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm);
 
514
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm);
515
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, ctx->support_simdgroup_mm);
516
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, ctx->support_simdgroup_mm);
517
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, ctx->support_simdgroup_mm);
518
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm);
519
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm);
520
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm);
 
531
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm);
532
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, ctx->support_simdgroup_mm);
533
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, ctx->support_simdgroup_mm);
534
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm);
535
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
536
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
537
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true);
 
1348
  case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break;
1349
  case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32].pipeline; break;
1350
  case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32 ].pipeline; break;
1351
+ case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break;
1352
  default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
1353
  }
1354
 
 
1489
  nth1 = 16;
1490
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32].pipeline;
1491
  } break;
1492
+ case GGML_TYPE_IQ4_NL:
1493
+ {
1494
+ nth0 = 4;
1495
+ nth1 = 16;
1496
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32].pipeline;
1497
+ } break;
1498
  default:
1499
  {
1500
  GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
 
1542
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1543
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1544
  }
1545
+ else if (src0t == GGML_TYPE_IQ4_NL) {
1546
+ const int mem_size = 32*sizeof(float);
1547
+ [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1548
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1549
+ }
1550
  else if (src0t == GGML_TYPE_Q4_K) {
1551
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1552
  }
 
1641
  case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break;
1642
  case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32].pipeline; break;
1643
  case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32 ].pipeline; break;
1644
+ case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break;
1645
  default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
1646
  }
1647
 
 
1785
  nth1 = 16;
1786
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32].pipeline;
1787
  } break;
1788
+ case GGML_TYPE_IQ4_NL:
1789
+ {
1790
+ nth0 = 4;
1791
+ nth1 = 16;
1792
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline;
1793
+ } break;
1794
  default:
1795
  {
1796
  GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t);
 
1854
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1855
  [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1856
  }
1857
+ else if (src2t == GGML_TYPE_IQ4_NL) {
1858
+ const int mem_size = 32*sizeof(float);
1859
+ [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1860
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1861
+ }
1862
  else if (src2t == GGML_TYPE_Q4_K) {
1863
  [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1864
  }
 
1901
  case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break;
1902
  case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS].pipeline; break;
1903
  case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S ].pipeline; break;
1904
+ case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL ].pipeline; break;
1905
  case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break;
1906
  default: GGML_ASSERT(false && "not implemented");
1907
  }
ggml-metal.metal CHANGED
@@ -2531,6 +2531,12 @@ typedef struct {
2531
  uint8_t scales[QK_K/16];
2532
  } block_iq1_s;
2533
 
 
 
 
 
 
 
2534
 
2535
  //====================================== dot products =========================
2536
 
@@ -4384,7 +4390,6 @@ void kernel_mul_mv_iq1_s_f32_impl(
4384
  const uint i13 = im/ne12;
4385
 
4386
  const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
4387
-
4388
  device const block_iq1_s * x = (device const block_iq1_s *) src0 + ib_row + offset0;
4389
  device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
4390
 
@@ -4447,6 +4452,103 @@ void kernel_mul_mv_iq1_s_f32_impl(
4447
  }
4448
  }
4449
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4450
  [[host_name("kernel_mul_mv_iq1_s_f32")]]
4451
  kernel void kernel_mul_mv_iq1_s_f32(
4452
  device const void * src0,
@@ -4475,6 +4577,34 @@ kernel void kernel_mul_mv_iq1_s_f32(
4475
  kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
4476
  }
4477
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4478
 
4479
  //============================= templates and their specializations =============================
4480
 
@@ -4838,6 +4968,21 @@ void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 &
4838
  }
4839
  }
4840
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4841
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
4842
  kernel void kernel_get_rows(
4843
  device const void * src0,
@@ -5381,6 +5526,7 @@ template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_t kernel_get_r
5381
  template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_t kernel_get_rows<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
5382
  template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_t kernel_get_rows<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
5383
  template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows<block_iq1_s, QK_NL, dequantize_iq1_s>;
 
5384
 
5385
  //
5386
  // matrix-matrix multiplication
@@ -5421,6 +5567,7 @@ template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_m
5421
  template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
5422
  template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
5423
  template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_s, QK_NL, dequantize_iq1_s>;
 
5424
 
5425
  //
5426
  // indirect matrix-matrix multiplication
@@ -5473,6 +5620,7 @@ template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel
5473
  template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
5474
  template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
5475
  template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_s, QK_NL, dequantize_iq1_s>;
 
5476
 
5477
  //
5478
  // matrix-vector multiplication
@@ -6503,3 +6651,68 @@ kernel void kernel_mul_mv_id_iq1_s_f32(
6503
  tiisg,
6504
  sgitg);
6505
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2531
  uint8_t scales[QK_K/16];
2532
  } block_iq1_s;
2533
 
2534
+ // Non-linear quants
2535
+ #define QK4_NL 32
2536
+ typedef struct {
2537
+ half d;
2538
+ uint8_t qs[QK4_NL/2];
2539
+ } block_iq4_nl;
2540
 
2541
  //====================================== dot products =========================
2542
 
 
4390
  const uint i13 = im/ne12;
4391
 
4392
  const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
 
4393
  device const block_iq1_s * x = (device const block_iq1_s *) src0 + ib_row + offset0;
4394
  device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
4395
 
 
4452
  }
4453
  }
4454
 
4455
+ constexpr constant static float kvalues_iq4nl_f[16] = {
4456
+ -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
4457
+ };
4458
+
4459
+ void kernel_mul_mv_iq4_nl_f32_impl(
4460
+ device const void * src0,
4461
+ device const float * src1,
4462
+ device float * dst,
4463
+ constant int64_t & ne00,
4464
+ constant int64_t & ne01,
4465
+ constant int64_t & ne02,
4466
+ constant int64_t & ne10,
4467
+ constant int64_t & ne12,
4468
+ constant int64_t & ne0,
4469
+ constant int64_t & ne1,
4470
+ constant uint & r2,
4471
+ constant uint & r3,
4472
+ threadgroup float * shared_values [[threadgroup(0)]],
4473
+ uint3 tgpig[[threadgroup_position_in_grid]],
4474
+ uint tiisg[[thread_index_in_simdgroup]],
4475
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4476
+
4477
+ const int nb = ne00/QK4_NL;
4478
+ const int r0 = tgpig.x;
4479
+ const int r1 = tgpig.y;
4480
+ const int im = tgpig.z;
4481
+ const int first_row = (r0 * 2 + sgitg) * 2;
4482
+ const int ib_row = first_row * nb;
4483
+
4484
+ const uint i12 = im%ne12;
4485
+ const uint i13 = im/ne12;
4486
+
4487
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
4488
+ device const block_iq4_nl * x = (device const block_iq4_nl *) src0 + ib_row + offset0;
4489
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
4490
+
4491
+ const int ix = tiisg/2; // 0...15
4492
+ const int it = tiisg%2; // 0 or 1
4493
+
4494
+ shared_values[tiisg] = kvalues_iq4nl_f[tiisg%16];
4495
+ threadgroup_barrier(mem_flags::mem_threadgroup);
4496
+
4497
+ float4 yl[4];
4498
+ float sumf[2]={0.f}, all_sum;
4499
+
4500
+ device const float * yb = y + ix * QK4_NL + it * 8;
4501
+
4502
+ uint32_t aux32[2];
4503
+ thread const uint8_t * q8 = (thread const uint8_t *)aux32;
4504
+
4505
+ float4 qf1, qf2;
4506
+
4507
+ for (int ib = ix; ib < nb; ib += 16) {
4508
+
4509
+ device const float4 * y4 = (device const float4 *)yb;
4510
+ yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5];
4511
+
4512
+ for (int row = 0; row < 2; ++row) {
4513
+
4514
+ device const block_iq4_nl & xb = x[row*nb + ib];
4515
+ device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8*it);
4516
+
4517
+ float4 acc1 = {0.f}, acc2 = {0.f};
4518
+
4519
+ aux32[0] = q4[0] | (q4[1] << 16);
4520
+ aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f;
4521
+ aux32[0] &= 0x0f0f0f0f;
4522
+ qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
4523
+ qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
4524
+ acc1 += yl[0] * qf1;
4525
+ acc2 += yl[1] * qf2;
4526
+
4527
+ aux32[0] = q4[2] | (q4[3] << 16);
4528
+ aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f;
4529
+ aux32[0] &= 0x0f0f0f0f;
4530
+ qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
4531
+ qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
4532
+ acc1 += yl[2] * qf1;
4533
+ acc2 += yl[3] * qf2;
4534
+
4535
+ acc1 += acc2;
4536
+
4537
+ sumf[row] += (float)xb.d * (acc1[0] + acc1[1] + acc1[2] + acc1[3]);
4538
+
4539
+ }
4540
+
4541
+ yb += 16 * QK4_NL;
4542
+ }
4543
+
4544
+ for (int row = 0; row < 2; ++row) {
4545
+ all_sum = simd_sum(sumf[row]);
4546
+ if (tiisg == 0) {
4547
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
4548
+ }
4549
+ }
4550
+ }
4551
+
4552
  [[host_name("kernel_mul_mv_iq1_s_f32")]]
4553
  kernel void kernel_mul_mv_iq1_s_f32(
4554
  device const void * src0,
 
4577
  kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
4578
  }
4579
 
4580
+ [[host_name("kernel_mul_mv_iq4_nl_f32")]]
4581
+ kernel void kernel_mul_mv_iq4_nl_f32(
4582
+ device const void * src0,
4583
+ device const float * src1,
4584
+ device float * dst,
4585
+ constant int64_t & ne00,
4586
+ constant int64_t & ne01,
4587
+ constant int64_t & ne02,
4588
+ constant uint64_t & nb00,
4589
+ constant uint64_t & nb01,
4590
+ constant uint64_t & nb02,
4591
+ constant int64_t & ne10,
4592
+ constant int64_t & ne11,
4593
+ constant int64_t & ne12,
4594
+ constant uint64_t & nb10,
4595
+ constant uint64_t & nb11,
4596
+ constant uint64_t & nb12,
4597
+ constant int64_t & ne0,
4598
+ constant int64_t & ne1,
4599
+ constant uint & r2,
4600
+ constant uint & r3,
4601
+ threadgroup float * shared_values [[threadgroup(0)]],
4602
+ uint3 tgpig[[threadgroup_position_in_grid]],
4603
+ uint tiisg[[thread_index_in_simdgroup]],
4604
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4605
+
4606
+ kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
4607
+ }
4608
 
4609
  //============================= templates and their specializations =============================
4610
 
 
4968
  }
4969
  }
4970
 
4971
+ template <typename type4x4>
4972
+ void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4 & reg) {
4973
+ device const uint16_t * q4 = (device const uint16_t *)xb->qs;
4974
+ const float d = xb->d;
4975
+ uint32_t aux32;
4976
+ thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
4977
+ for (int i = 0; i < 4; ++i) {
4978
+ aux32 = ((q4[2*i] | (q4[2*i+1] << 16)) >> 4*il) & 0x0f0f0f0f;
4979
+ reg[i][0] = d * kvalues_iq4nl_f[q8[0]];
4980
+ reg[i][1] = d * kvalues_iq4nl_f[q8[1]];
4981
+ reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
4982
+ reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
4983
+ }
4984
+ }
4985
+
4986
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
4987
  kernel void kernel_get_rows(
4988
  device const void * src0,
 
5526
  template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_t kernel_get_rows<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
5527
  template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_t kernel_get_rows<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
5528
  template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows<block_iq1_s, QK_NL, dequantize_iq1_s>;
5529
+ template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_t kernel_get_rows<block_iq4_nl, 2, dequantize_iq4_nl>;
5530
 
5531
  //
5532
  // matrix-matrix multiplication
 
5567
  template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
5568
  template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
5569
  template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_s, QK_NL, dequantize_iq1_s>;
5570
+ template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_nl>;
5571
 
5572
  //
5573
  // indirect matrix-matrix multiplication
 
5620
  template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
5621
  template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
5622
  template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_s, QK_NL, dequantize_iq1_s>;
5623
+ template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_nl, 2, dequantize_iq4_nl>;
5624
 
5625
  //
5626
  // matrix-vector multiplication
 
6651
  tiisg,
6652
  sgitg);
6653
  }
6654
+
6655
+ [[host_name("kernel_mul_mv_id_iq4_nl_f32")]]
6656
+ kernel void kernel_mul_mv_id_iq4_nl_f32(
6657
+ device const char * ids,
6658
+ device const char * src1,
6659
+ device float * dst,
6660
+ constant uint64_t & nbi1,
6661
+ constant int64_t & ne00,
6662
+ constant int64_t & ne01,
6663
+ constant int64_t & ne02,
6664
+ constant uint64_t & nb00,
6665
+ constant uint64_t & nb01,
6666
+ constant uint64_t & nb02,
6667
+ constant int64_t & ne10,
6668
+ constant int64_t & ne11,
6669
+ constant int64_t & ne12,
6670
+ constant int64_t & ne13,
6671
+ constant uint64_t & nb10,
6672
+ constant uint64_t & nb11,
6673
+ constant uint64_t & nb12,
6674
+ constant int64_t & ne0,
6675
+ constant int64_t & ne1,
6676
+ constant uint64_t & nb1,
6677
+ constant uint & r2,
6678
+ constant uint & r3,
6679
+ constant int & idx,
6680
+ device const char * src00,
6681
+ device const char * src01,
6682
+ device const char * src02,
6683
+ device const char * src03,
6684
+ device const char * src04,
6685
+ device const char * src05,
6686
+ device const char * src06,
6687
+ device const char * src07,
6688
+ threadgroup float * shared_values [[threadgroup(0)]],
6689
+ uint3 tgpig[[threadgroup_position_in_grid]],
6690
+ uint tiitg[[thread_index_in_threadgroup]],
6691
+ uint tiisg[[thread_index_in_simdgroup]],
6692
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
6693
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
6694
+
6695
+ const int64_t bid = tgpig.z/(ne12*ne13);
6696
+
6697
+ tgpig.z = tgpig.z%(ne12*ne13);
6698
+
6699
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6700
+
6701
+ kernel_mul_mv_iq4_nl_f32_impl(
6702
+ src0[id],
6703
+ (device const float *) (src1 + bid*nb11),
6704
+ dst + bid*ne0,
6705
+ ne00,
6706
+ ne01,
6707
+ ne02,
6708
+ ne10,
6709
+ ne12,
6710
+ ne0,
6711
+ ne1,
6712
+ r2,
6713
+ r3,
6714
+ shared_values,
6715
+ tgpig,
6716
+ tiisg,
6717
+ sgitg);
6718
+ }
ggml-quants.c CHANGED
@@ -3754,6 +3754,26 @@ void dequantize_row_iq1_s(const block_iq1_s * restrict x, float * restrict y, in
3754
  }
3755
  }
3756
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3757
  //===================================== Q8_K ==============================================
3758
 
3759
  void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k) {
@@ -9148,7 +9168,6 @@ void ggml_vec_dot_iq2_xs_q8_K(int n, float * restrict s, size_t bs, const void *
9148
  #endif
9149
  }
9150
 
9151
- // TODO
9152
  void ggml_vec_dot_iq3_xxs_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
9153
  assert(n % QK_K == 0);
9154
  assert(nrc == 1);
@@ -9452,7 +9471,100 @@ void ggml_vec_dot_iq1_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const
9452
  *s = sumf;
9453
 
9454
  #endif
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9455
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9456
  }
9457
 
9458
  // ================================ IQ2 quantization =============================================
@@ -10729,3 +10841,123 @@ size_t quantize_iq1_s(const float * src, void * dst, int nrow, int n_per_row, in
10729
  }
10730
  return nrow * nblock * sizeof(block_iq1_s);
10731
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3754
  }
3755
  }
3756
 
3757
+ static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
3758
+
3759
+ void dequantize_row_iq4_nl(const block_iq4_nl * restrict x, float * restrict y, int k) {
3760
+ assert(k % QK4_NL == 0);
3761
+ const int nb = k / QK4_NL;
3762
+
3763
+ for (int i = 0; i < nb; i++) {
3764
+
3765
+ const uint8_t * qs = x[i].qs;
3766
+
3767
+ const float d = GGML_FP16_TO_FP32(x[i].d);
3768
+ for (int j = 0; j < QK4_NL/2; ++j) {
3769
+ y[j+ 0] = d * kvalues_iq4nl[qs[j] & 0xf];
3770
+ y[j+QK4_NL/2] = d * kvalues_iq4nl[qs[j] >> 4];
3771
+ }
3772
+ y += QK4_NL;
3773
+ qs += QK4_NL/2;
3774
+ }
3775
+ }
3776
+
3777
  //===================================== Q8_K ==============================================
3778
 
3779
  void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k) {
 
9168
  #endif
9169
  }
9170
 
 
9171
  void ggml_vec_dot_iq3_xxs_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
9172
  assert(n % QK_K == 0);
9173
  assert(nrc == 1);
 
9471
  *s = sumf;
9472
 
9473
  #endif
9474
+ }
9475
+
9476
+ void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
9477
+ assert(nrc == 1);
9478
+ UNUSED(nrc);
9479
+ UNUSED(bx);
9480
+ UNUSED(by);
9481
+ UNUSED(bs);
9482
+ assert(n % QK4_NL == 0);
9483
+ static_assert(QK4_NL == QK8_0, "QK4_NL and QK8_0 must be the same");
9484
+
9485
+ const block_iq4_nl * restrict x = vx;
9486
+ const block_q8_0 * restrict y = vy;
9487
+
9488
+ const int nb = n / QK4_NL;
9489
+
9490
+ #if defined __ARM_NEON
9491
+ const int8x16_t values = vld1q_s8(kvalues_iq4nl);
9492
+ const uint8x16_t m4b = vdupq_n_u8(0x0f);
9493
+ uint8x16x2_t q4bits;
9494
+ int8x16x4_t q4b;
9495
+ int8x16x4_t q8b;
9496
+ int32x4_t prod_1, prod_2;
9497
 
9498
+ float sumf = 0;
9499
+
9500
+ for (int ib = 0; ib < nb; ib += 2) {
9501
+
9502
+ q4bits.val[0] = vld1q_u8(x[ib+0].qs);
9503
+ q4bits.val[1] = vld1q_u8(x[ib+1].qs);
9504
+ q8b.val[0] = vld1q_s8(y[ib+0].qs);
9505
+ q8b.val[1] = vld1q_s8(y[ib+0].qs + 16);
9506
+ q8b.val[2] = vld1q_s8(y[ib+1].qs);
9507
+ q8b.val[3] = vld1q_s8(y[ib+1].qs + 16);
9508
+
9509
+ q4b.val[0] = vqtbl1q_s8(values, vandq_u8(q4bits.val[0], m4b));
9510
+ q4b.val[1] = vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[0], 4));
9511
+ q4b.val[2] = vqtbl1q_s8(values, vandq_u8(q4bits.val[1], m4b));
9512
+ q4b.val[3] = vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[1], 4));
9513
+
9514
+ prod_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[0], q8b.val[0]), q4b.val[1], q8b.val[1]);
9515
+ prod_2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[2], q8b.val[2]), q4b.val[3], q8b.val[3]);
9516
+
9517
+ sumf += (float)x[ib+0].d * (float)y[ib+0].d * vaddvq_s32(prod_1) + (float)x[ib+1].d * (float)y[ib+1].d * vaddvq_s32(prod_2);
9518
+
9519
+ }
9520
+
9521
+ *s = sumf;
9522
+
9523
+ #elif defined __AVX2__
9524
+
9525
+ const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl);
9526
+ const __m128i m4b = _mm_set1_epi8(0x0f);
9527
+ const __m256i mone = _mm256_set1_epi16(1);
9528
+
9529
+ __m256 accum1 = _mm256_setzero_ps();
9530
+ __m256 accum2 = _mm256_setzero_ps();
9531
+ for (int ib = 0; ib < nb; ib += 2) {
9532
+ const __m128i q4bits_1 = _mm_loadu_si128((const __m128i*)x[0].qs);
9533
+ const __m128i q4bits_2 = _mm_loadu_si128((const __m128i*)x[1].qs);
9534
+ const __m256i q8b_1 = _mm256_loadu_si256((const __m256i *)y[0].qs);
9535
+ const __m256i q8b_2 = _mm256_loadu_si256((const __m256i *)y[1].qs);
9536
+ const __m256i q4b_1 = _mm256_set_m128i(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)),
9537
+ _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b)));
9538
+ const __m256i q4b_2 = _mm256_set_m128i(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)),
9539
+ _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b)));
9540
+ const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1);
9541
+ const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2);
9542
+ const __m256i p_1 = _mm256_madd_epi16(p16_1, mone);
9543
+ const __m256i p_2 = _mm256_madd_epi16(p16_2, mone);
9544
+ accum1 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[0].d)*GGML_FP16_TO_FP32(x[0].d)),
9545
+ _mm256_cvtepi32_ps(p_1), accum1);
9546
+ accum2 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[1].d)*GGML_FP16_TO_FP32(x[1].d)),
9547
+ _mm256_cvtepi32_ps(p_2), accum2);
9548
+
9549
+ y += 2;
9550
+ x += 2;
9551
+ }
9552
+
9553
+ *s = hsum_float_8(_mm256_add_ps(accum1, accum2));
9554
+
9555
+ #else
9556
+ float sumf = 0;
9557
+ for (int ib = 0; ib < nb; ++ib) {
9558
+ const float d = GGML_FP16_TO_FP32(y[ib].d)*GGML_FP16_TO_FP32(x[ib].d);
9559
+ int sumi1 = 0, sumi2 = 0;
9560
+ for (int j = 0; j < QK4_NL/2; ++j) {
9561
+ sumi1 += y[ib].qs[j+ 0] * kvalues_iq4nl[x[ib].qs[j] & 0xf];
9562
+ sumi2 += y[ib].qs[j+QK4_NL/2] * kvalues_iq4nl[x[ib].qs[j] >> 4];
9563
+ }
9564
+ sumf += d * (sumi1 + sumi2);
9565
+ }
9566
+ *s = sumf;
9567
+ #endif
9568
  }
9569
 
9570
  // ================================ IQ2 quantization =============================================
 
10841
  }
10842
  return nrow * nblock * sizeof(block_iq1_s);
10843
  }
10844
+
10845
+ // ============================ 4-bit non-linear quants
10846
+
10847
+ static inline int best_index_int8(int n, const int8_t * val, float x) {
10848
+ if (x <= val[0]) return 0;
10849
+ if (x >= val[n-1]) return n-1;
10850
+ int ml = 0, mu = n-1;
10851
+ while (mu-ml > 1) {
10852
+ int mav = (ml+mu)/2;
10853
+ if (x < val[mav]) mu = mav; else ml = mav;
10854
+ }
10855
+ return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
10856
+ }
10857
+
10858
+ static void quantize_row_iq4_nl_impl(const int block_size, const float * GGML_RESTRICT x,
10859
+ ggml_fp16_t * dh, uint8_t * q4,
10860
+ float * weight, uint8_t * L,
10861
+ const int8_t * values,
10862
+ const float * quant_weights) {
10863
+
10864
+ const int ntry = 7;
10865
+
10866
+ float sigma2 = 0;
10867
+ for (int j = 0; j < QK4_NL; ++j) sigma2 += x[j]*x[j];
10868
+ sigma2 *= 2.f/QK4_NL;
10869
+
10870
+ const int nb = QK4_NL/block_size;
10871
+
10872
+ memset(q4, 0, QK4_NL/2);
10873
+ for (int ib = 0; ib < nb; ++ib) {
10874
+ dh[ib] = GGML_FP32_TO_FP16(0.f);
10875
+ const float * xb = x + ib*block_size;
10876
+ if (quant_weights) {
10877
+ const float * qw = quant_weights + ib*block_size;
10878
+ for (int j = 0; j < block_size; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]);
10879
+ } else {
10880
+ for (int j = 0; j < block_size; ++j) weight[j] = xb[j]*xb[j];
10881
+ }
10882
+ float amax = 0, max = 0;
10883
+ for (int j = 0; j < block_size; ++j) {
10884
+ float ax = fabsf(xb[j]);
10885
+ if (ax > amax) {
10886
+ amax = ax; max = xb[j];
10887
+ }
10888
+ }
10889
+ if (!amax) {
10890
+ continue;
10891
+ }
10892
+ float d = -max/values[0];
10893
+ float id = 1/d;
10894
+ float sumqx = 0, sumq2 = 0;
10895
+ for (int j = 0; j < block_size; ++j) {
10896
+ float al = id*xb[j];
10897
+ int l = best_index_int8(16, values, al);
10898
+ float q = values[l];
10899
+ float w = weight[j];
10900
+ sumqx += w*q*xb[j];
10901
+ sumq2 += w*q*q;
10902
+ }
10903
+ float best_id = id;
10904
+ d = sumqx/sumq2;
10905
+ float best = d*sumqx;
10906
+ for (int itry = -ntry; itry <= ntry; ++itry) {
10907
+ id = (itry + values[0])/max;
10908
+ sumqx = sumq2 = 0;
10909
+ for (int j = 0; j < block_size; ++j) {
10910
+ float al = id*xb[j];
10911
+ int l = best_index_int8(16, values, al);
10912
+ float q = values[l];
10913
+ float w = weight[j];
10914
+ sumqx += w*q*xb[j];
10915
+ sumq2 += w*q*q;
10916
+ }
10917
+ if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
10918
+ d = sumqx/sumq2; best = d * sumqx;
10919
+ best_id = id;
10920
+ }
10921
+ }
10922
+ dh[ib] = GGML_FP32_TO_FP16(d);
10923
+ for (int j = 0; j < block_size; ++j) {
10924
+ L[ib*block_size + j] = best_index_int8(16, values, best_id*xb[j]);
10925
+ }
10926
+ }
10927
+ for (int i = 0; i < QK4_NL/32; ++i) {
10928
+ for (int j = 0; j < 16; ++j) {
10929
+ q4[16*i + j] = L[32*i + j] | (L[32*i + 16 + j] << 4);
10930
+ }
10931
+ }
10932
+ }
10933
+
10934
+ size_t quantize_iq4_nl(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
10935
+ (void)hist;
10936
+ GGML_ASSERT(n_per_row%QK4_NL == 0);
10937
+ int nblock = n_per_row/QK4_NL;
10938
+ char * qrow = (char *)dst;
10939
+ uint8_t L[QK4_NL];
10940
+ float weight[32];
10941
+ for (int row = 0; row < nrow; ++row) {
10942
+ block_iq4_nl * iq4 = (block_iq4_nl *)qrow;
10943
+ for (int ibl = 0; ibl < nblock; ++ibl) {
10944
+ const float * qw = quant_weights ? quant_weights + QK4_NL*ibl : NULL;
10945
+ quantize_row_iq4_nl_impl(32, src + QK4_NL*ibl, &iq4[ibl].d, iq4[ibl].qs, weight, L, kvalues_iq4nl, qw);
10946
+ }
10947
+ src += n_per_row;
10948
+ qrow += nblock*sizeof(block_iq4_nl);
10949
+ }
10950
+ return nrow * nblock * sizeof(block_iq4_nl);
10951
+ }
10952
+
10953
+ void quantize_row_iq4_nl(const float * restrict x, void * restrict vy, int k) {
10954
+ assert(k % QK4_NL == 0);
10955
+ block_iq4_nl * restrict y = vy;
10956
+ quantize_row_iq4_nl_reference(x, y, k);
10957
+ }
10958
+
10959
+ void quantize_row_iq4_nl_reference(const float * restrict x, block_iq4_nl * restrict y, int k) {
10960
+ assert(k % QK4_NL == 0);
10961
+ quantize_iq4_nl(x, y, 1, k, NULL, NULL);
10962
+ }
10963
+
ggml-quants.h CHANGED
@@ -198,6 +198,14 @@ typedef struct {
198
  } block_iq1_s;
199
  static_assert(sizeof(block_iq1_s) == sizeof(ggml_fp16_t) + QK_K/8 + QK_K/16, "wrong iq1_s block size/padding");
200
 
 
 
 
 
 
 
 
 
201
  #ifdef __cplusplus
202
  extern "C" {
203
  #endif
@@ -217,6 +225,7 @@ void quantize_row_q5_K_reference(const float * GGML_RESTRICT x, block_q5_K * GGM
217
  void quantize_row_q6_K_reference(const float * GGML_RESTRICT x, block_q6_K * GGML_RESTRICT y, int k);
218
  void quantize_row_q8_K_reference(const float * GGML_RESTRICT x, block_q8_K * GGML_RESTRICT y, int k);
219
  void quantize_row_iq3_xxs_reference(const float * GGML_RESTRICT x, block_iq3_xxs * GGML_RESTRICT y, int k);
 
220
 
221
  void quantize_row_q4_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
222
  void quantize_row_q4_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
@@ -232,6 +241,7 @@ void quantize_row_q5_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, in
232
  void quantize_row_q6_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
233
  void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
234
  void quantize_row_iq3_xxs(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
 
235
 
236
  // Dequantization
237
  void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
@@ -251,6 +261,7 @@ void dequantize_row_iq2_xxs(const block_iq2_xxs * GGML_RESTRICT x, float * GGML_
251
  void dequantize_row_iq2_xs (const block_iq2_xs * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
252
  void dequantize_row_iq3_xxs(const block_iq3_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
253
  void dequantize_row_iq1_s (const block_iq1_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
 
254
 
255
  // Dot product
256
  void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
@@ -268,6 +279,7 @@ void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const
268
  void ggml_vec_dot_iq2_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
269
  void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
270
  void ggml_vec_dot_iq1_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
 
271
 
272
  //
273
  // Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization")
@@ -276,6 +288,7 @@ size_t quantize_iq2_xxs(const float * src, void * dst, int nrows, int n_per_row,
276
  size_t quantize_iq2_xs (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
277
  size_t quantize_iq3_xxs(const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
278
  size_t quantize_iq1_s (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
 
279
  size_t quantize_q2_K (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
280
  size_t quantize_q3_K (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
281
  size_t quantize_q4_K (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
 
198
  } block_iq1_s;
199
  static_assert(sizeof(block_iq1_s) == sizeof(ggml_fp16_t) + QK_K/8 + QK_K/16, "wrong iq1_s block size/padding");
200
 
201
+ // Non-linear quants
202
+ #define QK4_NL 32
203
+ typedef struct {
204
+ ggml_fp16_t d;
205
+ uint8_t qs[QK4_NL/2];
206
+ } block_iq4_nl;
207
+ static_assert(sizeof(block_iq4_nl) == sizeof(ggml_fp16_t) + QK4_NL/2, "wrong iq4_nl block size/padding");
208
+
209
  #ifdef __cplusplus
210
  extern "C" {
211
  #endif
 
225
  void quantize_row_q6_K_reference(const float * GGML_RESTRICT x, block_q6_K * GGML_RESTRICT y, int k);
226
  void quantize_row_q8_K_reference(const float * GGML_RESTRICT x, block_q8_K * GGML_RESTRICT y, int k);
227
  void quantize_row_iq3_xxs_reference(const float * GGML_RESTRICT x, block_iq3_xxs * GGML_RESTRICT y, int k);
228
+ void quantize_row_iq4_nl_reference (const float * GGML_RESTRICT x, block_iq4_nl * GGML_RESTRICT y, int k);
229
 
230
  void quantize_row_q4_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
231
  void quantize_row_q4_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
 
241
  void quantize_row_q6_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
242
  void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
243
  void quantize_row_iq3_xxs(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
244
+ void quantize_row_iq4_nl (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
245
 
246
  // Dequantization
247
  void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
 
261
  void dequantize_row_iq2_xs (const block_iq2_xs * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
262
  void dequantize_row_iq3_xxs(const block_iq3_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
263
  void dequantize_row_iq1_s (const block_iq1_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
264
+ void dequantize_row_iq4_nl (const block_iq4_nl * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
265
 
266
  // Dot product
267
  void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
 
279
  void ggml_vec_dot_iq2_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
280
  void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
281
  void ggml_vec_dot_iq1_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
282
+ void ggml_vec_dot_iq4_nl_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
283
 
284
  //
285
  // Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization")
 
288
  size_t quantize_iq2_xs (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
289
  size_t quantize_iq3_xxs(const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
290
  size_t quantize_iq1_s (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
291
+ size_t quantize_iq4_nl (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
292
  size_t quantize_q2_K (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
293
  size_t quantize_q3_K (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
294
  size_t quantize_q4_K (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
ggml.c CHANGED
@@ -690,6 +690,18 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
690
  .vec_dot_type = GGML_TYPE_Q8_K,
691
  .nrows = 1,
692
  },
 
 
 
 
 
 
 
 
 
 
 
 
693
  [GGML_TYPE_Q8_K] = {
694
  .type_name = "q8_K",
695
  .blck_size = QK_K,
@@ -2291,6 +2303,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
2291
  case GGML_FTYPE_MOSTLY_IQ2_XS: wtype = GGML_TYPE_IQ2_XS; break;
2292
  case GGML_FTYPE_MOSTLY_IQ3_XXS: wtype = GGML_TYPE_IQ3_XXS; break;
2293
  case GGML_FTYPE_MOSTLY_IQ1_S: wtype = GGML_TYPE_IQ1_S; break;
 
2294
  case GGML_FTYPE_UNKNOWN: wtype = GGML_TYPE_COUNT; break;
2295
  case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = GGML_TYPE_COUNT; break;
2296
  }
@@ -7724,6 +7737,7 @@ static void ggml_compute_forward_add(
7724
  case GGML_TYPE_IQ2_XS:
7725
  case GGML_TYPE_IQ3_XXS:
7726
  case GGML_TYPE_IQ1_S:
 
7727
  {
7728
  ggml_compute_forward_add_q_f32(params, dst);
7729
  } break;
@@ -8002,6 +8016,7 @@ static void ggml_compute_forward_add1(
8002
  case GGML_TYPE_IQ2_XS:
8003
  case GGML_TYPE_IQ3_XXS:
8004
  case GGML_TYPE_IQ1_S:
 
8005
  {
8006
  ggml_compute_forward_add1_q_f32(params, dst);
8007
  } break;
@@ -8125,6 +8140,7 @@ static void ggml_compute_forward_acc(
8125
  case GGML_TYPE_IQ2_XS:
8126
  case GGML_TYPE_IQ3_XXS:
8127
  case GGML_TYPE_IQ1_S:
 
8128
  default:
8129
  {
8130
  GGML_ASSERT(false);
@@ -11022,6 +11038,7 @@ static void ggml_compute_forward_out_prod(
11022
  case GGML_TYPE_IQ2_XS:
11023
  case GGML_TYPE_IQ3_XXS:
11024
  case GGML_TYPE_IQ1_S:
 
11025
  {
11026
  ggml_compute_forward_out_prod_q_f32(params, dst);
11027
  } break;
@@ -11209,6 +11226,7 @@ static void ggml_compute_forward_set(
11209
  case GGML_TYPE_IQ2_XS:
11210
  case GGML_TYPE_IQ3_XXS:
11211
  case GGML_TYPE_IQ1_S:
 
11212
  default:
11213
  {
11214
  GGML_ASSERT(false);
@@ -11410,6 +11428,7 @@ static void ggml_compute_forward_get_rows(
11410
  case GGML_TYPE_IQ2_XS:
11411
  case GGML_TYPE_IQ3_XXS:
11412
  case GGML_TYPE_IQ1_S:
 
11413
  {
11414
  ggml_compute_forward_get_rows_q(params, dst);
11415
  } break;
@@ -12109,6 +12128,7 @@ static void ggml_compute_forward_alibi(
12109
  case GGML_TYPE_IQ2_XS:
12110
  case GGML_TYPE_IQ3_XXS:
12111
  case GGML_TYPE_IQ1_S:
 
12112
  case GGML_TYPE_Q8_K:
12113
  case GGML_TYPE_I8:
12114
  case GGML_TYPE_I16:
@@ -12191,6 +12211,7 @@ static void ggml_compute_forward_clamp(
12191
  case GGML_TYPE_IQ2_XS:
12192
  case GGML_TYPE_IQ3_XXS:
12193
  case GGML_TYPE_IQ1_S:
 
12194
  case GGML_TYPE_Q8_K:
12195
  case GGML_TYPE_I8:
12196
  case GGML_TYPE_I16:
@@ -19725,6 +19746,15 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i
19725
  result = quantize_iq1_s(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
19726
  GGML_ASSERT(result == row_size * nrows);
19727
  } break;
 
 
 
 
 
 
 
 
 
19728
  case GGML_TYPE_F16:
19729
  {
19730
  size_t elemsize = sizeof(ggml_fp16_t);
 
690
  .vec_dot_type = GGML_TYPE_Q8_K,
691
  .nrows = 1,
692
  },
693
+ [GGML_TYPE_IQ4_NL] = {
694
+ .type_name = "iq4_nl",
695
+ .blck_size = QK4_NL,
696
+ .type_size = sizeof(block_iq4_nl),
697
+ .is_quantized = true,
698
+ .to_float = (ggml_to_float_t) dequantize_row_iq4_nl,
699
+ .from_float = quantize_row_iq4_nl,
700
+ .from_float_reference = (ggml_from_float_t)quantize_row_iq4_nl_reference,
701
+ .vec_dot = ggml_vec_dot_iq4_nl_q8_0,
702
+ .vec_dot_type = GGML_TYPE_Q8_0,
703
+ .nrows = 1,
704
+ },
705
  [GGML_TYPE_Q8_K] = {
706
  .type_name = "q8_K",
707
  .blck_size = QK_K,
 
2303
  case GGML_FTYPE_MOSTLY_IQ2_XS: wtype = GGML_TYPE_IQ2_XS; break;
2304
  case GGML_FTYPE_MOSTLY_IQ3_XXS: wtype = GGML_TYPE_IQ3_XXS; break;
2305
  case GGML_FTYPE_MOSTLY_IQ1_S: wtype = GGML_TYPE_IQ1_S; break;
2306
+ case GGML_FTYPE_MOSTLY_IQ4_NL: wtype = GGML_TYPE_IQ4_NL; break;
2307
  case GGML_FTYPE_UNKNOWN: wtype = GGML_TYPE_COUNT; break;
2308
  case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = GGML_TYPE_COUNT; break;
2309
  }
 
7737
  case GGML_TYPE_IQ2_XS:
7738
  case GGML_TYPE_IQ3_XXS:
7739
  case GGML_TYPE_IQ1_S:
7740
+ case GGML_TYPE_IQ4_NL:
7741
  {
7742
  ggml_compute_forward_add_q_f32(params, dst);
7743
  } break;
 
8016
  case GGML_TYPE_IQ2_XS:
8017
  case GGML_TYPE_IQ3_XXS:
8018
  case GGML_TYPE_IQ1_S:
8019
+ case GGML_TYPE_IQ4_NL:
8020
  {
8021
  ggml_compute_forward_add1_q_f32(params, dst);
8022
  } break;
 
8140
  case GGML_TYPE_IQ2_XS:
8141
  case GGML_TYPE_IQ3_XXS:
8142
  case GGML_TYPE_IQ1_S:
8143
+ case GGML_TYPE_IQ4_NL:
8144
  default:
8145
  {
8146
  GGML_ASSERT(false);
 
11038
  case GGML_TYPE_IQ2_XS:
11039
  case GGML_TYPE_IQ3_XXS:
11040
  case GGML_TYPE_IQ1_S:
11041
+ case GGML_TYPE_IQ4_NL:
11042
  {
11043
  ggml_compute_forward_out_prod_q_f32(params, dst);
11044
  } break;
 
11226
  case GGML_TYPE_IQ2_XS:
11227
  case GGML_TYPE_IQ3_XXS:
11228
  case GGML_TYPE_IQ1_S:
11229
+ case GGML_TYPE_IQ4_NL:
11230
  default:
11231
  {
11232
  GGML_ASSERT(false);
 
11428
  case GGML_TYPE_IQ2_XS:
11429
  case GGML_TYPE_IQ3_XXS:
11430
  case GGML_TYPE_IQ1_S:
11431
+ case GGML_TYPE_IQ4_NL:
11432
  {
11433
  ggml_compute_forward_get_rows_q(params, dst);
11434
  } break;
 
12128
  case GGML_TYPE_IQ2_XS:
12129
  case GGML_TYPE_IQ3_XXS:
12130
  case GGML_TYPE_IQ1_S:
12131
+ case GGML_TYPE_IQ4_NL:
12132
  case GGML_TYPE_Q8_K:
12133
  case GGML_TYPE_I8:
12134
  case GGML_TYPE_I16:
 
12211
  case GGML_TYPE_IQ2_XS:
12212
  case GGML_TYPE_IQ3_XXS:
12213
  case GGML_TYPE_IQ1_S:
12214
+ case GGML_TYPE_IQ4_NL:
12215
  case GGML_TYPE_Q8_K:
12216
  case GGML_TYPE_I8:
12217
  case GGML_TYPE_I16:
 
19746
  result = quantize_iq1_s(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
19747
  GGML_ASSERT(result == row_size * nrows);
19748
  } break;
19749
+ case GGML_TYPE_IQ4_NL:
19750
+ {
19751
+ GGML_ASSERT(start % QK4_NL == 0);
19752
+ GGML_ASSERT(start % n_per_row == 0);
19753
+ size_t start_row = start / n_per_row;
19754
+ size_t row_size = ggml_row_size(type, n_per_row);
19755
+ result = quantize_iq4_nl(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
19756
+ GGML_ASSERT(result == row_size * nrows);
19757
+ } break;
19758
  case GGML_TYPE_F16:
19759
  {
19760
  size_t elemsize = sizeof(ggml_fp16_t);
ggml.h CHANGED
@@ -355,6 +355,7 @@ extern "C" {
355
  GGML_TYPE_IQ2_XS = 17,
356
  GGML_TYPE_IQ3_XXS = 18,
357
  GGML_TYPE_IQ1_S = 19,
 
358
  GGML_TYPE_I8,
359
  GGML_TYPE_I16,
360
  GGML_TYPE_I32,
@@ -393,6 +394,7 @@ extern "C" {
393
  GGML_FTYPE_MOSTLY_IQ2_XS = 16, // except 1d tensors
394
  GGML_FTYPE_MOSTLY_IQ3_XXS = 17, // except 1d tensors
395
  GGML_FTYPE_MOSTLY_IQ1_S = 18, // except 1d tensors
 
396
  };
397
 
398
  // available tensor operations:
 
355
  GGML_TYPE_IQ2_XS = 17,
356
  GGML_TYPE_IQ3_XXS = 18,
357
  GGML_TYPE_IQ1_S = 19,
358
+ GGML_TYPE_IQ4_NL = 20,
359
  GGML_TYPE_I8,
360
  GGML_TYPE_I16,
361
  GGML_TYPE_I32,
 
394
  GGML_FTYPE_MOSTLY_IQ2_XS = 16, // except 1d tensors
395
  GGML_FTYPE_MOSTLY_IQ3_XXS = 17, // except 1d tensors
396
  GGML_FTYPE_MOSTLY_IQ1_S = 18, // except 1d tensors
397
+ GGML_FTYPE_MOSTLY_IQ4_NL = 19, // except 1d tensors
398
  };
399
 
400
  // available tensor operations: