ggerganov commited on
Commit
192bda4
·
1 Parent(s): 6cb3028

ggml : full ALiBi support (llama/7192)

Browse files

* ggml : full ALiBi support

* ggml : update ggml_soft_max_ext() CUDA, SYCL

* ggml : ggml_flash_attn_ext() support ALiBi (CPU)

* ggml : ggml_flash_attn_ext() support ALiBi (Metal)

* ggml : fix warning

* ggml : ggml_flash_attn_ext() support ALiBi (CUDA)

ggml-ci

* ggml : fix assert message

* vulkan : add dev notes

* ggml : require mask when using ALiBi

ggml-ci

* convert : fix convert for refact models

Files changed (10) hide show
  1. ggml-cuda.cu +0 -5
  2. ggml-cuda/fattn.cu +62 -10
  3. ggml-cuda/softmax.cu +21 -34
  4. ggml-kompute.cpp +9 -3
  5. ggml-metal.m +54 -94
  6. ggml-metal.metal +49 -71
  7. ggml-sycl.cpp +19 -119
  8. ggml-vulkan.cpp +4 -2
  9. ggml.c +40 -269
  10. ggml.h +3 -15
ggml-cuda.cu CHANGED
@@ -4,7 +4,6 @@
4
 
5
  #include "ggml-cuda/common.cuh"
6
  #include "ggml-cuda/acc.cuh"
7
- #include "ggml-cuda/alibi.cuh"
8
  #include "ggml-cuda/arange.cuh"
9
  #include "ggml-cuda/argsort.cuh"
10
  #include "ggml-cuda/binbcast.cuh"
@@ -2280,9 +2279,6 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
2280
  case GGML_OP_ROPE:
2281
  ggml_cuda_op_rope(ctx, dst);
2282
  break;
2283
- case GGML_OP_ALIBI:
2284
- ggml_cuda_op_alibi(ctx, dst);
2285
- break;
2286
  case GGML_OP_IM2COL:
2287
  ggml_cuda_op_im2col(ctx, dst);
2288
  break;
@@ -2833,7 +2829,6 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
2833
  case GGML_OP_DIAG_MASK_INF:
2834
  case GGML_OP_SOFT_MAX:
2835
  case GGML_OP_ROPE:
2836
- case GGML_OP_ALIBI:
2837
  case GGML_OP_IM2COL:
2838
  case GGML_OP_POOL_2D:
2839
  case GGML_OP_SUM_ROWS:
 
4
 
5
  #include "ggml-cuda/common.cuh"
6
  #include "ggml-cuda/acc.cuh"
 
7
  #include "ggml-cuda/arange.cuh"
8
  #include "ggml-cuda/argsort.cuh"
9
  #include "ggml-cuda/binbcast.cuh"
 
2279
  case GGML_OP_ROPE:
2280
  ggml_cuda_op_rope(ctx, dst);
2281
  break;
 
 
 
2282
  case GGML_OP_IM2COL:
2283
  ggml_cuda_op_im2col(ctx, dst);
2284
  break;
 
2829
  case GGML_OP_DIAG_MASK_INF:
2830
  case GGML_OP_SOFT_MAX:
2831
  case GGML_OP_ROPE:
 
2832
  case GGML_OP_IM2COL:
2833
  case GGML_OP_POOL_2D:
2834
  case GGML_OP_SUM_ROWS:
ggml-cuda/fattn.cu CHANGED
@@ -23,6 +23,10 @@ static __global__ void flash_attn_vec_ext_f16(
23
  float * __restrict__ dst,
24
  float2 * __restrict__ dst_meta,
25
  const float scale,
 
 
 
 
26
  const int ne00,
27
  const int ne01,
28
  const int ne02,
@@ -58,6 +62,18 @@ static __global__ void flash_attn_vec_ext_f16(
58
  const int stride_KV = nb11 / sizeof(half);
59
  const int stride_KV2 = nb11 / sizeof(half2);
60
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
62
  constexpr int nwarps = D / WARP_SIZE;
63
  const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
@@ -141,7 +157,7 @@ static __global__ void flash_attn_vec_ext_f16(
141
  for (int j = 0; j < ncols; ++j) {
142
  sum2[j] = warp_reduce_sum(sum2[j]);
143
  half sum = __low2half(sum2[j]) + __high2half(sum2[j]);
144
- sum += mask ? maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
145
 
146
  if (ncols == 1) {
147
  kqmax_new = ggml_cuda_hmax(kqmax_new, sum);
@@ -249,6 +265,10 @@ static __global__ void flash_attn_ext_f16(
249
  float * __restrict__ dst,
250
  float2 * __restrict__ dst_meta,
251
  const float scale,
 
 
 
 
252
  const int ne00,
253
  const int ne01,
254
  const int ne02,
@@ -305,6 +325,20 @@ static __global__ void flash_attn_ext_f16(
305
  const int stride_Q = nb01 / sizeof(float);
306
  const int stride_KV = nb11 / sizeof(half);
307
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
308
  frag_b Q_b[D/16][ncols/frag_n];
309
 
310
  // A single buffer for temporarily holding tiles of KQ and VKQ parts:
@@ -421,7 +455,7 @@ static __global__ void flash_attn_ext_f16(
421
  for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
422
  const int k = k0 + threadIdx.x;
423
 
424
- KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f;
425
  KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/WARP_SIZE]);
426
  }
427
  KQ_max_new = warp_reduce_max(KQ_max_new);
@@ -464,7 +498,7 @@ static __global__ void flash_attn_ext_f16(
464
  for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
465
  const int k = k0 + threadIdx.x;
466
 
467
- KQ2_tmp[k0/WARP_SIZE] += mask ? mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
468
  KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]);
469
  }
470
  KQ_max_new = __half2half2(warp_reduce_max(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new))));
@@ -710,8 +744,17 @@ template <int D, int cols_per_block, int parallel_blocks> void launch_fattn_vec_
710
  const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]);
711
  const int shmem = 0;
712
 
713
- float scale;
714
- memcpy(&scale, KQV->op_params, sizeof(float));
 
 
 
 
 
 
 
 
 
715
 
716
  flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks>
717
  <<<blocks_num, block_dim, shmem, main_stream>>> (
@@ -720,7 +763,7 @@ template <int D, int cols_per_block, int parallel_blocks> void launch_fattn_vec_
720
  (const char *) V->data,
721
  mask ? ((const char *) mask->data) : nullptr,
722
  parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
723
- scale,
724
  Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
725
  K->ne[0], K->ne[1], K->ne[2], K->ne[3],
726
  mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
@@ -761,8 +804,17 @@ template <int D, int cols_per_block, int nwarps, int parallel_blocks, typename K
761
  const dim3 blocks_num(parallel_blocks*(Q->ne[1] + cols_per_block - 1) / cols_per_block, Q->ne[2], Q->ne[3]);
762
  const int shmem = 0;
763
 
764
- float scale;
765
- memcpy(&scale, KQV->op_params, sizeof(float));
 
 
 
 
 
 
 
 
 
766
 
767
  flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>
768
  <<<blocks_num, block_dim, shmem, main_stream>>> (
@@ -771,7 +823,7 @@ template <int D, int cols_per_block, int nwarps, int parallel_blocks, typename K
771
  (const char *) V->data,
772
  mask ? ((const char *) mask->data) : nullptr,
773
  (parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
774
- scale,
775
  Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
776
  K->ne[0], K->ne[1], K->ne[2], K->ne[3],
777
  mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
@@ -837,7 +889,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
837
  const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
838
  const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
839
 
840
- const int32_t precision = KQV->op_params[1];
841
 
842
  if (!fp16_mma_available(cc)) {
843
  GGML_ASSERT(precision == GGML_PREC_DEFAULT);
 
23
  float * __restrict__ dst,
24
  float2 * __restrict__ dst_meta,
25
  const float scale,
26
+ const float max_bias,
27
+ const float m0,
28
+ const float m1,
29
+ const uint32_t n_head_log2,
30
  const int ne00,
31
  const int ne01,
32
  const int ne02,
 
62
  const int stride_KV = nb11 / sizeof(half);
63
  const int stride_KV2 = nb11 / sizeof(half2);
64
 
65
+ half slopeh = __float2half(1.0f);
66
+
67
+ // ALiBi
68
+ if (max_bias > 0.0f) {
69
+ const int h = blockIdx.y;
70
+
71
+ const float base = h < n_head_log2 ? m0 : m1;
72
+ const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
73
+
74
+ slopeh = __float2half(powf(base, exph));
75
+ }
76
+
77
  static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
78
  constexpr int nwarps = D / WARP_SIZE;
79
  const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
 
157
  for (int j = 0; j < ncols; ++j) {
158
  sum2[j] = warp_reduce_sum(sum2[j]);
159
  half sum = __low2half(sum2[j]) + __high2half(sum2[j]);
160
+ sum += mask ? slopeh*maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
161
 
162
  if (ncols == 1) {
163
  kqmax_new = ggml_cuda_hmax(kqmax_new, sum);
 
265
  float * __restrict__ dst,
266
  float2 * __restrict__ dst_meta,
267
  const float scale,
268
+ const float max_bias,
269
+ const float m0,
270
+ const float m1,
271
+ const uint32_t n_head_log2,
272
  const int ne00,
273
  const int ne01,
274
  const int ne02,
 
325
  const int stride_Q = nb01 / sizeof(float);
326
  const int stride_KV = nb11 / sizeof(half);
327
 
328
+ half slopeh = __float2half(1.0f);
329
+ half2 slope2 = make_half2(1.0f, 1.0f);
330
+
331
+ // ALiBi
332
+ if (max_bias > 0.0f) {
333
+ const int h = blockIdx.y;
334
+
335
+ const float base = h < n_head_log2 ? m0 : m1;
336
+ const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
337
+
338
+ slopeh = __float2half(powf(base, exph));
339
+ slope2 = make_half2(slopeh, slopeh);
340
+ }
341
+
342
  frag_b Q_b[D/16][ncols/frag_n];
343
 
344
  // A single buffer for temporarily holding tiles of KQ and VKQ parts:
 
455
  for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
456
  const int k = k0 + threadIdx.x;
457
 
458
+ KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f;
459
  KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/WARP_SIZE]);
460
  }
461
  KQ_max_new = warp_reduce_max(KQ_max_new);
 
498
  for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
499
  const int k = k0 + threadIdx.x;
500
 
501
+ KQ2_tmp[k0/WARP_SIZE] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
502
  KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]);
503
  }
504
  KQ_max_new = __half2half2(warp_reduce_max(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new))));
 
744
  const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]);
745
  const int shmem = 0;
746
 
747
+ float scale = 1.0f;
748
+ float max_bias = 0.0f;
749
+
750
+ memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float));
751
+ memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
752
+
753
+ const uint32_t n_head = Q->ne[2];
754
+ const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
755
+
756
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
757
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
758
 
759
  flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks>
760
  <<<blocks_num, block_dim, shmem, main_stream>>> (
 
763
  (const char *) V->data,
764
  mask ? ((const char *) mask->data) : nullptr,
765
  parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
766
+ scale, max_bias, m0, m1, n_head_log2,
767
  Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
768
  K->ne[0], K->ne[1], K->ne[2], K->ne[3],
769
  mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
 
804
  const dim3 blocks_num(parallel_blocks*(Q->ne[1] + cols_per_block - 1) / cols_per_block, Q->ne[2], Q->ne[3]);
805
  const int shmem = 0;
806
 
807
+ float scale = 1.0f;
808
+ float max_bias = 0.0f;
809
+
810
+ memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float));
811
+ memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
812
+
813
+ const uint32_t n_head = Q->ne[2];
814
+ const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
815
+
816
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
817
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
818
 
819
  flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>
820
  <<<blocks_num, block_dim, shmem, main_stream>>> (
 
823
  (const char *) V->data,
824
  mask ? ((const char *) mask->data) : nullptr,
825
  (parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
826
+ scale, max_bias, m0, m1, n_head_log2,
827
  Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
828
  K->ne[0], K->ne[1], K->ne[2], K->ne[3],
829
  mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
 
889
  const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
890
  const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
891
 
892
+ const int32_t precision = KQV->op_params[2];
893
 
894
  if (!fp16_mma_available(cc)) {
895
  GGML_ASSERT(precision == GGML_PREC_DEFAULT);
ggml-cuda/softmax.cu CHANGED
@@ -11,7 +11,7 @@ __device__ float __forceinline__ t2f32<half>(half val) {
11
  }
12
 
13
  template <bool vals_smem, int ncols_template, int block_size_template, typename T>
14
- static __global__ void soft_max_f32(const float * x, const T * mask, const T * pos, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) {
15
  const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
16
 
17
  const int tid = threadIdx.x;
@@ -23,16 +23,16 @@ static __global__ void soft_max_f32(const float * x, const T * mask, const T * p
23
  const int warp_id = threadIdx.x / WARP_SIZE;
24
  const int lane_id = threadIdx.x % WARP_SIZE;
25
 
26
- float slope = 0.0f;
27
 
28
  // ALiBi
29
  if (max_bias > 0.0f) {
30
  const int h = rowx/nrows_y; // head index
31
 
32
  const float base = h < n_head_log2 ? m0 : m1;
33
- const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
34
 
35
- slope = powf(base, exp);
36
  }
37
 
38
  extern __shared__ float data_soft_max_f32[];
@@ -53,7 +53,7 @@ static __global__ void soft_max_f32(const float * x, const T * mask, const T * p
53
  const int64_t ix = (int64_t)rowx*ncols + col;
54
  const int64_t iy = (int64_t)rowy*ncols + col;
55
 
56
- const float val = x[ix]*scale + (mask ? t2f32(mask[iy]) : 0.0f) + (pos ? slope*t2f32(pos[col]) : 0.0f);
57
 
58
  vals[col] = val;
59
  max_val = max(max_val, val);
@@ -125,7 +125,7 @@ static __global__ void soft_max_f32(const float * x, const T * mask, const T * p
125
  }
126
 
127
  template<typename T>
128
- static void soft_max_f32_cuda(const float * x, const T * mask, const T * pos, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) {
129
  int nth = WARP_SIZE;
130
  while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
131
  const dim3 block_dims(nth, 1, 1);
@@ -133,8 +133,8 @@ static void soft_max_f32_cuda(const float * x, const T * mask, const T * pos, fl
133
  const size_t shmem = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float);
134
  static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
135
 
136
- const uint32_t n_head_kv = nrows_x/nrows_y;
137
- const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
138
 
139
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
140
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
@@ -142,43 +142,42 @@ static void soft_max_f32_cuda(const float * x, const T * mask, const T * pos, fl
142
  if (shmem < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) {
143
  switch (ncols_x) {
144
  case 32:
145
- soft_max_f32<true, 32, 32><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
146
  break;
147
  case 64:
148
- soft_max_f32<true, 64, 64><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
149
  break;
150
  case 128:
151
- soft_max_f32<true, 128, 128><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
152
  break;
153
  case 256:
154
- soft_max_f32<true, 256, 256><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
155
  break;
156
  case 512:
157
- soft_max_f32<true, 512, 512><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
158
  break;
159
  case 1024:
160
- soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
161
  break;
162
  case 2048:
163
- soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
164
  break;
165
  case 4096:
166
- soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
167
  break;
168
  default:
169
- soft_max_f32<true, 0, 0><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
170
  break;
171
  }
172
  } else {
173
  const size_t shmem_low = WARP_SIZE*sizeof(float);
174
- soft_max_f32<false, 0, 0><<<block_nums, block_dims, shmem_low, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
175
  }
176
  }
177
 
178
  void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
179
  const ggml_tensor * src0 = dst->src[0];
180
  const ggml_tensor * src1 = dst->src[1];
181
- const ggml_tensor * src2 = dst->src[2];
182
 
183
  const float * src0_d = (const float *)src0->data;
184
  const void * src1_d = src1 ? (const void *)src1->data : nullptr;
@@ -190,7 +189,6 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
190
  GGML_ASSERT( dst->type == GGML_TYPE_F32);
191
 
192
  GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
193
- GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F16 || src2->type == GGML_TYPE_F32); // src2 contains positions and it is optional
194
 
195
  const int64_t ne00 = src0->ne[0];
196
  const int64_t nrows_x = ggml_nrows(src0);
@@ -202,26 +200,15 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
202
  memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
203
  memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
204
 
205
- // positions tensor
206
- void * src2_d = nullptr;
207
-
208
- const bool use_src2 = src2 != nullptr;
209
-
210
- if (use_src2) {
211
- src2_d = (void *)src2->data;
212
- }
213
-
214
- const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16);
215
 
216
  if (use_f16) {
217
  const half * src1_dd = (const half *)src1_d;
218
- const half * src2_dd = (const half *)src2_d;
219
 
220
- soft_max_f32_cuda(src0_d, src1_dd, src2_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
221
  } else {
222
  const float * src1_dd = (const float *)src1_d;
223
- const float * src2_dd = (const float *)src2_d;
224
 
225
- soft_max_f32_cuda(src0_d, src1_dd, src2_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
226
  }
227
  }
 
11
  }
12
 
13
  template <bool vals_smem, int ncols_template, int block_size_template, typename T>
14
+ static __global__ void soft_max_f32(const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) {
15
  const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
16
 
17
  const int tid = threadIdx.x;
 
23
  const int warp_id = threadIdx.x / WARP_SIZE;
24
  const int lane_id = threadIdx.x % WARP_SIZE;
25
 
26
+ float slope = 1.0f;
27
 
28
  // ALiBi
29
  if (max_bias > 0.0f) {
30
  const int h = rowx/nrows_y; // head index
31
 
32
  const float base = h < n_head_log2 ? m0 : m1;
33
+ const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
34
 
35
+ slope = powf(base, exph);
36
  }
37
 
38
  extern __shared__ float data_soft_max_f32[];
 
53
  const int64_t ix = (int64_t)rowx*ncols + col;
54
  const int64_t iy = (int64_t)rowy*ncols + col;
55
 
56
+ const float val = x[ix]*scale + (mask ? slope*t2f32(mask[iy]) : 0.0f);
57
 
58
  vals[col] = val;
59
  max_val = max(max_val, val);
 
125
  }
126
 
127
  template<typename T>
128
+ static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) {
129
  int nth = WARP_SIZE;
130
  while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
131
  const dim3 block_dims(nth, 1, 1);
 
133
  const size_t shmem = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float);
134
  static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
135
 
136
+ const uint32_t n_head = nrows_x/nrows_y;
137
+ const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
138
 
139
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
140
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
 
142
  if (shmem < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) {
143
  switch (ncols_x) {
144
  case 32:
145
+ soft_max_f32<true, 32, 32><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
146
  break;
147
  case 64:
148
+ soft_max_f32<true, 64, 64><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
149
  break;
150
  case 128:
151
+ soft_max_f32<true, 128, 128><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
152
  break;
153
  case 256:
154
+ soft_max_f32<true, 256, 256><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
155
  break;
156
  case 512:
157
+ soft_max_f32<true, 512, 512><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
158
  break;
159
  case 1024:
160
+ soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
161
  break;
162
  case 2048:
163
+ soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
164
  break;
165
  case 4096:
166
+ soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
167
  break;
168
  default:
169
+ soft_max_f32<true, 0, 0><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
170
  break;
171
  }
172
  } else {
173
  const size_t shmem_low = WARP_SIZE*sizeof(float);
174
+ soft_max_f32<false, 0, 0><<<block_nums, block_dims, shmem_low, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
175
  }
176
  }
177
 
178
  void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
179
  const ggml_tensor * src0 = dst->src[0];
180
  const ggml_tensor * src1 = dst->src[1];
 
181
 
182
  const float * src0_d = (const float *)src0->data;
183
  const void * src1_d = src1 ? (const void *)src1->data : nullptr;
 
189
  GGML_ASSERT( dst->type == GGML_TYPE_F32);
190
 
191
  GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
 
192
 
193
  const int64_t ne00 = src0->ne[0];
194
  const int64_t nrows_x = ggml_nrows(src0);
 
200
  memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
201
  memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
202
 
203
+ const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
 
 
 
 
 
 
 
 
 
204
 
205
  if (use_f16) {
206
  const half * src1_dd = (const half *)src1_d;
 
207
 
208
+ soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
209
  } else {
210
  const float * src1_dd = (const float *)src1_d;
 
211
 
212
+ soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
213
  }
214
  }
ggml-kompute.cpp CHANGED
@@ -1559,12 +1559,18 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
1559
  case GGML_OP_SOFT_MAX:
1560
  {
1561
  float scale;
1562
- memcpy(&scale, dst->op_params, sizeof(float));
1563
 
1564
- #pragma message("TODO: add ggml_vk_soft_max() F16/F32 src1 and src2 support")
 
 
 
1565
  #pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021")
1566
  GGML_ASSERT(!src1 || src1t == GGML_TYPE_F32);
1567
- GGML_ASSERT(src2 == nullptr);
 
 
 
1568
 
1569
  ggml_vk_soft_max(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, ne01, ne02, ne03, scale);
1570
  } break;
 
1559
  case GGML_OP_SOFT_MAX:
1560
  {
1561
  float scale;
1562
+ float max_bias;
1563
 
1564
+ memcpy(&scale, (float *)dst->op_params + 0, sizeof(float));
1565
+ memcpy(&max_bias, (float *)dst->op_params + 1, sizeof(float));
1566
+
1567
+ #pragma message("TODO: add ggml_vk_soft_max() F16 src1 support")
1568
  #pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021")
1569
  GGML_ASSERT(!src1 || src1t == GGML_TYPE_F32);
1570
+
1571
+ #pragma message("TODO: add ALiBi support")
1572
+ #pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/7192")
1573
+ GGML_ASSERT(max_bias == 0.0f);
1574
 
1575
  ggml_vk_soft_max(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, ne01, ne02, ne03, scale);
1576
  } break;
ggml-metal.m CHANGED
@@ -170,7 +170,6 @@ enum ggml_metal_kernel_type {
170
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,
171
  GGML_METAL_KERNEL_TYPE_ROPE_F32,
172
  GGML_METAL_KERNEL_TYPE_ROPE_F16,
173
- GGML_METAL_KERNEL_TYPE_ALIBI_F32,
174
  GGML_METAL_KERNEL_TYPE_IM2COL_F16,
175
  GGML_METAL_KERNEL_TYPE_IM2COL_F32,
176
  GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
@@ -625,7 +624,6 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
625
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm);
626
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
627
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
628
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true);
629
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
630
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
631
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
@@ -762,7 +760,6 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
762
  case GGML_OP_GROUP_NORM:
763
  return ctx->support_simdgroup_reduction;
764
  case GGML_OP_NORM:
765
- case GGML_OP_ALIBI:
766
  case GGML_OP_ROPE:
767
  case GGML_OP_IM2COL:
768
  return true;
@@ -1373,13 +1370,12 @@ static enum ggml_status ggml_metal_graph_compute(
1373
  case GGML_OP_SOFT_MAX:
1374
  {
1375
  GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32);
1376
- GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F16 || src2->type == GGML_TYPE_F32);
1377
 
1378
  int nth = 32; // SIMD width
1379
 
1380
  id<MTLComputePipelineState> pipeline = nil;
1381
 
1382
- const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16);
1383
 
1384
  if (ne00%4 == 0) {
1385
  while (nth < ne00/4 && nth < 256) {
@@ -1410,8 +1406,8 @@ static enum ggml_status ggml_metal_graph_compute(
1410
  const int64_t nrows_x = ggml_nrows(src0);
1411
  const int64_t nrows_y = src0->ne[1];
1412
 
1413
- const uint32_t n_head_kv = nrows_x/nrows_y;
1414
- const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
1415
 
1416
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
1417
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
@@ -1423,20 +1419,15 @@ static enum ggml_status ggml_metal_graph_compute(
1423
  } else {
1424
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
1425
  }
1426
- if (id_src2) {
1427
- [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
1428
- } else {
1429
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:2];
1430
- }
1431
- [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
1432
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:4];
1433
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:5];
1434
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:6];
1435
- [encoder setBytes:&scale length:sizeof(scale) atIndex:7];
1436
- [encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:8];
1437
- [encoder setBytes:&m0 length:sizeof(m0) atIndex:9];
1438
- [encoder setBytes:&m1 length:sizeof(m1) atIndex:10];
1439
- [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:11];
1440
  [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
1441
 
1442
  [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
@@ -2241,49 +2232,6 @@ static enum ggml_status ggml_metal_graph_compute(
2241
 
2242
  [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2243
  } break;
2244
- case GGML_OP_ALIBI:
2245
- {
2246
- GGML_ASSERT((src0t == GGML_TYPE_F32));
2247
-
2248
- const int nth = MIN(1024, ne00);
2249
-
2250
- //const int n_past = ((int32_t *) dst->op_params)[0];
2251
- const int n_head = ((int32_t *) dst->op_params)[1];
2252
-
2253
- float max_bias;
2254
- memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
2255
-
2256
- const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
2257
- const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
2258
- const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
2259
-
2260
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ALIBI_F32].pipeline;
2261
-
2262
- [encoder setComputePipelineState:pipeline];
2263
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2264
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2265
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
2266
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
2267
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
2268
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
2269
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
2270
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
2271
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
2272
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
2273
- [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
2274
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
2275
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
2276
- [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
2277
- [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
2278
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
2279
- [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
2280
- [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
2281
- [encoder setBytes:&m0 length:sizeof( float) atIndex:18];
2282
- [encoder setBytes:&m1 length:sizeof( float) atIndex:19];
2283
- [encoder setBytes:&n_heads_log2_floor length:sizeof(int) atIndex:20];
2284
-
2285
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2286
- } break;
2287
  case GGML_OP_ROPE:
2288
  {
2289
  GGML_ASSERT(ne10 == ne02);
@@ -2581,7 +2529,7 @@ static enum ggml_status ggml_metal_graph_compute(
2581
  "the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big");
2582
 
2583
  const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30);
2584
- const int64_t ne31 = src3 ? src3->ne[1] : 0;
2585
  const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32);
2586
  const int64_t ne33 = src3 ? src3->ne[3] : 0; GGML_UNUSED(ne33);
2587
 
@@ -2593,7 +2541,16 @@ static enum ggml_status ggml_metal_graph_compute(
2593
  const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
2594
 
2595
  float scale;
2596
- memcpy(&scale, dst->op_params, sizeof(float));
 
 
 
 
 
 
 
 
 
2597
 
2598
  id<MTLComputePipelineState> pipeline = nil;
2599
 
@@ -2630,34 +2587,37 @@ static enum ggml_status ggml_metal_graph_compute(
2630
  }
2631
 
2632
  [encoder setComputePipelineState:pipeline];
2633
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2634
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
2635
- [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
2636
- [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
2637
- [encoder setBuffer:id_dst offset:offs_dst atIndex:4];
2638
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:5];
2639
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:6];
2640
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:7];
2641
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:8];
2642
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:9];
2643
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:10];
2644
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:11];
2645
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:12];
2646
- [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:13];
2647
- [encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:14];
2648
- [encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:15];
2649
- [encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:16];
2650
- [encoder setBytes:&nb10 length:sizeof(uint64_t) atIndex:17];
2651
- [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:18];
2652
- [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:19];
2653
- [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:20];
2654
- [encoder setBytes:&ne31 length:sizeof( int64_t) atIndex:21];
2655
- [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:22];
2656
- [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:23];
2657
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:24];
2658
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:25];
2659
- [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26];
2660
- [encoder setBytes:&scale length:sizeof( float) atIndex:27];
 
 
 
2661
 
2662
  if (!use_vec_kernel) {
2663
  // half8x8 kernel
 
170
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,
171
  GGML_METAL_KERNEL_TYPE_ROPE_F32,
172
  GGML_METAL_KERNEL_TYPE_ROPE_F16,
 
173
  GGML_METAL_KERNEL_TYPE_IM2COL_F16,
174
  GGML_METAL_KERNEL_TYPE_IM2COL_F32,
175
  GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
 
624
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm);
625
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
626
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
 
627
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
628
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
629
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
 
760
  case GGML_OP_GROUP_NORM:
761
  return ctx->support_simdgroup_reduction;
762
  case GGML_OP_NORM:
 
763
  case GGML_OP_ROPE:
764
  case GGML_OP_IM2COL:
765
  return true;
 
1370
  case GGML_OP_SOFT_MAX:
1371
  {
1372
  GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32);
 
1373
 
1374
  int nth = 32; // SIMD width
1375
 
1376
  id<MTLComputePipelineState> pipeline = nil;
1377
 
1378
+ const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
1379
 
1380
  if (ne00%4 == 0) {
1381
  while (nth < ne00/4 && nth < 256) {
 
1406
  const int64_t nrows_x = ggml_nrows(src0);
1407
  const int64_t nrows_y = src0->ne[1];
1408
 
1409
+ const uint32_t n_head = nrows_x/nrows_y;
1410
+ const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
1411
 
1412
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
1413
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
 
1419
  } else {
1420
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
1421
  }
1422
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1423
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1424
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1425
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1426
+ [encoder setBytes:&scale length:sizeof(scale) atIndex:6];
1427
+ [encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:7];
1428
+ [encoder setBytes:&m0 length:sizeof(m0) atIndex:8];
1429
+ [encoder setBytes:&m1 length:sizeof(m1) atIndex:9];
1430
+ [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:10];
 
 
 
 
 
1431
  [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
1432
 
1433
  [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
 
2232
 
2233
  [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2234
  } break;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2235
  case GGML_OP_ROPE:
2236
  {
2237
  GGML_ASSERT(ne10 == ne02);
 
2529
  "the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big");
2530
 
2531
  const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30);
2532
+ //const int64_t ne31 = src3 ? src3->ne[1] : 0;
2533
  const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32);
2534
  const int64_t ne33 = src3 ? src3->ne[3] : 0; GGML_UNUSED(ne33);
2535
 
 
2541
  const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
2542
 
2543
  float scale;
2544
+ float max_bias;
2545
+
2546
+ memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale));
2547
+ memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias));
2548
+
2549
+ const uint32_t n_head = src0->ne[2];
2550
+ const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
2551
+
2552
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
2553
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
2554
 
2555
  id<MTLComputePipelineState> pipeline = nil;
2556
 
 
2587
  }
2588
 
2589
  [encoder setComputePipelineState:pipeline];
2590
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2591
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
2592
+ [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
2593
+ [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
2594
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:4];
2595
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:5];
2596
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:6];
2597
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:7];
2598
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:8];
2599
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:9];
2600
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:10];
2601
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:11];
2602
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:12];
2603
+ [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:13];
2604
+ [encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:14];
2605
+ [encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:15];
2606
+ [encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:16];
2607
+ [encoder setBytes:&nb10 length:sizeof(uint64_t) atIndex:17];
2608
+ [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:18];
2609
+ [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:19];
2610
+ [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:20];
2611
+ [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:21];
2612
+ [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:22];
2613
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:23];
2614
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:24];
2615
+ [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:25];
2616
+ [encoder setBytes:&scale length:sizeof( float) atIndex:26];
2617
+ [encoder setBytes:&max_bias length:sizeof( float) atIndex:27];
2618
+ [encoder setBytes:&m0 length:sizeof(m0) atIndex:28];
2619
+ [encoder setBytes:&m1 length:sizeof(m1) atIndex:29];
2620
+ [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:30];
2621
 
2622
  if (!use_vec_kernel) {
2623
  // half8x8 kernel
ggml-metal.metal CHANGED
@@ -363,7 +363,6 @@ template<typename T>
363
  kernel void kernel_soft_max(
364
  device const char * src0,
365
  device const char * src1,
366
- device const char * src2,
367
  device char * dst,
368
  constant int64_t & ne00,
369
  constant int64_t & ne01,
@@ -385,10 +384,9 @@ kernel void kernel_soft_max(
385
 
386
  device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
387
  device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00 : nullptr;
388
- device const T * ppos = src2 != src0 ? (device const T *) src2 : nullptr;
389
  device float * pdst = (device float *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
390
 
391
- float slope = 0.0f;
392
 
393
  // ALiBi
394
  if (max_bias > 0.0f) {
@@ -404,7 +402,7 @@ kernel void kernel_soft_max(
404
  float lmax = -INFINITY;
405
 
406
  for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
407
- lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f));
408
  }
409
 
410
  // find the max value in the block
@@ -429,7 +427,7 @@ kernel void kernel_soft_max(
429
  // parallel sum
430
  float lsum = 0.0f;
431
  for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
432
- const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f)) - max_val);
433
  lsum += exp_psrc0;
434
  pdst[i00] = exp_psrc0;
435
  }
@@ -468,7 +466,6 @@ template<typename T>
468
  kernel void kernel_soft_max_4(
469
  device const char * src0,
470
  device const char * src1,
471
- device const char * src2,
472
  device char * dst,
473
  constant int64_t & ne00,
474
  constant int64_t & ne01,
@@ -490,10 +487,9 @@ kernel void kernel_soft_max_4(
490
 
491
  device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
492
  device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00/4 : nullptr;
493
- device const T * ppos = src2 != src0 ? (device const T *) src2 : nullptr;
494
  device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
495
 
496
- float slope = 0.0f;
497
 
498
  if (max_bias > 0.0f) {
499
  const int64_t h = i02;
@@ -508,7 +504,7 @@ kernel void kernel_soft_max_4(
508
  float4 lmax4 = -INFINITY;
509
 
510
  for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
511
- lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f)));
512
  }
513
 
514
  const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
@@ -534,7 +530,7 @@ kernel void kernel_soft_max_4(
534
  // parallel sum
535
  float4 lsum4 = 0.0f;
536
  for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
537
- const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f))) - max_val);
538
  lsum4 += exp_psrc4;
539
  pdst4[i00] = exp_psrc4;
540
  }
@@ -1602,60 +1598,6 @@ kernel void kernel_mul_mv_f16_f32_l4(
1602
  }
1603
  }
1604
 
1605
- kernel void kernel_alibi_f32(
1606
- device const float * src0,
1607
- device float * dst,
1608
- constant int64_t & ne00,
1609
- constant int64_t & ne01,
1610
- constant int64_t & ne02,
1611
- constant int64_t & ne03,
1612
- constant uint64_t & nb00,
1613
- constant uint64_t & nb01,
1614
- constant uint64_t & nb02,
1615
- constant uint64_t & nb03,
1616
- constant int64_t & ne0,
1617
- constant int64_t & ne1,
1618
- constant int64_t & ne2,
1619
- constant int64_t & ne3,
1620
- constant uint64_t & nb0,
1621
- constant uint64_t & nb1,
1622
- constant uint64_t & nb2,
1623
- constant uint64_t & nb3,
1624
- constant float & m0,
1625
- constant float & m1,
1626
- constant int & n_heads_log2_floor,
1627
- uint3 tgpig[[threadgroup_position_in_grid]],
1628
- uint3 tpitg[[thread_position_in_threadgroup]],
1629
- uint3 ntg[[threads_per_threadgroup]]) {
1630
- const int64_t i03 = tgpig[2];
1631
- const int64_t i02 = tgpig[1];
1632
- const int64_t i01 = tgpig[0];
1633
-
1634
- const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
1635
-
1636
- const int64_t i3 = n / (ne2*ne1*ne0);
1637
- const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
1638
- const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
1639
- //const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
1640
-
1641
- const int64_t k = i3*ne3 + i2;
1642
-
1643
- float m_k;
1644
- if (k < n_heads_log2_floor) {
1645
- m_k = pow(m0, k + 1);
1646
- } else {
1647
- m_k = pow(m1, 2 * (k - n_heads_log2_floor) + 1);
1648
- }
1649
-
1650
- device char * dst_row = (device char *) dst + i3*nb3 + i2*nb2 + i1*nb1;
1651
- device const char * src_row = (device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01;
1652
- for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
1653
- const float src_v = *(device float *)(src_row + i00*nb00);
1654
- device float * dst_v = (device float *)(dst_row + i00*nb0);
1655
- *dst_v = i00 * m_k + src_v;
1656
- }
1657
- }
1658
-
1659
  static float rope_yarn_ramp(const float low, const float high, const int i0) {
1660
  const float y = (i0 / 2 - low) / max(0.001f, high - low);
1661
  return 1.0f - min(1.0f, max(0.0f, y));
@@ -2123,13 +2065,16 @@ typedef void (flash_attn_ext_f16_t)(
2123
  constant uint64_t & nb11,
2124
  constant uint64_t & nb12,
2125
  constant uint64_t & nb13,
2126
- constant int64_t & ne31,
2127
  constant uint64_t & nb31,
2128
  constant int64_t & ne0,
2129
  constant int64_t & ne1,
2130
  constant int64_t & ne2,
2131
  constant int64_t & ne3,
2132
  constant float & scale,
 
 
 
 
2133
  threadgroup half * shared,
2134
  uint3 tgpig[[threadgroup_position_in_grid]],
2135
  uint3 tpitg[[thread_position_in_threadgroup]],
@@ -2161,13 +2106,16 @@ kernel void kernel_flash_attn_ext_f16(
2161
  constant uint64_t & nb11,
2162
  constant uint64_t & nb12,
2163
  constant uint64_t & nb13,
2164
- constant int64_t & ne31,
2165
  constant uint64_t & nb31,
2166
  constant int64_t & ne0,
2167
  constant int64_t & ne1,
2168
  constant int64_t & ne2,
2169
  constant int64_t & ne3,
2170
  constant float & scale,
 
 
 
 
2171
  threadgroup half * shared [[threadgroup(0)]],
2172
  uint3 tgpig[[threadgroup_position_in_grid]],
2173
  uint3 tpitg[[thread_position_in_threadgroup]],
@@ -2264,6 +2212,19 @@ kernel void kernel_flash_attn_ext_f16(
2264
  // prepare diagonal scale matrix
2265
  simdgroup_float8x8 mscale(scale);
2266
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2267
  // loop over the KV cache
2268
  // each simdgroup handles blocks of Q rows and C columns
2269
  for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) {
@@ -2286,9 +2247,10 @@ kernel void kernel_flash_attn_ext_f16(
2286
  simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
2287
  }
2288
 
2289
- // mqk = mqk*scale + mask
2290
  simdgroup_half8x8 mm;
2291
  simdgroup_load(mm, mp + ic + 8*cc, nb31/sizeof(half), 0, false);
 
2292
  simdgroup_multiply_accumulate(mqk, mqk, mscale, mm);
2293
 
2294
  simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
@@ -2479,13 +2441,16 @@ kernel void kernel_flash_attn_ext_vec_f16(
2479
  constant uint64_t & nb11,
2480
  constant uint64_t & nb12,
2481
  constant uint64_t & nb13,
2482
- constant int64_t & ne31,
2483
  constant uint64_t & nb31,
2484
  constant int64_t & ne0,
2485
  constant int64_t & ne1,
2486
  constant int64_t & ne2,
2487
  constant int64_t & ne3,
2488
  constant float & scale,
 
 
 
 
2489
  threadgroup half * shared [[threadgroup(0)]],
2490
  uint3 tgpig[[threadgroup_position_in_grid]],
2491
  uint3 tpitg[[thread_position_in_threadgroup]],
@@ -2504,6 +2469,18 @@ kernel void kernel_flash_attn_ext_vec_f16(
2504
 
2505
  const short T = D + 2*nsg*SH; // shared memory size per query in (half)
2506
 
 
 
 
 
 
 
 
 
 
 
 
 
2507
  //threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
2508
  threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
2509
  threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
@@ -2610,10 +2587,10 @@ kernel void kernel_flash_attn_ext_vec_f16(
2610
  mqk += simd_shuffle_down(mqk, 2);
2611
  mqk += simd_shuffle_down(mqk, 1);
2612
 
2613
- // mqk = mqk*scale + mask
2614
  if (tiisg == 0) {
2615
  float4 mm = (float4) mp4[ic/4 + cc];
2616
- mqk = mqk*scale + mm;
2617
 
2618
  ss4[cc] = mqk;
2619
  }
@@ -2847,7 +2824,8 @@ kernel void kernel_cpy_f32_f16(
2847
  for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
2848
  device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
2849
 
2850
- dst_data[i00] = src[0];
 
2851
  }
2852
  }
2853
 
 
363
  kernel void kernel_soft_max(
364
  device const char * src0,
365
  device const char * src1,
 
366
  device char * dst,
367
  constant int64_t & ne00,
368
  constant int64_t & ne01,
 
384
 
385
  device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
386
  device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00 : nullptr;
 
387
  device float * pdst = (device float *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
388
 
389
+ float slope = 1.0f;
390
 
391
  // ALiBi
392
  if (max_bias > 0.0f) {
 
402
  float lmax = -INFINITY;
403
 
404
  for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
405
+ lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
406
  }
407
 
408
  // find the max value in the block
 
427
  // parallel sum
428
  float lsum = 0.0f;
429
  for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
430
+ const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val);
431
  lsum += exp_psrc0;
432
  pdst[i00] = exp_psrc0;
433
  }
 
466
  kernel void kernel_soft_max_4(
467
  device const char * src0,
468
  device const char * src1,
 
469
  device char * dst,
470
  constant int64_t & ne00,
471
  constant int64_t & ne01,
 
487
 
488
  device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
489
  device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00/4 : nullptr;
 
490
  device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
491
 
492
+ float slope = 1.0f;
493
 
494
  if (max_bias > 0.0f) {
495
  const int64_t h = i02;
 
504
  float4 lmax4 = -INFINITY;
505
 
506
  for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
507
+ lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
508
  }
509
 
510
  const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
 
530
  // parallel sum
531
  float4 lsum4 = 0.0f;
532
  for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
533
+ const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val);
534
  lsum4 += exp_psrc4;
535
  pdst4[i00] = exp_psrc4;
536
  }
 
1598
  }
1599
  }
1600
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1601
  static float rope_yarn_ramp(const float low, const float high, const int i0) {
1602
  const float y = (i0 / 2 - low) / max(0.001f, high - low);
1603
  return 1.0f - min(1.0f, max(0.0f, y));
 
2065
  constant uint64_t & nb11,
2066
  constant uint64_t & nb12,
2067
  constant uint64_t & nb13,
 
2068
  constant uint64_t & nb31,
2069
  constant int64_t & ne0,
2070
  constant int64_t & ne1,
2071
  constant int64_t & ne2,
2072
  constant int64_t & ne3,
2073
  constant float & scale,
2074
+ constant float & max_bias,
2075
+ constant float & m0,
2076
+ constant float & m1,
2077
+ constant uint32_t & n_head_log2,
2078
  threadgroup half * shared,
2079
  uint3 tgpig[[threadgroup_position_in_grid]],
2080
  uint3 tpitg[[thread_position_in_threadgroup]],
 
2106
  constant uint64_t & nb11,
2107
  constant uint64_t & nb12,
2108
  constant uint64_t & nb13,
 
2109
  constant uint64_t & nb31,
2110
  constant int64_t & ne0,
2111
  constant int64_t & ne1,
2112
  constant int64_t & ne2,
2113
  constant int64_t & ne3,
2114
  constant float & scale,
2115
+ constant float & max_bias,
2116
+ constant float & m0,
2117
+ constant float & m1,
2118
+ constant uint32_t & n_head_log2,
2119
  threadgroup half * shared [[threadgroup(0)]],
2120
  uint3 tgpig[[threadgroup_position_in_grid]],
2121
  uint3 tpitg[[thread_position_in_threadgroup]],
 
2212
  // prepare diagonal scale matrix
2213
  simdgroup_float8x8 mscale(scale);
2214
 
2215
+ // prepare diagonal slope matrix
2216
+ simdgroup_float8x8 mslope(1.0f);
2217
+
2218
+ // ALiBi
2219
+ if (max_bias > 0.0f) {
2220
+ const short h = iq2;
2221
+
2222
+ const float base = h < n_head_log2 ? m0 : m1;
2223
+ const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
2224
+
2225
+ mslope = simdgroup_float8x8(pow(base, exph));
2226
+ }
2227
+
2228
  // loop over the KV cache
2229
  // each simdgroup handles blocks of Q rows and C columns
2230
  for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) {
 
2247
  simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
2248
  }
2249
 
2250
+ // mqk = mqk*scale + mask*slope
2251
  simdgroup_half8x8 mm;
2252
  simdgroup_load(mm, mp + ic + 8*cc, nb31/sizeof(half), 0, false);
2253
+ simdgroup_multiply(mm, mslope, mm);
2254
  simdgroup_multiply_accumulate(mqk, mqk, mscale, mm);
2255
 
2256
  simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
 
2441
  constant uint64_t & nb11,
2442
  constant uint64_t & nb12,
2443
  constant uint64_t & nb13,
 
2444
  constant uint64_t & nb31,
2445
  constant int64_t & ne0,
2446
  constant int64_t & ne1,
2447
  constant int64_t & ne2,
2448
  constant int64_t & ne3,
2449
  constant float & scale,
2450
+ constant float & max_bias,
2451
+ constant float & m0,
2452
+ constant float & m1,
2453
+ constant uint32_t & n_head_log2,
2454
  threadgroup half * shared [[threadgroup(0)]],
2455
  uint3 tgpig[[threadgroup_position_in_grid]],
2456
  uint3 tpitg[[thread_position_in_threadgroup]],
 
2469
 
2470
  const short T = D + 2*nsg*SH; // shared memory size per query in (half)
2471
 
2472
+ float slope = 1.0f;
2473
+
2474
+ // ALiBi
2475
+ if (max_bias > 0.0f) {
2476
+ const short h = iq2;
2477
+
2478
+ const float base = h < n_head_log2 ? m0 : m1;
2479
+ const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
2480
+
2481
+ slope = pow(base, exp);
2482
+ }
2483
+
2484
  //threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
2485
  threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
2486
  threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
 
2587
  mqk += simd_shuffle_down(mqk, 2);
2588
  mqk += simd_shuffle_down(mqk, 1);
2589
 
2590
+ // mqk = mqk*scale + mask*slope
2591
  if (tiisg == 0) {
2592
  float4 mm = (float4) mp4[ic/4 + cc];
2593
+ mqk = mqk*scale + mm*slope;
2594
 
2595
  ss4[cc] = mqk;
2596
  }
 
2824
  for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
2825
  device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
2826
 
2827
+ // TODO: is there a better way to handle -INFINITY?
2828
+ dst_data[i00] = src[0] == -INFINITY ? -MAXHALF : src[0];
2829
  }
2830
  }
2831
 
ggml-sycl.cpp CHANGED
@@ -3154,7 +3154,6 @@ typedef float (*vec_dot_q_mul_mat_sycl_t)(
3154
  #define SYCL_SCALE_BLOCK_SIZE 256
3155
  #define SYCL_CLAMP_BLOCK_SIZE 256
3156
  #define SYCL_ROPE_BLOCK_SIZE 256
3157
- #define SYCL_ALIBI_BLOCK_SIZE 32
3158
  #define SYCL_DIAG_MASK_INF_BLOCK_SIZE 32
3159
  #define SYCL_QUANTIZE_BLOCK_SIZE 256
3160
  #define SYCL_DEQUANTIZE_BLOCK_SIZE 256
@@ -9316,32 +9315,6 @@ static void rope_glm_f32(
9316
  dst[i + half_n_dims * 3] = x2*sin_block_theta + x3*cos_block_theta;
9317
  }
9318
 
9319
- static void alibi_f32(const float * x, float * dst, const int ncols, const int k_rows,
9320
- const int n_heads_log2_floor, const float m0, const float m1,
9321
- const sycl::nd_item<3> &item_ct1) {
9322
- const int col = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
9323
- item_ct1.get_local_id(2);
9324
-
9325
- if (col >= ncols) {
9326
- return;
9327
- }
9328
-
9329
- const int row = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
9330
- item_ct1.get_local_id(1);
9331
- const int i = row*ncols + col;
9332
-
9333
- const int k = row/k_rows;
9334
-
9335
- float m_k;
9336
- if (k < n_heads_log2_floor) {
9337
- m_k = dpct::pow(m0, k + 1);
9338
- } else {
9339
- m_k = dpct::pow(m1, 2 * (k - n_heads_log2_floor) + 1);
9340
- }
9341
-
9342
- dst[i] = col * m_k + x[i];
9343
- }
9344
-
9345
  static void k_sum_rows_f32(const float * x, float * dst, const int ncols,
9346
  const sycl::nd_item<3> &item_ct1) {
9347
  const int row = item_ct1.get_group(1);
@@ -9443,7 +9416,7 @@ static void diag_mask_inf_f32(const float * x, float * dst, const int ncols, con
9443
 
9444
 
9445
  template <bool vals_smem, int ncols_template, int block_size_template>
9446
- static void soft_max_f32(const float * x, const float * mask, const float *pos, float * dst, const int ncols_par,
9447
  const int nrows_y, const float scale, const float max_bias, const float m0,
9448
  const float m1, uint32_t n_head_log2, const sycl::nd_item<3> &item_ct1, float *buf) {
9449
  const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
@@ -9457,7 +9430,7 @@ static void soft_max_f32(const float * x, const float * mask, const float *pos,
9457
  const int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
9458
  const int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
9459
 
9460
- float slope = 0.0f;
9461
 
9462
  // ALiBi
9463
  if (max_bias > 0.0f) {
@@ -9482,7 +9455,7 @@ static void soft_max_f32(const float * x, const float * mask, const float *pos,
9482
  const int ix = rowx*ncols + col;
9483
  const int iy = rowy*ncols + col;
9484
 
9485
- const float val = x[ix]*scale + (mask ? mask[iy] : 0.0f) + (pos ? slope*pos[col] : 0.0f);
9486
 
9487
  vals[col] = val;
9488
  max_val = sycl::max(max_val, val);
@@ -12964,20 +12937,6 @@ static void rope_glm_f32_sycl(const float *x, float *dst, int ncols, int nrows,
12964
  });
12965
  }
12966
 
12967
- static void alibi_f32_sycl(const float *x, float *dst, const int ncols,
12968
- const int nrows, const int k_rows,
12969
- const int n_heads_log2_floor, const float m0,
12970
- const float m1, dpct::queue_ptr stream) {
12971
- const sycl::range<3> block_dims(1, 1, SYCL_ALIBI_BLOCK_SIZE);
12972
- const int num_blocks_x = (ncols + SYCL_ALIBI_BLOCK_SIZE - 1) / (SYCL_ALIBI_BLOCK_SIZE);
12973
- const sycl::range<3> block_nums(1, nrows, num_blocks_x);
12974
- stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
12975
- [=](sycl::nd_item<3> item_ct1) {
12976
- alibi_f32(x, dst, ncols, k_rows,
12977
- n_heads_log2_floor, m0, m1, item_ct1);
12978
- });
12979
- }
12980
-
12981
  static void sum_rows_f32_sycl(const float *x, float *dst, const int ncols,
12982
  const int nrows, dpct::queue_ptr stream) {
12983
  const sycl::range<3> block_dims(1, 1, WARP_SIZE);
@@ -13058,7 +13017,7 @@ static void diag_mask_inf_f32_sycl(const float *x, float *dst,
13058
  }
13059
 
13060
  template <bool vals_smem, int ncols_template, int block_size_template>
13061
- static void soft_max_f32_submitter(const float * x, const float * mask, const float *pos, float * dst, const int ncols_par,
13062
  const int nrows_y, const float scale, const float max_bias, const float m0,
13063
  const float m1, uint32_t n_head_log2, sycl::range<3> block_nums, sycl::range<3> block_dims,
13064
  const size_t n_local_scratch, dpct::queue_ptr stream) {
@@ -13068,7 +13027,7 @@ static void soft_max_f32_submitter(const float * x, const float * mask, const fl
13068
  cgh.parallel_for(
13069
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
13070
  [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
13071
- soft_max_f32<vals_smem, ncols_template, block_size_template>(x, mask, pos, dst, ncols_par,
13072
  nrows_y, scale, max_bias, m0,
13073
  m1, n_head_log2, item_ct1,
13074
  local_buf_acc.get_pointer());
@@ -13076,7 +13035,7 @@ static void soft_max_f32_submitter(const float * x, const float * mask, const fl
13076
  });
13077
  }
13078
 
13079
- static void soft_max_f32_sycl(const float * x, const float * mask, const float * pos,
13080
  float * dst, const int ncols_x, const int nrows_x,
13081
  const int nrows_y, const float scale, const float max_bias,
13082
  dpct::queue_ptr stream) {
@@ -13098,60 +13057,60 @@ static void soft_max_f32_sycl(const float * x, const float * mask, const float *
13098
  const size_t local_mem_size = stream->get_device().get_info<sycl::info::device::local_mem_size>();
13099
  if (n_local_scratch*sizeof(float) < local_mem_size) {
13100
  if (ncols_x > max_block_size) {
13101
- soft_max_f32_submitter<true, 0, 0>(x, mask, pos, dst, ncols_x, nrows_y, scale,
13102
  max_bias, m0, m1, n_head_log2, block_nums,
13103
  block_dims, n_local_scratch, stream);
13104
  return;
13105
  }
13106
  switch (ncols_x) {
13107
  case 32:
13108
- soft_max_f32_submitter<true, 32, 32>(x, mask, pos, dst, ncols_x, nrows_y, scale,
13109
  max_bias, m0, m1, n_head_log2, block_nums,
13110
  block_dims, n_local_scratch, stream);
13111
  break;
13112
  case 64:
13113
- soft_max_f32_submitter<true, 64, 64>(x, mask, pos, dst, ncols_x, nrows_y, scale,
13114
  max_bias, m0, m1, n_head_log2, block_nums,
13115
  block_dims, n_local_scratch, stream);
13116
  break;
13117
  case 128:
13118
- soft_max_f32_submitter<true, 128, 128>(x, mask, pos, dst, ncols_x, nrows_y, scale,
13119
  max_bias, m0, m1, n_head_log2, block_nums,
13120
  block_dims, n_local_scratch, stream);
13121
  break;
13122
  case 256:
13123
- soft_max_f32_submitter<true, 256, 256>(x, mask, pos, dst, ncols_x, nrows_y, scale,
13124
  max_bias, m0, m1, n_head_log2, block_nums,
13125
  block_dims, n_local_scratch, stream);
13126
  break;
13127
  case 512:
13128
- soft_max_f32_submitter<true, 512, 512>(x, mask, pos, dst, ncols_x, nrows_y, scale,
13129
  max_bias, m0, m1, n_head_log2, block_nums,
13130
  block_dims, n_local_scratch, stream);
13131
  break;
13132
  case 1024:
13133
- soft_max_f32_submitter<true, 1024, 1024>(x, mask, pos, dst, ncols_x, nrows_y, scale,
13134
  max_bias, m0, m1, n_head_log2, block_nums,
13135
  block_dims, n_local_scratch, stream);
13136
  break;
13137
  case 2048:
13138
- soft_max_f32_submitter<true, 2048, 1024>(x, mask, pos, dst, ncols_x, nrows_y, scale,
13139
  max_bias, m0, m1, n_head_log2, block_nums,
13140
  block_dims, n_local_scratch, stream);
13141
  break;
13142
  case 4096:
13143
- soft_max_f32_submitter<true, 4096, 1024>(x, mask, pos, dst, ncols_x, nrows_y, scale,
13144
  max_bias, m0, m1, n_head_log2, block_nums,
13145
  block_dims, n_local_scratch, stream);
13146
  break;
13147
  default:
13148
- soft_max_f32_submitter<true, 0, 0>(x, mask, pos, dst, ncols_x, nrows_y, scale,
13149
  max_bias, m0, m1, n_head_log2, block_nums,
13150
  block_dims, n_local_scratch, stream);
13151
  break;
13152
  }
13153
  } else {
13154
- soft_max_f32_submitter<false, 0, 0>(x, mask, pos, dst, ncols_x, nrows_y, scale,
13155
  max_bias, m0, m1, n_head_log2, block_nums,
13156
  block_dims, WARP_SIZE, stream);
13157
  }
@@ -14562,36 +14521,6 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
14562
  (void) src1_dd;
14563
  }
14564
 
14565
- inline void ggml_sycl_op_alibi(const ggml_tensor *src0, const ggml_tensor *src1,
14566
- ggml_tensor *dst, const float *src0_dd,
14567
- const float *src1_dd, float *dst_dd,
14568
- const dpct::queue_ptr &main_stream) {
14569
-
14570
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
14571
- GGML_ASSERT( dst->type == GGML_TYPE_F32);
14572
-
14573
- GGML_TENSOR_LOCALS_3(int64_t, ne0, src0, ne);
14574
- const int64_t nrows = ggml_nrows(src0);
14575
-
14576
- //const int n_past = ((int32_t *) dst->op_params)[0];
14577
- const int n_head = ((int32_t *) dst->op_params)[1];
14578
- float max_bias;
14579
- memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
14580
-
14581
- //GGML_ASSERT(ne01 + n_past == ne00);
14582
- GGML_ASSERT(n_head == ne02);
14583
-
14584
- const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
14585
-
14586
- const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
14587
- const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
14588
-
14589
- alibi_f32_sycl(src0_dd, dst_dd, ne00, nrows, ne01, n_heads_log2_floor, m0, m1, main_stream);
14590
-
14591
- (void) src1;
14592
- (void) src1_dd;
14593
- }
14594
-
14595
  static void ggml_sycl_op_pool2d(const ggml_tensor *src0,
14596
  const ggml_tensor *src1, ggml_tensor *dst,
14597
  const float *src0_dd, const float *src1_dd,
@@ -14746,12 +14675,9 @@ inline void ggml_sycl_op_soft_max(const ggml_tensor *src0,
14746
  GGML_ASSERT(src0->type == GGML_TYPE_F32);
14747
  GGML_ASSERT( dst->type == GGML_TYPE_F32);
14748
 
14749
- const ggml_tensor * src2 = dst->src[2];
14750
-
14751
- #pragma message("TODO: add ggml_sycl_op_soft_max() F16 src1 and src2 support")
14752
  #pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021")
14753
  GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
14754
- GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32); // src2 contains positions and it is optional
14755
 
14756
  const int64_t ne00 = src0->ne[0];
14757
  const int64_t nrows_x = ggml_nrows(src0);
@@ -14763,25 +14689,7 @@ inline void ggml_sycl_op_soft_max(const ggml_tensor *src0,
14763
  memcpy(&scale, dst->op_params + 0, sizeof(float));
14764
  memcpy(&max_bias, dst->op_params + 1, sizeof(float));
14765
 
14766
- // positions tensor
14767
- float * src2_dd = nullptr;
14768
- sycl_pool_alloc<float> src2_f;
14769
-
14770
- const bool use_src2 = src2 != nullptr;
14771
-
14772
- if (use_src2) {
14773
- const bool src2_on_device = src2->backend == GGML_BACKEND_TYPE_GPU;
14774
-
14775
- if (src2_on_device) {
14776
- ggml_tensor_extra_gpu * src2_extra = (ggml_tensor_extra_gpu *) src2->extra;
14777
- src2_dd = (float *) src2_extra->data_device[g_main_device];
14778
- } else {
14779
- src2_dd = src2_f.alloc(ggml_nelements(src2));
14780
- SYCL_CHECK(ggml_sycl_cpy_tensor_2d(src2_dd, src2, 0, 0, 0, 1, main_stream));
14781
- }
14782
- }
14783
-
14784
- soft_max_f32_sycl(src0_dd, src1 ? src1_dd : nullptr, src2_dd, dst_dd, ne00,
14785
  nrows_x, nrows_y, scale, max_bias, main_stream);
14786
  }
14787
 
@@ -16232,10 +16140,6 @@ static void ggml_sycl_rope(const ggml_tensor * src0, const ggml_tensor * src1, g
16232
  ggml_sycl_op_flatten(src0, src1, dst, ggml_sycl_op_rope);
16233
  }
16234
 
16235
- static void ggml_sycl_alibi(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
16236
- ggml_sycl_op_flatten(src0, src1, dst, ggml_sycl_op_alibi);
16237
- }
16238
-
16239
  static void ggml_sycl_pool2d(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
16240
  ggml_sycl_op_flatten(src0, src1, dst, ggml_sycl_op_pool2d);
16241
  }
@@ -16612,9 +16516,6 @@ bool ggml_sycl_compute_forward(struct ggml_compute_params * params, struct ggml_
16612
  case GGML_OP_ROPE:
16613
  func = ggml_sycl_rope;
16614
  break;
16615
- case GGML_OP_ALIBI:
16616
- func = ggml_sycl_alibi;
16617
- break;
16618
  case GGML_OP_IM2COL:
16619
  func = ggml_sycl_im2col;
16620
  break;
@@ -17744,7 +17645,6 @@ GGML_CALL static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, cons
17744
  case GGML_OP_DIAG_MASK_INF:
17745
  case GGML_OP_SOFT_MAX:
17746
  case GGML_OP_ROPE:
17747
- case GGML_OP_ALIBI:
17748
  case GGML_OP_IM2COL:
17749
  case GGML_OP_POOL_2D:
17750
  case GGML_OP_SUM_ROWS:
 
3154
  #define SYCL_SCALE_BLOCK_SIZE 256
3155
  #define SYCL_CLAMP_BLOCK_SIZE 256
3156
  #define SYCL_ROPE_BLOCK_SIZE 256
 
3157
  #define SYCL_DIAG_MASK_INF_BLOCK_SIZE 32
3158
  #define SYCL_QUANTIZE_BLOCK_SIZE 256
3159
  #define SYCL_DEQUANTIZE_BLOCK_SIZE 256
 
9315
  dst[i + half_n_dims * 3] = x2*sin_block_theta + x3*cos_block_theta;
9316
  }
9317
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9318
  static void k_sum_rows_f32(const float * x, float * dst, const int ncols,
9319
  const sycl::nd_item<3> &item_ct1) {
9320
  const int row = item_ct1.get_group(1);
 
9416
 
9417
 
9418
  template <bool vals_smem, int ncols_template, int block_size_template>
9419
+ static void soft_max_f32(const float * x, const float * mask, float * dst, const int ncols_par,
9420
  const int nrows_y, const float scale, const float max_bias, const float m0,
9421
  const float m1, uint32_t n_head_log2, const sycl::nd_item<3> &item_ct1, float *buf) {
9422
  const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
 
9430
  const int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
9431
  const int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
9432
 
9433
+ float slope = 1.0f;
9434
 
9435
  // ALiBi
9436
  if (max_bias > 0.0f) {
 
9455
  const int ix = rowx*ncols + col;
9456
  const int iy = rowy*ncols + col;
9457
 
9458
+ const float val = x[ix]*scale + (mask ? slope*mask[iy] : 0.0f);
9459
 
9460
  vals[col] = val;
9461
  max_val = sycl::max(max_val, val);
 
12937
  });
12938
  }
12939
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12940
  static void sum_rows_f32_sycl(const float *x, float *dst, const int ncols,
12941
  const int nrows, dpct::queue_ptr stream) {
12942
  const sycl::range<3> block_dims(1, 1, WARP_SIZE);
 
13017
  }
13018
 
13019
  template <bool vals_smem, int ncols_template, int block_size_template>
13020
+ static void soft_max_f32_submitter(const float * x, const float * mask, float * dst, const int ncols_par,
13021
  const int nrows_y, const float scale, const float max_bias, const float m0,
13022
  const float m1, uint32_t n_head_log2, sycl::range<3> block_nums, sycl::range<3> block_dims,
13023
  const size_t n_local_scratch, dpct::queue_ptr stream) {
 
13027
  cgh.parallel_for(
13028
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
13029
  [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
13030
+ soft_max_f32<vals_smem, ncols_template, block_size_template>(x, mask, dst, ncols_par,
13031
  nrows_y, scale, max_bias, m0,
13032
  m1, n_head_log2, item_ct1,
13033
  local_buf_acc.get_pointer());
 
13035
  });
13036
  }
13037
 
13038
+ static void soft_max_f32_sycl(const float * x, const float * mask,
13039
  float * dst, const int ncols_x, const int nrows_x,
13040
  const int nrows_y, const float scale, const float max_bias,
13041
  dpct::queue_ptr stream) {
 
13057
  const size_t local_mem_size = stream->get_device().get_info<sycl::info::device::local_mem_size>();
13058
  if (n_local_scratch*sizeof(float) < local_mem_size) {
13059
  if (ncols_x > max_block_size) {
13060
+ soft_max_f32_submitter<true, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale,
13061
  max_bias, m0, m1, n_head_log2, block_nums,
13062
  block_dims, n_local_scratch, stream);
13063
  return;
13064
  }
13065
  switch (ncols_x) {
13066
  case 32:
13067
+ soft_max_f32_submitter<true, 32, 32>(x, mask, dst, ncols_x, nrows_y, scale,
13068
  max_bias, m0, m1, n_head_log2, block_nums,
13069
  block_dims, n_local_scratch, stream);
13070
  break;
13071
  case 64:
13072
+ soft_max_f32_submitter<true, 64, 64>(x, mask, dst, ncols_x, nrows_y, scale,
13073
  max_bias, m0, m1, n_head_log2, block_nums,
13074
  block_dims, n_local_scratch, stream);
13075
  break;
13076
  case 128:
13077
+ soft_max_f32_submitter<true, 128, 128>(x, mask, dst, ncols_x, nrows_y, scale,
13078
  max_bias, m0, m1, n_head_log2, block_nums,
13079
  block_dims, n_local_scratch, stream);
13080
  break;
13081
  case 256:
13082
+ soft_max_f32_submitter<true, 256, 256>(x, mask, dst, ncols_x, nrows_y, scale,
13083
  max_bias, m0, m1, n_head_log2, block_nums,
13084
  block_dims, n_local_scratch, stream);
13085
  break;
13086
  case 512:
13087
+ soft_max_f32_submitter<true, 512, 512>(x, mask, dst, ncols_x, nrows_y, scale,
13088
  max_bias, m0, m1, n_head_log2, block_nums,
13089
  block_dims, n_local_scratch, stream);
13090
  break;
13091
  case 1024:
13092
+ soft_max_f32_submitter<true, 1024, 1024>(x, mask, dst, ncols_x, nrows_y, scale,
13093
  max_bias, m0, m1, n_head_log2, block_nums,
13094
  block_dims, n_local_scratch, stream);
13095
  break;
13096
  case 2048:
13097
+ soft_max_f32_submitter<true, 2048, 1024>(x, mask, dst, ncols_x, nrows_y, scale,
13098
  max_bias, m0, m1, n_head_log2, block_nums,
13099
  block_dims, n_local_scratch, stream);
13100
  break;
13101
  case 4096:
13102
+ soft_max_f32_submitter<true, 4096, 1024>(x, mask, dst, ncols_x, nrows_y, scale,
13103
  max_bias, m0, m1, n_head_log2, block_nums,
13104
  block_dims, n_local_scratch, stream);
13105
  break;
13106
  default:
13107
+ soft_max_f32_submitter<true, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale,
13108
  max_bias, m0, m1, n_head_log2, block_nums,
13109
  block_dims, n_local_scratch, stream);
13110
  break;
13111
  }
13112
  } else {
13113
+ soft_max_f32_submitter<false, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale,
13114
  max_bias, m0, m1, n_head_log2, block_nums,
13115
  block_dims, WARP_SIZE, stream);
13116
  }
 
14521
  (void) src1_dd;
14522
  }
14523
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14524
  static void ggml_sycl_op_pool2d(const ggml_tensor *src0,
14525
  const ggml_tensor *src1, ggml_tensor *dst,
14526
  const float *src0_dd, const float *src1_dd,
 
14675
  GGML_ASSERT(src0->type == GGML_TYPE_F32);
14676
  GGML_ASSERT( dst->type == GGML_TYPE_F32);
14677
 
14678
+ #pragma message("TODO: add ggml_sycl_op_soft_max() F16 src1 support")
 
 
14679
  #pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021")
14680
  GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
 
14681
 
14682
  const int64_t ne00 = src0->ne[0];
14683
  const int64_t nrows_x = ggml_nrows(src0);
 
14689
  memcpy(&scale, dst->op_params + 0, sizeof(float));
14690
  memcpy(&max_bias, dst->op_params + 1, sizeof(float));
14691
 
14692
+ soft_max_f32_sycl(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14693
  nrows_x, nrows_y, scale, max_bias, main_stream);
14694
  }
14695
 
 
16140
  ggml_sycl_op_flatten(src0, src1, dst, ggml_sycl_op_rope);
16141
  }
16142
 
 
 
 
 
16143
  static void ggml_sycl_pool2d(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
16144
  ggml_sycl_op_flatten(src0, src1, dst, ggml_sycl_op_pool2d);
16145
  }
 
16516
  case GGML_OP_ROPE:
16517
  func = ggml_sycl_rope;
16518
  break;
 
 
 
16519
  case GGML_OP_IM2COL:
16520
  func = ggml_sycl_im2col;
16521
  break;
 
17645
  case GGML_OP_DIAG_MASK_INF:
17646
  case GGML_OP_SOFT_MAX:
17647
  case GGML_OP_ROPE:
 
17648
  case GGML_OP_IM2COL:
17649
  case GGML_OP_POOL_2D:
17650
  case GGML_OP_SUM_ROWS:
ggml-vulkan.cpp CHANGED
@@ -3830,9 +3830,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
3830
  return nullptr;
3831
  case GGML_OP_SOFT_MAX:
3832
  GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);
3833
- GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32 || src2->type == GGML_TYPE_F16);
3834
 
3835
- if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && (src2 == nullptr || src2->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
3836
  return ctx->device->pipeline_soft_max_f32;
3837
  }
3838
  if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16 && src2->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
@@ -4286,6 +4285,9 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context * subctx,
4286
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
4287
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
4288
 
 
 
 
4289
  ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_SOFT_MAX, {
4290
  ncols,
4291
  src1 != nullptr ? nrows_y : (uint32_t)0,
 
3830
  return nullptr;
3831
  case GGML_OP_SOFT_MAX:
3832
  GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);
 
3833
 
3834
+ if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
3835
  return ctx->device->pipeline_soft_max_f32;
3836
  }
3837
  if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16 && src2->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
 
4285
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
4286
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
4287
 
4288
+ #pragma message("TODO: src2 is no longer used in soft_max - should be removed and ALiBi calculation should be updated")
4289
+ #pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/7192")
4290
+
4291
  ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_SOFT_MAX, {
4292
  ncols,
4293
  src1 != nullptr ? nrows_y : (uint32_t)0,
ggml.c CHANGED
@@ -2186,7 +2186,6 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
2186
  "SOFT_MAX_BACK",
2187
  "ROPE",
2188
  "ROPE_BACK",
2189
- "ALIBI",
2190
  "CLAMP",
2191
  "CONV_TRANSPOSE_1D",
2192
  "IM2COL",
@@ -2228,7 +2227,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
2228
  "CROSS_ENTROPY_LOSS_BACK",
2229
  };
2230
 
2231
- static_assert(GGML_OP_COUNT == 77, "GGML_OP_COUNT != 77");
2232
 
2233
  static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2234
  "none",
@@ -2277,7 +2276,6 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2277
  "soft_max_back(x)",
2278
  "rope(x)",
2279
  "rope_back(x)",
2280
- "alibi(x)",
2281
  "clamp(x)",
2282
  "conv_transpose_1d(x)",
2283
  "im2col(x)",
@@ -2319,7 +2317,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2319
  "cross_entropy_loss_back(x,y)",
2320
  };
2321
 
2322
- static_assert(GGML_OP_COUNT == 77, "GGML_OP_COUNT != 77");
2323
 
2324
  static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
2325
 
@@ -5662,7 +5660,6 @@ static struct ggml_tensor * ggml_soft_max_impl(
5662
  struct ggml_context * ctx,
5663
  struct ggml_tensor * a,
5664
  struct ggml_tensor * mask,
5665
- struct ggml_tensor * pos,
5666
  float scale,
5667
  float max_bias,
5668
  bool inplace) {
@@ -5676,18 +5673,8 @@ static struct ggml_tensor * ggml_soft_max_impl(
5676
  GGML_ASSERT(mask->ne[1] >= a->ne[1]);
5677
  }
5678
 
5679
- if (pos) {
5680
- GGML_ASSERT(ggml_is_vector(pos));
5681
- GGML_ASSERT(pos->type == GGML_TYPE_F16 || pos->type == GGML_TYPE_F32);
5682
- GGML_ASSERT(pos->ne[0] == a->ne[0]);
5683
- }
5684
-
5685
- if (pos && mask) {
5686
- GGML_ASSERT(pos->type == mask->type);
5687
- }
5688
-
5689
  if (max_bias > 0.0f) {
5690
- GGML_ASSERT(pos);
5691
  }
5692
 
5693
  bool is_node = false;
@@ -5705,7 +5692,6 @@ static struct ggml_tensor * ggml_soft_max_impl(
5705
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5706
  result->src[0] = a;
5707
  result->src[1] = mask;
5708
- result->src[2] = pos;
5709
 
5710
  return result;
5711
  }
@@ -5713,23 +5699,22 @@ static struct ggml_tensor * ggml_soft_max_impl(
5713
  struct ggml_tensor * ggml_soft_max(
5714
  struct ggml_context * ctx,
5715
  struct ggml_tensor * a) {
5716
- return ggml_soft_max_impl(ctx, a, NULL, NULL, 1.0f, 0.0f, false);
5717
  }
5718
 
5719
  struct ggml_tensor * ggml_soft_max_inplace(
5720
  struct ggml_context * ctx,
5721
  struct ggml_tensor * a) {
5722
- return ggml_soft_max_impl(ctx, a, NULL, NULL, 1.0f, 0.0f, true);
5723
  }
5724
 
5725
  struct ggml_tensor * ggml_soft_max_ext(
5726
  struct ggml_context * ctx,
5727
  struct ggml_tensor * a,
5728
  struct ggml_tensor * mask,
5729
- struct ggml_tensor * pos,
5730
  float scale,
5731
  float max_bias) {
5732
- return ggml_soft_max_impl(ctx, a, mask, pos, scale, max_bias, false);
5733
  }
5734
 
5735
  // ggml_soft_max_back
@@ -5944,37 +5929,6 @@ struct ggml_tensor * ggml_rope_back(
5944
  return result;
5945
  }
5946
 
5947
- // ggml_alibi
5948
-
5949
- struct ggml_tensor * ggml_alibi(
5950
- struct ggml_context * ctx,
5951
- struct ggml_tensor * a,
5952
- int n_past,
5953
- int n_head,
5954
- float bias_max) {
5955
- GGML_ASSERT(n_past >= 0);
5956
- bool is_node = false;
5957
-
5958
- if (a->grad) {
5959
- GGML_ASSERT(false); // TODO: implement backward
5960
- is_node = true;
5961
- }
5962
-
5963
- // TODO: when implement backward, fix this:
5964
- //struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
5965
- struct ggml_tensor * result = ggml_view_tensor(ctx, a);
5966
-
5967
- int32_t op_params[3] = { n_past, n_head };
5968
- memcpy(op_params + 2, &bias_max, sizeof(float));
5969
- ggml_set_op_params(result, op_params, sizeof(op_params));
5970
-
5971
- result->op = GGML_OP_ALIBI;
5972
- result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5973
- result->src[0] = a;
5974
-
5975
- return result;
5976
- }
5977
-
5978
  // ggml_clamp
5979
 
5980
  struct ggml_tensor * ggml_clamp(
@@ -6502,9 +6456,11 @@ struct ggml_tensor * ggml_flash_attn_ext(
6502
  struct ggml_tensor * k,
6503
  struct ggml_tensor * v,
6504
  struct ggml_tensor * mask,
6505
- float scale) {
 
6506
  GGML_ASSERT(ggml_can_mul_mat(k, q));
6507
  // TODO: check if vT can be multiplied by (k*qT)
 
6508
  if (mask) {
6509
  GGML_ASSERT(ggml_is_contiguous(mask));
6510
  GGML_ASSERT(mask->ne[2] == 1);
@@ -6514,6 +6470,10 @@ struct ggml_tensor * ggml_flash_attn_ext(
6514
  //GGML_ASSERT(ggml_can_repeat_rows(mask, qk));
6515
  }
6516
 
 
 
 
 
6517
  bool is_node = false;
6518
 
6519
  if (q->grad || k->grad || v->grad) {
@@ -6524,7 +6484,7 @@ struct ggml_tensor * ggml_flash_attn_ext(
6524
  int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] };
6525
  struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
6526
 
6527
- float params[] = { scale };
6528
  ggml_set_op_params(result, params, sizeof(params));
6529
 
6530
  result->op = GGML_OP_FLASH_ATTN_EXT;
@@ -6544,7 +6504,7 @@ void ggml_flash_attn_ext_set_prec(
6544
 
6545
  const int32_t prec_i32 = (int32_t) prec;
6546
 
6547
- ggml_set_op_params_i32(a, 1, prec_i32); // scale is on first pos
6548
  }
6549
 
6550
  // ggml_flash_ff
@@ -13395,7 +13355,6 @@ static void ggml_compute_forward_soft_max_f32(
13395
 
13396
  const struct ggml_tensor * src0 = dst->src[0];
13397
  const struct ggml_tensor * src1 = dst->src[1];
13398
- const struct ggml_tensor * src2 = dst->src[2];
13399
 
13400
  assert(ggml_is_contiguous(dst));
13401
  assert(ggml_are_same_shape(src0, dst));
@@ -13421,8 +13380,8 @@ static void ggml_compute_forward_soft_max_f32(
13421
 
13422
  // TODO: is this supposed to be ceil instead of floor?
13423
  // https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
13424
- const uint32_t n_head_kv = ne02;
13425
- const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head_kv));
13426
 
13427
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
13428
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
@@ -13439,13 +13398,13 @@ static void ggml_compute_forward_soft_max_f32(
13439
 
13440
  float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
13441
 
13442
- // when max_bias <= 0.0f, src2 is not used and we default it to src0 to avoid branching
13443
- ggml_fp16_t * pos_f16 = src2 ? (ggml_fp16_t *) src2->data : src0->data;
13444
- float * pos_f32 = src2 ? (float *) src2->data : src0->data;
13445
-
13446
- const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16);
13447
 
13448
  for (int i1 = ir0; i1 < ir1; i1++) {
 
 
 
 
13449
  float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
13450
  float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);
13451
 
@@ -13458,27 +13417,11 @@ static void ggml_compute_forward_soft_max_f32(
13458
  if (mp_f32) {
13459
  if (use_f16) {
13460
  for (int i = 0; i < nc; ++i) {
13461
- wp[i] += GGML_FP16_TO_FP32(mp_f16[i]);
13462
  }
13463
  } else {
13464
  for (int i = 0; i < nc; ++i) {
13465
- wp[i] += mp_f32[i];
13466
- }
13467
- }
13468
- }
13469
-
13470
- // ALiBi bias
13471
- if (max_bias > 0.0f) {
13472
- const uint32_t h = (i1/ne01)%ne02; // head
13473
- const float slope = h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1);
13474
-
13475
- if (use_f16) {
13476
- for (int i = 0; i < nc; ++i) {
13477
- wp[i] += slope*GGML_FP16_TO_FP32(pos_f16[i]);
13478
- }
13479
- } else {
13480
- for (int i = 0; i < nc; ++i) {
13481
- wp[i] += slope*pos_f32[i];
13482
  }
13483
  }
13484
  }
@@ -13640,178 +13583,6 @@ static void ggml_compute_forward_soft_max_back(
13640
  }
13641
  }
13642
 
13643
- // ggml_compute_forward_alibi
13644
-
13645
- static void ggml_compute_forward_alibi_f32(
13646
- const struct ggml_compute_params * params,
13647
- struct ggml_tensor * dst) {
13648
-
13649
- const struct ggml_tensor * src0 = dst->src[0];
13650
-
13651
- assert(params->ith == 0);
13652
-
13653
- if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
13654
- return;
13655
- }
13656
-
13657
- //const int n_past = ((int32_t *) dst->op_params)[0];
13658
- const int n_head = ((int32_t *) dst->op_params)[1];
13659
- float max_bias;
13660
- memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
13661
-
13662
- const int64_t ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
13663
- const int64_t ne1 = src0->ne[1]; // seq_len_without_past
13664
- const int64_t ne2 = src0->ne[2]; // n_head -> this is k
13665
- //const int64_t ne3 = src0->ne[3]; // 1 -> bsz
13666
-
13667
- const int64_t n = ggml_nrows(src0);
13668
- const int64_t ne2_ne3 = n/ne1; // ne2*ne3
13669
-
13670
- const size_t nb0 = src0->nb[0];
13671
- const size_t nb1 = src0->nb[1];
13672
- const size_t nb2 = src0->nb[2];
13673
- //const int nb3 = src0->nb[3];
13674
-
13675
- GGML_ASSERT(nb0 == sizeof(float));
13676
- GGML_ASSERT(n_head == ne2);
13677
-
13678
- // add alibi to src0 (KQ_scaled)
13679
- const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
13680
-
13681
- const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
13682
- const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
13683
-
13684
- for (int64_t k = 0; k < ne2_ne3; k++) {
13685
- // TODO: k*nb2 or k*nb3
13686
- float m_k;
13687
-
13688
- if (k < n_heads_log2_floor) {
13689
- m_k = powf(m0, k + 1);
13690
- } else {
13691
- m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
13692
- }
13693
-
13694
- for (int64_t i = 0; i < ne0; i++) {
13695
- for (int64_t j = 0; j < ne1; j++) {
13696
- float * const src = (float *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2);
13697
- float * pdst = (float *)((char *) dst->data + i*nb0 + j*nb1 + k*nb2);
13698
- pdst[0] = i * m_k + src[0];
13699
- }
13700
- }
13701
- }
13702
- }
13703
-
13704
- static void ggml_compute_forward_alibi_f16(
13705
- const struct ggml_compute_params * params,
13706
- struct ggml_tensor * dst) {
13707
-
13708
- const struct ggml_tensor * src0 = dst->src[0];
13709
-
13710
- assert(params->ith == 0);
13711
-
13712
- if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
13713
- return;
13714
- }
13715
-
13716
- //const int n_past = ((int32_t *) dst->op_params)[0];
13717
- const int n_head = ((int32_t *) dst->op_params)[1];
13718
- float max_bias;
13719
- memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
13720
-
13721
- const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
13722
- const int ne1 = src0->ne[1]; // seq_len_without_past
13723
- const int ne2 = src0->ne[2]; // n_head -> this is k
13724
- //const int ne3 = src0->ne[3]; // 1 -> bsz
13725
-
13726
- const int n = ggml_nrows(src0);
13727
- const int ne2_ne3 = n/ne1; // ne2*ne3
13728
-
13729
- const int nb0 = src0->nb[0];
13730
- const int nb1 = src0->nb[1];
13731
- const int nb2 = src0->nb[2];
13732
- //const int nb3 = src0->nb[3];
13733
-
13734
- GGML_ASSERT(nb0 == sizeof(ggml_fp16_t));
13735
- //GGML_ASSERT(ne1 + n_past == ne0); (void) n_past;
13736
- GGML_ASSERT(n_head == ne2);
13737
-
13738
- // add alibi to src0 (KQ_scaled)
13739
- const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
13740
-
13741
- const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
13742
- const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
13743
-
13744
- for (int k = 0; k < ne2_ne3; k++) {
13745
- // TODO: k*nb2 or k*nb3
13746
- float m_k;
13747
-
13748
- if (k < n_heads_log2_floor) {
13749
- m_k = powf(m0, k + 1);
13750
- } else {
13751
- m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
13752
- }
13753
-
13754
- for (int i = 0; i < ne0; i++) {
13755
- for (int j = 0; j < ne1; j++) {
13756
- ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2);
13757
- float * pdst = (float *)((char *) dst->data + i*nb0 + j*nb1 + k*nb2);
13758
-
13759
- // we return F32
13760
- pdst[0] = i * m_k + GGML_FP16_TO_FP32(src[0]);
13761
- }
13762
- }
13763
- }
13764
- }
13765
-
13766
- static void ggml_compute_forward_alibi(
13767
- const struct ggml_compute_params * params,
13768
- struct ggml_tensor * dst) {
13769
-
13770
- const struct ggml_tensor * src0 = dst->src[0];
13771
-
13772
- switch (src0->type) {
13773
- case GGML_TYPE_F16:
13774
- {
13775
- ggml_compute_forward_alibi_f16(params, dst);
13776
- } break;
13777
- case GGML_TYPE_F32:
13778
- {
13779
- ggml_compute_forward_alibi_f32(params, dst);
13780
- } break;
13781
- case GGML_TYPE_BF16:
13782
- case GGML_TYPE_Q4_0:
13783
- case GGML_TYPE_Q4_1:
13784
- case GGML_TYPE_Q5_0:
13785
- case GGML_TYPE_Q5_1:
13786
- case GGML_TYPE_Q8_0:
13787
- case GGML_TYPE_Q8_1:
13788
- case GGML_TYPE_Q2_K:
13789
- case GGML_TYPE_Q3_K:
13790
- case GGML_TYPE_Q4_K:
13791
- case GGML_TYPE_Q5_K:
13792
- case GGML_TYPE_Q6_K:
13793
- case GGML_TYPE_IQ2_XXS:
13794
- case GGML_TYPE_IQ2_XS:
13795
- case GGML_TYPE_IQ3_XXS:
13796
- case GGML_TYPE_IQ1_S:
13797
- case GGML_TYPE_IQ1_M:
13798
- case GGML_TYPE_IQ4_NL:
13799
- case GGML_TYPE_IQ4_XS:
13800
- case GGML_TYPE_IQ3_S:
13801
- case GGML_TYPE_IQ2_S:
13802
- case GGML_TYPE_Q8_K:
13803
- case GGML_TYPE_I8:
13804
- case GGML_TYPE_I16:
13805
- case GGML_TYPE_I32:
13806
- case GGML_TYPE_I64:
13807
- case GGML_TYPE_F64:
13808
- case GGML_TYPE_COUNT:
13809
- {
13810
- GGML_ASSERT(false);
13811
- } break;
13812
- }
13813
- }
13814
-
13815
  // ggml_compute_forward_clamp
13816
 
13817
  static void ggml_compute_forward_clamp_f32(
@@ -15825,8 +15596,17 @@ static void ggml_compute_forward_flash_attn_ext_f16(
15825
  const int ir0 = dr*ith;
15826
  const int ir1 = MIN(ir0 + dr, nr);
15827
 
15828
- float scale = 1.0f;
15829
- memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
 
 
 
 
 
 
 
 
 
15830
 
15831
  // loop over n_batch and n_head
15832
  for (int ir = ir0; ir < ir1; ++ir) {
@@ -15835,6 +15615,9 @@ static void ggml_compute_forward_flash_attn_ext_f16(
15835
  const int iq2 = (ir - iq3*neq2*neq1)/neq1;
15836
  const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
15837
 
 
 
 
15838
  float S = 0.0f;
15839
  float M = -INFINITY;
15840
 
@@ -15858,7 +15641,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
15858
  // loop over n_kv and n_head_kv
15859
  // ref: https://arxiv.org/pdf/2112.05682.pdf
15860
  for (int64_t ic = 0; ic < nek1; ++ic) {
15861
- const float mv = mp ? GGML_FP16_TO_FP32(mp[ic]) : 0.0f;
15862
  if (mv == -INFINITY) {
15863
  continue;
15864
  }
@@ -15929,7 +15712,7 @@ static void ggml_compute_forward_flash_attn_ext(
15929
  const struct ggml_tensor * v,
15930
  const struct ggml_tensor * mask,
15931
  struct ggml_tensor * dst) {
15932
- switch (dst->op_params[1]) {
15933
  case GGML_PREC_DEFAULT:
15934
  case GGML_PREC_F32:
15935
  {
@@ -17696,10 +17479,6 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
17696
  {
17697
  ggml_compute_forward_rope_back(params, tensor);
17698
  } break;
17699
- case GGML_OP_ALIBI:
17700
- {
17701
- ggml_compute_forward_alibi(params, tensor);
17702
- } break;
17703
  case GGML_OP_CLAMP:
17704
  {
17705
  ggml_compute_forward_clamp(params, tensor);
@@ -18718,10 +18497,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18718
  zero_table);
18719
  }
18720
  } break;
18721
- case GGML_OP_ALIBI:
18722
- {
18723
- GGML_ASSERT(false); // TODO: not implemented
18724
- } break;
18725
  case GGML_OP_CLAMP:
18726
  {
18727
  GGML_ASSERT(false); // TODO: not implemented
@@ -19499,10 +19274,6 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_
19499
  {
19500
  n_tasks = n_threads;
19501
  } break;
19502
- case GGML_OP_ALIBI:
19503
- {
19504
- n_tasks = 1; //TODO
19505
- } break;
19506
  case GGML_OP_CLAMP:
19507
  {
19508
  n_tasks = 1; //TODO
 
2186
  "SOFT_MAX_BACK",
2187
  "ROPE",
2188
  "ROPE_BACK",
 
2189
  "CLAMP",
2190
  "CONV_TRANSPOSE_1D",
2191
  "IM2COL",
 
2227
  "CROSS_ENTROPY_LOSS_BACK",
2228
  };
2229
 
2230
+ static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
2231
 
2232
  static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2233
  "none",
 
2276
  "soft_max_back(x)",
2277
  "rope(x)",
2278
  "rope_back(x)",
 
2279
  "clamp(x)",
2280
  "conv_transpose_1d(x)",
2281
  "im2col(x)",
 
2317
  "cross_entropy_loss_back(x,y)",
2318
  };
2319
 
2320
+ static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
2321
 
2322
  static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
2323
 
 
5660
  struct ggml_context * ctx,
5661
  struct ggml_tensor * a,
5662
  struct ggml_tensor * mask,
 
5663
  float scale,
5664
  float max_bias,
5665
  bool inplace) {
 
5673
  GGML_ASSERT(mask->ne[1] >= a->ne[1]);
5674
  }
5675
 
 
 
 
 
 
 
 
 
 
 
5676
  if (max_bias > 0.0f) {
5677
+ GGML_ASSERT(mask);
5678
  }
5679
 
5680
  bool is_node = false;
 
5692
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5693
  result->src[0] = a;
5694
  result->src[1] = mask;
 
5695
 
5696
  return result;
5697
  }
 
5699
  struct ggml_tensor * ggml_soft_max(
5700
  struct ggml_context * ctx,
5701
  struct ggml_tensor * a) {
5702
+ return ggml_soft_max_impl(ctx, a, NULL, 1.0f, 0.0f, false);
5703
  }
5704
 
5705
  struct ggml_tensor * ggml_soft_max_inplace(
5706
  struct ggml_context * ctx,
5707
  struct ggml_tensor * a) {
5708
+ return ggml_soft_max_impl(ctx, a, NULL, 1.0f, 0.0f, true);
5709
  }
5710
 
5711
  struct ggml_tensor * ggml_soft_max_ext(
5712
  struct ggml_context * ctx,
5713
  struct ggml_tensor * a,
5714
  struct ggml_tensor * mask,
 
5715
  float scale,
5716
  float max_bias) {
5717
+ return ggml_soft_max_impl(ctx, a, mask, scale, max_bias, false);
5718
  }
5719
 
5720
  // ggml_soft_max_back
 
5929
  return result;
5930
  }
5931
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5932
  // ggml_clamp
5933
 
5934
  struct ggml_tensor * ggml_clamp(
 
6456
  struct ggml_tensor * k,
6457
  struct ggml_tensor * v,
6458
  struct ggml_tensor * mask,
6459
+ float scale,
6460
+ float max_bias) {
6461
  GGML_ASSERT(ggml_can_mul_mat(k, q));
6462
  // TODO: check if vT can be multiplied by (k*qT)
6463
+
6464
  if (mask) {
6465
  GGML_ASSERT(ggml_is_contiguous(mask));
6466
  GGML_ASSERT(mask->ne[2] == 1);
 
6470
  //GGML_ASSERT(ggml_can_repeat_rows(mask, qk));
6471
  }
6472
 
6473
+ if (max_bias > 0.0f) {
6474
+ GGML_ASSERT(mask);
6475
+ }
6476
+
6477
  bool is_node = false;
6478
 
6479
  if (q->grad || k->grad || v->grad) {
 
6484
  int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] };
6485
  struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
6486
 
6487
+ float params[] = { scale, max_bias };
6488
  ggml_set_op_params(result, params, sizeof(params));
6489
 
6490
  result->op = GGML_OP_FLASH_ATTN_EXT;
 
6504
 
6505
  const int32_t prec_i32 = (int32_t) prec;
6506
 
6507
+ ggml_set_op_params_i32(a, 2, prec_i32); // scale is on first pos, max_bias on second
6508
  }
6509
 
6510
  // ggml_flash_ff
 
13355
 
13356
  const struct ggml_tensor * src0 = dst->src[0];
13357
  const struct ggml_tensor * src1 = dst->src[1];
 
13358
 
13359
  assert(ggml_is_contiguous(dst));
13360
  assert(ggml_are_same_shape(src0, dst));
 
13380
 
13381
  // TODO: is this supposed to be ceil instead of floor?
13382
  // https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
13383
+ const uint32_t n_head = ne02;
13384
+ const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
13385
 
13386
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
13387
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
 
13398
 
13399
  float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
13400
 
13401
+ const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
 
 
 
 
13402
 
13403
  for (int i1 = ir0; i1 < ir1; i1++) {
13404
+ // ALiBi
13405
+ const uint32_t h = (i1/ne01)%ne02; // head
13406
+ const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
13407
+
13408
  float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
13409
  float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);
13410
 
 
13417
  if (mp_f32) {
13418
  if (use_f16) {
13419
  for (int i = 0; i < nc; ++i) {
13420
+ wp[i] += slope*GGML_FP16_TO_FP32(mp_f16[i]);
13421
  }
13422
  } else {
13423
  for (int i = 0; i < nc; ++i) {
13424
+ wp[i] += slope*mp_f32[i];
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13425
  }
13426
  }
13427
  }
 
13583
  }
13584
  }
13585
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13586
  // ggml_compute_forward_clamp
13587
 
13588
  static void ggml_compute_forward_clamp_f32(
 
15596
  const int ir0 = dr*ith;
15597
  const int ir1 = MIN(ir0 + dr, nr);
15598
 
15599
+ float scale = 1.0f;
15600
+ float max_bias = 0.0f;
15601
+
15602
+ memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
15603
+ memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
15604
+
15605
+ const uint32_t n_head = neq2;
15606
+ const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
15607
+
15608
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
15609
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
15610
 
15611
  // loop over n_batch and n_head
15612
  for (int ir = ir0; ir < ir1; ++ir) {
 
15615
  const int iq2 = (ir - iq3*neq2*neq1)/neq1;
15616
  const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
15617
 
15618
+ const uint32_t h = iq2; // head
15619
+ const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
15620
+
15621
  float S = 0.0f;
15622
  float M = -INFINITY;
15623
 
 
15641
  // loop over n_kv and n_head_kv
15642
  // ref: https://arxiv.org/pdf/2112.05682.pdf
15643
  for (int64_t ic = 0; ic < nek1; ++ic) {
15644
+ const float mv = mp ? slope*GGML_FP16_TO_FP32(mp[ic]) : 0.0f;
15645
  if (mv == -INFINITY) {
15646
  continue;
15647
  }
 
15712
  const struct ggml_tensor * v,
15713
  const struct ggml_tensor * mask,
15714
  struct ggml_tensor * dst) {
15715
+ switch (dst->op_params[2]) {
15716
  case GGML_PREC_DEFAULT:
15717
  case GGML_PREC_F32:
15718
  {
 
17479
  {
17480
  ggml_compute_forward_rope_back(params, tensor);
17481
  } break;
 
 
 
 
17482
  case GGML_OP_CLAMP:
17483
  {
17484
  ggml_compute_forward_clamp(params, tensor);
 
18497
  zero_table);
18498
  }
18499
  } break;
 
 
 
 
18500
  case GGML_OP_CLAMP:
18501
  {
18502
  GGML_ASSERT(false); // TODO: not implemented
 
19274
  {
19275
  n_tasks = n_threads;
19276
  } break;
 
 
 
 
19277
  case GGML_OP_CLAMP:
19278
  {
19279
  n_tasks = 1; //TODO
ggml.h CHANGED
@@ -468,7 +468,6 @@ extern "C" {
468
  GGML_OP_SOFT_MAX_BACK,
469
  GGML_OP_ROPE,
470
  GGML_OP_ROPE_BACK,
471
- GGML_OP_ALIBI,
472
  GGML_OP_CLAMP,
473
  GGML_OP_CONV_TRANSPOSE_1D,
474
  GGML_OP_IM2COL,
@@ -1437,15 +1436,13 @@ extern "C" {
1437
  struct ggml_context * ctx,
1438
  struct ggml_tensor * a);
1439
 
1440
- // fused soft_max(a*scale + mask + pos[i]*(ALiBi slope))
1441
  // mask is optional
1442
- // pos is required when max_bias > 0.0f
1443
  // max_bias = 0.0f for no ALiBi
1444
  GGML_API struct ggml_tensor * ggml_soft_max_ext(
1445
  struct ggml_context * ctx,
1446
  struct ggml_tensor * a,
1447
  struct ggml_tensor * mask,
1448
- struct ggml_tensor * pos,
1449
  float scale,
1450
  float max_bias);
1451
 
@@ -1547,16 +1544,6 @@ extern "C" {
1547
  float xpos_base,
1548
  bool xpos_down);
1549
 
1550
- // alibi position embedding
1551
- // in-place, returns view(a)
1552
- GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_alibi(
1553
- struct ggml_context * ctx,
1554
- struct ggml_tensor * a,
1555
- int n_past,
1556
- int n_head,
1557
- float bias_max),
1558
- "use ggml_soft_max_ext instead (will be removed in Mar 2024)");
1559
-
1560
  // clamp
1561
  // in-place, returns view(a)
1562
  GGML_API struct ggml_tensor * ggml_clamp(
@@ -1753,7 +1740,8 @@ extern "C" {
1753
  struct ggml_tensor * k,
1754
  struct ggml_tensor * v,
1755
  struct ggml_tensor * mask,
1756
- float scale);
 
1757
 
1758
  GGML_API void ggml_flash_attn_ext_set_prec(
1759
  struct ggml_tensor * a,
 
468
  GGML_OP_SOFT_MAX_BACK,
469
  GGML_OP_ROPE,
470
  GGML_OP_ROPE_BACK,
 
471
  GGML_OP_CLAMP,
472
  GGML_OP_CONV_TRANSPOSE_1D,
473
  GGML_OP_IM2COL,
 
1436
  struct ggml_context * ctx,
1437
  struct ggml_tensor * a);
1438
 
1439
+ // fused soft_max(a*scale + mask*(ALiBi slope))
1440
  // mask is optional
 
1441
  // max_bias = 0.0f for no ALiBi
1442
  GGML_API struct ggml_tensor * ggml_soft_max_ext(
1443
  struct ggml_context * ctx,
1444
  struct ggml_tensor * a,
1445
  struct ggml_tensor * mask,
 
1446
  float scale,
1447
  float max_bias);
1448
 
 
1544
  float xpos_base,
1545
  bool xpos_down);
1546
 
 
 
 
 
 
 
 
 
 
 
1547
  // clamp
1548
  // in-place, returns view(a)
1549
  GGML_API struct ggml_tensor * ggml_clamp(
 
1740
  struct ggml_tensor * k,
1741
  struct ggml_tensor * v,
1742
  struct ggml_tensor * mask,
1743
+ float scale,
1744
+ float max_bias);
1745
 
1746
  GGML_API void ggml_flash_attn_ext_set_prec(
1747
  struct ggml_tensor * a,