Spaces:
Running
Running
ggml : full ALiBi support (llama/7192)
Browse files* ggml : full ALiBi support
* ggml : update ggml_soft_max_ext() CUDA, SYCL
* ggml : ggml_flash_attn_ext() support ALiBi (CPU)
* ggml : ggml_flash_attn_ext() support ALiBi (Metal)
* ggml : fix warning
* ggml : ggml_flash_attn_ext() support ALiBi (CUDA)
ggml-ci
* ggml : fix assert message
* vulkan : add dev notes
* ggml : require mask when using ALiBi
ggml-ci
* convert : fix convert for refact models
- ggml-cuda.cu +0 -5
- ggml-cuda/fattn.cu +62 -10
- ggml-cuda/softmax.cu +21 -34
- ggml-kompute.cpp +9 -3
- ggml-metal.m +54 -94
- ggml-metal.metal +49 -71
- ggml-sycl.cpp +19 -119
- ggml-vulkan.cpp +4 -2
- ggml.c +40 -269
- ggml.h +3 -15
ggml-cuda.cu
CHANGED
|
@@ -4,7 +4,6 @@
|
|
| 4 |
|
| 5 |
#include "ggml-cuda/common.cuh"
|
| 6 |
#include "ggml-cuda/acc.cuh"
|
| 7 |
-
#include "ggml-cuda/alibi.cuh"
|
| 8 |
#include "ggml-cuda/arange.cuh"
|
| 9 |
#include "ggml-cuda/argsort.cuh"
|
| 10 |
#include "ggml-cuda/binbcast.cuh"
|
|
@@ -2280,9 +2279,6 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
|
| 2280 |
case GGML_OP_ROPE:
|
| 2281 |
ggml_cuda_op_rope(ctx, dst);
|
| 2282 |
break;
|
| 2283 |
-
case GGML_OP_ALIBI:
|
| 2284 |
-
ggml_cuda_op_alibi(ctx, dst);
|
| 2285 |
-
break;
|
| 2286 |
case GGML_OP_IM2COL:
|
| 2287 |
ggml_cuda_op_im2col(ctx, dst);
|
| 2288 |
break;
|
|
@@ -2833,7 +2829,6 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
|
| 2833 |
case GGML_OP_DIAG_MASK_INF:
|
| 2834 |
case GGML_OP_SOFT_MAX:
|
| 2835 |
case GGML_OP_ROPE:
|
| 2836 |
-
case GGML_OP_ALIBI:
|
| 2837 |
case GGML_OP_IM2COL:
|
| 2838 |
case GGML_OP_POOL_2D:
|
| 2839 |
case GGML_OP_SUM_ROWS:
|
|
|
|
| 4 |
|
| 5 |
#include "ggml-cuda/common.cuh"
|
| 6 |
#include "ggml-cuda/acc.cuh"
|
|
|
|
| 7 |
#include "ggml-cuda/arange.cuh"
|
| 8 |
#include "ggml-cuda/argsort.cuh"
|
| 9 |
#include "ggml-cuda/binbcast.cuh"
|
|
|
|
| 2279 |
case GGML_OP_ROPE:
|
| 2280 |
ggml_cuda_op_rope(ctx, dst);
|
| 2281 |
break;
|
|
|
|
|
|
|
|
|
|
| 2282 |
case GGML_OP_IM2COL:
|
| 2283 |
ggml_cuda_op_im2col(ctx, dst);
|
| 2284 |
break;
|
|
|
|
| 2829 |
case GGML_OP_DIAG_MASK_INF:
|
| 2830 |
case GGML_OP_SOFT_MAX:
|
| 2831 |
case GGML_OP_ROPE:
|
|
|
|
| 2832 |
case GGML_OP_IM2COL:
|
| 2833 |
case GGML_OP_POOL_2D:
|
| 2834 |
case GGML_OP_SUM_ROWS:
|
ggml-cuda/fattn.cu
CHANGED
|
@@ -23,6 +23,10 @@ static __global__ void flash_attn_vec_ext_f16(
|
|
| 23 |
float * __restrict__ dst,
|
| 24 |
float2 * __restrict__ dst_meta,
|
| 25 |
const float scale,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
const int ne00,
|
| 27 |
const int ne01,
|
| 28 |
const int ne02,
|
|
@@ -58,6 +62,18 @@ static __global__ void flash_attn_vec_ext_f16(
|
|
| 58 |
const int stride_KV = nb11 / sizeof(half);
|
| 59 |
const int stride_KV2 = nb11 / sizeof(half2);
|
| 60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
|
| 62 |
constexpr int nwarps = D / WARP_SIZE;
|
| 63 |
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
|
|
@@ -141,7 +157,7 @@ static __global__ void flash_attn_vec_ext_f16(
|
|
| 141 |
for (int j = 0; j < ncols; ++j) {
|
| 142 |
sum2[j] = warp_reduce_sum(sum2[j]);
|
| 143 |
half sum = __low2half(sum2[j]) + __high2half(sum2[j]);
|
| 144 |
-
sum += mask ? maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
|
| 145 |
|
| 146 |
if (ncols == 1) {
|
| 147 |
kqmax_new = ggml_cuda_hmax(kqmax_new, sum);
|
|
@@ -249,6 +265,10 @@ static __global__ void flash_attn_ext_f16(
|
|
| 249 |
float * __restrict__ dst,
|
| 250 |
float2 * __restrict__ dst_meta,
|
| 251 |
const float scale,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 252 |
const int ne00,
|
| 253 |
const int ne01,
|
| 254 |
const int ne02,
|
|
@@ -305,6 +325,20 @@ static __global__ void flash_attn_ext_f16(
|
|
| 305 |
const int stride_Q = nb01 / sizeof(float);
|
| 306 |
const int stride_KV = nb11 / sizeof(half);
|
| 307 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 308 |
frag_b Q_b[D/16][ncols/frag_n];
|
| 309 |
|
| 310 |
// A single buffer for temporarily holding tiles of KQ and VKQ parts:
|
|
@@ -421,7 +455,7 @@ static __global__ void flash_attn_ext_f16(
|
|
| 421 |
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
|
| 422 |
const int k = k0 + threadIdx.x;
|
| 423 |
|
| 424 |
-
KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f;
|
| 425 |
KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/WARP_SIZE]);
|
| 426 |
}
|
| 427 |
KQ_max_new = warp_reduce_max(KQ_max_new);
|
|
@@ -464,7 +498,7 @@ static __global__ void flash_attn_ext_f16(
|
|
| 464 |
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
|
| 465 |
const int k = k0 + threadIdx.x;
|
| 466 |
|
| 467 |
-
KQ2_tmp[k0/WARP_SIZE] += mask ? mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
|
| 468 |
KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]);
|
| 469 |
}
|
| 470 |
KQ_max_new = __half2half2(warp_reduce_max(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new))));
|
|
@@ -710,8 +744,17 @@ template <int D, int cols_per_block, int parallel_blocks> void launch_fattn_vec_
|
|
| 710 |
const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]);
|
| 711 |
const int shmem = 0;
|
| 712 |
|
| 713 |
-
float scale;
|
| 714 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 715 |
|
| 716 |
flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks>
|
| 717 |
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
|
@@ -720,7 +763,7 @@ template <int D, int cols_per_block, int parallel_blocks> void launch_fattn_vec_
|
|
| 720 |
(const char *) V->data,
|
| 721 |
mask ? ((const char *) mask->data) : nullptr,
|
| 722 |
parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
|
| 723 |
-
scale,
|
| 724 |
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
| 725 |
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
| 726 |
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
|
@@ -761,8 +804,17 @@ template <int D, int cols_per_block, int nwarps, int parallel_blocks, typename K
|
|
| 761 |
const dim3 blocks_num(parallel_blocks*(Q->ne[1] + cols_per_block - 1) / cols_per_block, Q->ne[2], Q->ne[3]);
|
| 762 |
const int shmem = 0;
|
| 763 |
|
| 764 |
-
float scale;
|
| 765 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 766 |
|
| 767 |
flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>
|
| 768 |
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
|
@@ -771,7 +823,7 @@ template <int D, int cols_per_block, int nwarps, int parallel_blocks, typename K
|
|
| 771 |
(const char *) V->data,
|
| 772 |
mask ? ((const char *) mask->data) : nullptr,
|
| 773 |
(parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
|
| 774 |
-
scale,
|
| 775 |
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
| 776 |
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
| 777 |
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
|
@@ -837,7 +889,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
|
| 837 |
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
| 838 |
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
|
| 839 |
|
| 840 |
-
const int32_t precision = KQV->op_params[
|
| 841 |
|
| 842 |
if (!fp16_mma_available(cc)) {
|
| 843 |
GGML_ASSERT(precision == GGML_PREC_DEFAULT);
|
|
|
|
| 23 |
float * __restrict__ dst,
|
| 24 |
float2 * __restrict__ dst_meta,
|
| 25 |
const float scale,
|
| 26 |
+
const float max_bias,
|
| 27 |
+
const float m0,
|
| 28 |
+
const float m1,
|
| 29 |
+
const uint32_t n_head_log2,
|
| 30 |
const int ne00,
|
| 31 |
const int ne01,
|
| 32 |
const int ne02,
|
|
|
|
| 62 |
const int stride_KV = nb11 / sizeof(half);
|
| 63 |
const int stride_KV2 = nb11 / sizeof(half2);
|
| 64 |
|
| 65 |
+
half slopeh = __float2half(1.0f);
|
| 66 |
+
|
| 67 |
+
// ALiBi
|
| 68 |
+
if (max_bias > 0.0f) {
|
| 69 |
+
const int h = blockIdx.y;
|
| 70 |
+
|
| 71 |
+
const float base = h < n_head_log2 ? m0 : m1;
|
| 72 |
+
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
| 73 |
+
|
| 74 |
+
slopeh = __float2half(powf(base, exph));
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
|
| 78 |
constexpr int nwarps = D / WARP_SIZE;
|
| 79 |
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
|
|
|
|
| 157 |
for (int j = 0; j < ncols; ++j) {
|
| 158 |
sum2[j] = warp_reduce_sum(sum2[j]);
|
| 159 |
half sum = __low2half(sum2[j]) + __high2half(sum2[j]);
|
| 160 |
+
sum += mask ? slopeh*maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
|
| 161 |
|
| 162 |
if (ncols == 1) {
|
| 163 |
kqmax_new = ggml_cuda_hmax(kqmax_new, sum);
|
|
|
|
| 265 |
float * __restrict__ dst,
|
| 266 |
float2 * __restrict__ dst_meta,
|
| 267 |
const float scale,
|
| 268 |
+
const float max_bias,
|
| 269 |
+
const float m0,
|
| 270 |
+
const float m1,
|
| 271 |
+
const uint32_t n_head_log2,
|
| 272 |
const int ne00,
|
| 273 |
const int ne01,
|
| 274 |
const int ne02,
|
|
|
|
| 325 |
const int stride_Q = nb01 / sizeof(float);
|
| 326 |
const int stride_KV = nb11 / sizeof(half);
|
| 327 |
|
| 328 |
+
half slopeh = __float2half(1.0f);
|
| 329 |
+
half2 slope2 = make_half2(1.0f, 1.0f);
|
| 330 |
+
|
| 331 |
+
// ALiBi
|
| 332 |
+
if (max_bias > 0.0f) {
|
| 333 |
+
const int h = blockIdx.y;
|
| 334 |
+
|
| 335 |
+
const float base = h < n_head_log2 ? m0 : m1;
|
| 336 |
+
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
| 337 |
+
|
| 338 |
+
slopeh = __float2half(powf(base, exph));
|
| 339 |
+
slope2 = make_half2(slopeh, slopeh);
|
| 340 |
+
}
|
| 341 |
+
|
| 342 |
frag_b Q_b[D/16][ncols/frag_n];
|
| 343 |
|
| 344 |
// A single buffer for temporarily holding tiles of KQ and VKQ parts:
|
|
|
|
| 455 |
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
|
| 456 |
const int k = k0 + threadIdx.x;
|
| 457 |
|
| 458 |
+
KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f;
|
| 459 |
KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/WARP_SIZE]);
|
| 460 |
}
|
| 461 |
KQ_max_new = warp_reduce_max(KQ_max_new);
|
|
|
|
| 498 |
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
|
| 499 |
const int k = k0 + threadIdx.x;
|
| 500 |
|
| 501 |
+
KQ2_tmp[k0/WARP_SIZE] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
|
| 502 |
KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]);
|
| 503 |
}
|
| 504 |
KQ_max_new = __half2half2(warp_reduce_max(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new))));
|
|
|
|
| 744 |
const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]);
|
| 745 |
const int shmem = 0;
|
| 746 |
|
| 747 |
+
float scale = 1.0f;
|
| 748 |
+
float max_bias = 0.0f;
|
| 749 |
+
|
| 750 |
+
memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float));
|
| 751 |
+
memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
|
| 752 |
+
|
| 753 |
+
const uint32_t n_head = Q->ne[2];
|
| 754 |
+
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
|
| 755 |
+
|
| 756 |
+
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
| 757 |
+
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
| 758 |
|
| 759 |
flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks>
|
| 760 |
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
|
|
|
| 763 |
(const char *) V->data,
|
| 764 |
mask ? ((const char *) mask->data) : nullptr,
|
| 765 |
parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
|
| 766 |
+
scale, max_bias, m0, m1, n_head_log2,
|
| 767 |
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
| 768 |
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
| 769 |
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
|
|
|
| 804 |
const dim3 blocks_num(parallel_blocks*(Q->ne[1] + cols_per_block - 1) / cols_per_block, Q->ne[2], Q->ne[3]);
|
| 805 |
const int shmem = 0;
|
| 806 |
|
| 807 |
+
float scale = 1.0f;
|
| 808 |
+
float max_bias = 0.0f;
|
| 809 |
+
|
| 810 |
+
memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float));
|
| 811 |
+
memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
|
| 812 |
+
|
| 813 |
+
const uint32_t n_head = Q->ne[2];
|
| 814 |
+
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
|
| 815 |
+
|
| 816 |
+
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
| 817 |
+
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
| 818 |
|
| 819 |
flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>
|
| 820 |
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
|
|
|
| 823 |
(const char *) V->data,
|
| 824 |
mask ? ((const char *) mask->data) : nullptr,
|
| 825 |
(parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
|
| 826 |
+
scale, max_bias, m0, m1, n_head_log2,
|
| 827 |
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
| 828 |
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
| 829 |
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
|
|
|
| 889 |
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
| 890 |
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
|
| 891 |
|
| 892 |
+
const int32_t precision = KQV->op_params[2];
|
| 893 |
|
| 894 |
if (!fp16_mma_available(cc)) {
|
| 895 |
GGML_ASSERT(precision == GGML_PREC_DEFAULT);
|
ggml-cuda/softmax.cu
CHANGED
|
@@ -11,7 +11,7 @@ __device__ float __forceinline__ t2f32<half>(half val) {
|
|
| 11 |
}
|
| 12 |
|
| 13 |
template <bool vals_smem, int ncols_template, int block_size_template, typename T>
|
| 14 |
-
static __global__ void soft_max_f32(const float * x, const T * mask,
|
| 15 |
const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
|
| 16 |
|
| 17 |
const int tid = threadIdx.x;
|
|
@@ -23,16 +23,16 @@ static __global__ void soft_max_f32(const float * x, const T * mask, const T * p
|
|
| 23 |
const int warp_id = threadIdx.x / WARP_SIZE;
|
| 24 |
const int lane_id = threadIdx.x % WARP_SIZE;
|
| 25 |
|
| 26 |
-
float slope =
|
| 27 |
|
| 28 |
// ALiBi
|
| 29 |
if (max_bias > 0.0f) {
|
| 30 |
const int h = rowx/nrows_y; // head index
|
| 31 |
|
| 32 |
const float base = h < n_head_log2 ? m0 : m1;
|
| 33 |
-
const int
|
| 34 |
|
| 35 |
-
slope = powf(base,
|
| 36 |
}
|
| 37 |
|
| 38 |
extern __shared__ float data_soft_max_f32[];
|
|
@@ -53,7 +53,7 @@ static __global__ void soft_max_f32(const float * x, const T * mask, const T * p
|
|
| 53 |
const int64_t ix = (int64_t)rowx*ncols + col;
|
| 54 |
const int64_t iy = (int64_t)rowy*ncols + col;
|
| 55 |
|
| 56 |
-
const float val = x[ix]*scale + (mask ? t2f32(mask[iy]) : 0.0f)
|
| 57 |
|
| 58 |
vals[col] = val;
|
| 59 |
max_val = max(max_val, val);
|
|
@@ -125,7 +125,7 @@ static __global__ void soft_max_f32(const float * x, const T * mask, const T * p
|
|
| 125 |
}
|
| 126 |
|
| 127 |
template<typename T>
|
| 128 |
-
static void soft_max_f32_cuda(const float * x, const T * mask,
|
| 129 |
int nth = WARP_SIZE;
|
| 130 |
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
|
| 131 |
const dim3 block_dims(nth, 1, 1);
|
|
@@ -133,8 +133,8 @@ static void soft_max_f32_cuda(const float * x, const T * mask, const T * pos, fl
|
|
| 133 |
const size_t shmem = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float);
|
| 134 |
static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
|
| 135 |
|
| 136 |
-
const uint32_t
|
| 137 |
-
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float)
|
| 138 |
|
| 139 |
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
| 140 |
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
|
@@ -142,43 +142,42 @@ static void soft_max_f32_cuda(const float * x, const T * mask, const T * pos, fl
|
|
| 142 |
if (shmem < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) {
|
| 143 |
switch (ncols_x) {
|
| 144 |
case 32:
|
| 145 |
-
soft_max_f32<true, 32, 32><<<block_nums, block_dims, shmem, stream>>>(x, mask,
|
| 146 |
break;
|
| 147 |
case 64:
|
| 148 |
-
soft_max_f32<true, 64, 64><<<block_nums, block_dims, shmem, stream>>>(x, mask,
|
| 149 |
break;
|
| 150 |
case 128:
|
| 151 |
-
soft_max_f32<true, 128, 128><<<block_nums, block_dims, shmem, stream>>>(x, mask,
|
| 152 |
break;
|
| 153 |
case 256:
|
| 154 |
-
soft_max_f32<true, 256, 256><<<block_nums, block_dims, shmem, stream>>>(x, mask,
|
| 155 |
break;
|
| 156 |
case 512:
|
| 157 |
-
soft_max_f32<true, 512, 512><<<block_nums, block_dims, shmem, stream>>>(x, mask,
|
| 158 |
break;
|
| 159 |
case 1024:
|
| 160 |
-
soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask,
|
| 161 |
break;
|
| 162 |
case 2048:
|
| 163 |
-
soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask,
|
| 164 |
break;
|
| 165 |
case 4096:
|
| 166 |
-
soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask,
|
| 167 |
break;
|
| 168 |
default:
|
| 169 |
-
soft_max_f32<true, 0, 0><<<block_nums, block_dims, shmem, stream>>>(x, mask,
|
| 170 |
break;
|
| 171 |
}
|
| 172 |
} else {
|
| 173 |
const size_t shmem_low = WARP_SIZE*sizeof(float);
|
| 174 |
-
soft_max_f32<false, 0, 0><<<block_nums, block_dims, shmem_low, stream>>>(x, mask,
|
| 175 |
}
|
| 176 |
}
|
| 177 |
|
| 178 |
void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 179 |
const ggml_tensor * src0 = dst->src[0];
|
| 180 |
const ggml_tensor * src1 = dst->src[1];
|
| 181 |
-
const ggml_tensor * src2 = dst->src[2];
|
| 182 |
|
| 183 |
const float * src0_d = (const float *)src0->data;
|
| 184 |
const void * src1_d = src1 ? (const void *)src1->data : nullptr;
|
|
@@ -190,7 +189,6 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
| 190 |
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
| 191 |
|
| 192 |
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
|
| 193 |
-
GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F16 || src2->type == GGML_TYPE_F32); // src2 contains positions and it is optional
|
| 194 |
|
| 195 |
const int64_t ne00 = src0->ne[0];
|
| 196 |
const int64_t nrows_x = ggml_nrows(src0);
|
|
@@ -202,26 +200,15 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
| 202 |
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
|
| 203 |
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
|
| 204 |
|
| 205 |
-
|
| 206 |
-
void * src2_d = nullptr;
|
| 207 |
-
|
| 208 |
-
const bool use_src2 = src2 != nullptr;
|
| 209 |
-
|
| 210 |
-
if (use_src2) {
|
| 211 |
-
src2_d = (void *)src2->data;
|
| 212 |
-
}
|
| 213 |
-
|
| 214 |
-
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16);
|
| 215 |
|
| 216 |
if (use_f16) {
|
| 217 |
const half * src1_dd = (const half *)src1_d;
|
| 218 |
-
const half * src2_dd = (const half *)src2_d;
|
| 219 |
|
| 220 |
-
soft_max_f32_cuda(src0_d, src1_dd,
|
| 221 |
} else {
|
| 222 |
const float * src1_dd = (const float *)src1_d;
|
| 223 |
-
const float * src2_dd = (const float *)src2_d;
|
| 224 |
|
| 225 |
-
soft_max_f32_cuda(src0_d, src1_dd,
|
| 226 |
}
|
| 227 |
}
|
|
|
|
| 11 |
}
|
| 12 |
|
| 13 |
template <bool vals_smem, int ncols_template, int block_size_template, typename T>
|
| 14 |
+
static __global__ void soft_max_f32(const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) {
|
| 15 |
const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
|
| 16 |
|
| 17 |
const int tid = threadIdx.x;
|
|
|
|
| 23 |
const int warp_id = threadIdx.x / WARP_SIZE;
|
| 24 |
const int lane_id = threadIdx.x % WARP_SIZE;
|
| 25 |
|
| 26 |
+
float slope = 1.0f;
|
| 27 |
|
| 28 |
// ALiBi
|
| 29 |
if (max_bias > 0.0f) {
|
| 30 |
const int h = rowx/nrows_y; // head index
|
| 31 |
|
| 32 |
const float base = h < n_head_log2 ? m0 : m1;
|
| 33 |
+
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
| 34 |
|
| 35 |
+
slope = powf(base, exph);
|
| 36 |
}
|
| 37 |
|
| 38 |
extern __shared__ float data_soft_max_f32[];
|
|
|
|
| 53 |
const int64_t ix = (int64_t)rowx*ncols + col;
|
| 54 |
const int64_t iy = (int64_t)rowy*ncols + col;
|
| 55 |
|
| 56 |
+
const float val = x[ix]*scale + (mask ? slope*t2f32(mask[iy]) : 0.0f);
|
| 57 |
|
| 58 |
vals[col] = val;
|
| 59 |
max_val = max(max_val, val);
|
|
|
|
| 125 |
}
|
| 126 |
|
| 127 |
template<typename T>
|
| 128 |
+
static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) {
|
| 129 |
int nth = WARP_SIZE;
|
| 130 |
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
|
| 131 |
const dim3 block_dims(nth, 1, 1);
|
|
|
|
| 133 |
const size_t shmem = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float);
|
| 134 |
static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
|
| 135 |
|
| 136 |
+
const uint32_t n_head = nrows_x/nrows_y;
|
| 137 |
+
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
|
| 138 |
|
| 139 |
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
| 140 |
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
|
|
|
| 142 |
if (shmem < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) {
|
| 143 |
switch (ncols_x) {
|
| 144 |
case 32:
|
| 145 |
+
soft_max_f32<true, 32, 32><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
| 146 |
break;
|
| 147 |
case 64:
|
| 148 |
+
soft_max_f32<true, 64, 64><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
| 149 |
break;
|
| 150 |
case 128:
|
| 151 |
+
soft_max_f32<true, 128, 128><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
| 152 |
break;
|
| 153 |
case 256:
|
| 154 |
+
soft_max_f32<true, 256, 256><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
| 155 |
break;
|
| 156 |
case 512:
|
| 157 |
+
soft_max_f32<true, 512, 512><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
| 158 |
break;
|
| 159 |
case 1024:
|
| 160 |
+
soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
| 161 |
break;
|
| 162 |
case 2048:
|
| 163 |
+
soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
| 164 |
break;
|
| 165 |
case 4096:
|
| 166 |
+
soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
| 167 |
break;
|
| 168 |
default:
|
| 169 |
+
soft_max_f32<true, 0, 0><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
| 170 |
break;
|
| 171 |
}
|
| 172 |
} else {
|
| 173 |
const size_t shmem_low = WARP_SIZE*sizeof(float);
|
| 174 |
+
soft_max_f32<false, 0, 0><<<block_nums, block_dims, shmem_low, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
| 175 |
}
|
| 176 |
}
|
| 177 |
|
| 178 |
void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 179 |
const ggml_tensor * src0 = dst->src[0];
|
| 180 |
const ggml_tensor * src1 = dst->src[1];
|
|
|
|
| 181 |
|
| 182 |
const float * src0_d = (const float *)src0->data;
|
| 183 |
const void * src1_d = src1 ? (const void *)src1->data : nullptr;
|
|
|
|
| 189 |
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
| 190 |
|
| 191 |
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
|
|
|
|
| 192 |
|
| 193 |
const int64_t ne00 = src0->ne[0];
|
| 194 |
const int64_t nrows_x = ggml_nrows(src0);
|
|
|
|
| 200 |
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
|
| 201 |
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
|
| 202 |
|
| 203 |
+
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 204 |
|
| 205 |
if (use_f16) {
|
| 206 |
const half * src1_dd = (const half *)src1_d;
|
|
|
|
| 207 |
|
| 208 |
+
soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
|
| 209 |
} else {
|
| 210 |
const float * src1_dd = (const float *)src1_d;
|
|
|
|
| 211 |
|
| 212 |
+
soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
|
| 213 |
}
|
| 214 |
}
|
ggml-kompute.cpp
CHANGED
|
@@ -1559,12 +1559,18 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
|
|
| 1559 |
case GGML_OP_SOFT_MAX:
|
| 1560 |
{
|
| 1561 |
float scale;
|
| 1562 |
-
|
| 1563 |
|
| 1564 |
-
|
|
|
|
|
|
|
|
|
|
| 1565 |
#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021")
|
| 1566 |
GGML_ASSERT(!src1 || src1t == GGML_TYPE_F32);
|
| 1567 |
-
|
|
|
|
|
|
|
|
|
|
| 1568 |
|
| 1569 |
ggml_vk_soft_max(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, ne01, ne02, ne03, scale);
|
| 1570 |
} break;
|
|
|
|
| 1559 |
case GGML_OP_SOFT_MAX:
|
| 1560 |
{
|
| 1561 |
float scale;
|
| 1562 |
+
float max_bias;
|
| 1563 |
|
| 1564 |
+
memcpy(&scale, (float *)dst->op_params + 0, sizeof(float));
|
| 1565 |
+
memcpy(&max_bias, (float *)dst->op_params + 1, sizeof(float));
|
| 1566 |
+
|
| 1567 |
+
#pragma message("TODO: add ggml_vk_soft_max() F16 src1 support")
|
| 1568 |
#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021")
|
| 1569 |
GGML_ASSERT(!src1 || src1t == GGML_TYPE_F32);
|
| 1570 |
+
|
| 1571 |
+
#pragma message("TODO: add ALiBi support")
|
| 1572 |
+
#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/7192")
|
| 1573 |
+
GGML_ASSERT(max_bias == 0.0f);
|
| 1574 |
|
| 1575 |
ggml_vk_soft_max(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, ne01, ne02, ne03, scale);
|
| 1576 |
} break;
|
ggml-metal.m
CHANGED
|
@@ -170,7 +170,6 @@ enum ggml_metal_kernel_type {
|
|
| 170 |
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,
|
| 171 |
GGML_METAL_KERNEL_TYPE_ROPE_F32,
|
| 172 |
GGML_METAL_KERNEL_TYPE_ROPE_F16,
|
| 173 |
-
GGML_METAL_KERNEL_TYPE_ALIBI_F32,
|
| 174 |
GGML_METAL_KERNEL_TYPE_IM2COL_F16,
|
| 175 |
GGML_METAL_KERNEL_TYPE_IM2COL_F32,
|
| 176 |
GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
|
|
@@ -625,7 +624,6 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
| 625 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm);
|
| 626 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
|
| 627 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
|
| 628 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true);
|
| 629 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
|
| 630 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
|
| 631 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
|
|
@@ -762,7 +760,6 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
|
| 762 |
case GGML_OP_GROUP_NORM:
|
| 763 |
return ctx->support_simdgroup_reduction;
|
| 764 |
case GGML_OP_NORM:
|
| 765 |
-
case GGML_OP_ALIBI:
|
| 766 |
case GGML_OP_ROPE:
|
| 767 |
case GGML_OP_IM2COL:
|
| 768 |
return true;
|
|
@@ -1373,13 +1370,12 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
| 1373 |
case GGML_OP_SOFT_MAX:
|
| 1374 |
{
|
| 1375 |
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32);
|
| 1376 |
-
GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F16 || src2->type == GGML_TYPE_F32);
|
| 1377 |
|
| 1378 |
int nth = 32; // SIMD width
|
| 1379 |
|
| 1380 |
id<MTLComputePipelineState> pipeline = nil;
|
| 1381 |
|
| 1382 |
-
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16)
|
| 1383 |
|
| 1384 |
if (ne00%4 == 0) {
|
| 1385 |
while (nth < ne00/4 && nth < 256) {
|
|
@@ -1410,8 +1406,8 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
| 1410 |
const int64_t nrows_x = ggml_nrows(src0);
|
| 1411 |
const int64_t nrows_y = src0->ne[1];
|
| 1412 |
|
| 1413 |
-
const uint32_t
|
| 1414 |
-
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float)
|
| 1415 |
|
| 1416 |
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
| 1417 |
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
|
@@ -1423,20 +1419,15 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
| 1423 |
} else {
|
| 1424 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
| 1425 |
}
|
| 1426 |
-
|
| 1427 |
-
|
| 1428 |
-
|
| 1429 |
-
|
| 1430 |
-
|
| 1431 |
-
[encoder
|
| 1432 |
-
[encoder setBytes:&
|
| 1433 |
-
[encoder setBytes:&
|
| 1434 |
-
[encoder setBytes:&
|
| 1435 |
-
[encoder setBytes:&scale length:sizeof(scale) atIndex:7];
|
| 1436 |
-
[encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:8];
|
| 1437 |
-
[encoder setBytes:&m0 length:sizeof(m0) atIndex:9];
|
| 1438 |
-
[encoder setBytes:&m1 length:sizeof(m1) atIndex:10];
|
| 1439 |
-
[encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:11];
|
| 1440 |
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
| 1441 |
|
| 1442 |
[encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
@@ -2241,49 +2232,6 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
| 2241 |
|
| 2242 |
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
| 2243 |
} break;
|
| 2244 |
-
case GGML_OP_ALIBI:
|
| 2245 |
-
{
|
| 2246 |
-
GGML_ASSERT((src0t == GGML_TYPE_F32));
|
| 2247 |
-
|
| 2248 |
-
const int nth = MIN(1024, ne00);
|
| 2249 |
-
|
| 2250 |
-
//const int n_past = ((int32_t *) dst->op_params)[0];
|
| 2251 |
-
const int n_head = ((int32_t *) dst->op_params)[1];
|
| 2252 |
-
|
| 2253 |
-
float max_bias;
|
| 2254 |
-
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
|
| 2255 |
-
|
| 2256 |
-
const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
|
| 2257 |
-
const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
|
| 2258 |
-
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
|
| 2259 |
-
|
| 2260 |
-
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ALIBI_F32].pipeline;
|
| 2261 |
-
|
| 2262 |
-
[encoder setComputePipelineState:pipeline];
|
| 2263 |
-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 2264 |
-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
| 2265 |
-
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
| 2266 |
-
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
| 2267 |
-
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
| 2268 |
-
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
|
| 2269 |
-
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
|
| 2270 |
-
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
|
| 2271 |
-
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
|
| 2272 |
-
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
|
| 2273 |
-
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
|
| 2274 |
-
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
|
| 2275 |
-
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
|
| 2276 |
-
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
|
| 2277 |
-
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
|
| 2278 |
-
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
|
| 2279 |
-
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
|
| 2280 |
-
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
| 2281 |
-
[encoder setBytes:&m0 length:sizeof( float) atIndex:18];
|
| 2282 |
-
[encoder setBytes:&m1 length:sizeof( float) atIndex:19];
|
| 2283 |
-
[encoder setBytes:&n_heads_log2_floor length:sizeof(int) atIndex:20];
|
| 2284 |
-
|
| 2285 |
-
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
| 2286 |
-
} break;
|
| 2287 |
case GGML_OP_ROPE:
|
| 2288 |
{
|
| 2289 |
GGML_ASSERT(ne10 == ne02);
|
|
@@ -2581,7 +2529,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
| 2581 |
"the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big");
|
| 2582 |
|
| 2583 |
const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30);
|
| 2584 |
-
|
| 2585 |
const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32);
|
| 2586 |
const int64_t ne33 = src3 ? src3->ne[3] : 0; GGML_UNUSED(ne33);
|
| 2587 |
|
|
@@ -2593,7 +2541,16 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
| 2593 |
const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
|
| 2594 |
|
| 2595 |
float scale;
|
| 2596 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2597 |
|
| 2598 |
id<MTLComputePipelineState> pipeline = nil;
|
| 2599 |
|
|
@@ -2630,34 +2587,37 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
| 2630 |
}
|
| 2631 |
|
| 2632 |
[encoder setComputePipelineState:pipeline];
|
| 2633 |
-
[encoder setBuffer:id_src0
|
| 2634 |
-
[encoder setBuffer:id_src1
|
| 2635 |
-
[encoder setBuffer:id_src2
|
| 2636 |
-
[encoder setBuffer:id_src3
|
| 2637 |
-
[encoder setBuffer:id_dst
|
| 2638 |
-
[encoder setBytes:&ne00
|
| 2639 |
-
[encoder setBytes:&ne01
|
| 2640 |
-
[encoder setBytes:&ne02
|
| 2641 |
-
[encoder setBytes:&ne03
|
| 2642 |
-
[encoder setBytes:&nb00
|
| 2643 |
-
[encoder setBytes:&nb01
|
| 2644 |
-
[encoder setBytes:&nb02
|
| 2645 |
-
[encoder setBytes:&nb03
|
| 2646 |
-
[encoder setBytes:&ne10
|
| 2647 |
-
[encoder setBytes:&ne11
|
| 2648 |
-
[encoder setBytes:&ne12
|
| 2649 |
-
[encoder setBytes:&ne13
|
| 2650 |
-
[encoder setBytes:&nb10
|
| 2651 |
-
[encoder setBytes:&nb11
|
| 2652 |
-
[encoder setBytes:&nb12
|
| 2653 |
-
[encoder setBytes:&nb13
|
| 2654 |
-
[encoder setBytes:&
|
| 2655 |
-
[encoder setBytes:&
|
| 2656 |
-
[encoder setBytes:&
|
| 2657 |
-
[encoder setBytes:&
|
| 2658 |
-
[encoder setBytes:&
|
| 2659 |
-
[encoder setBytes:&
|
| 2660 |
-
[encoder setBytes:&
|
|
|
|
|
|
|
|
|
|
| 2661 |
|
| 2662 |
if (!use_vec_kernel) {
|
| 2663 |
// half8x8 kernel
|
|
|
|
| 170 |
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,
|
| 171 |
GGML_METAL_KERNEL_TYPE_ROPE_F32,
|
| 172 |
GGML_METAL_KERNEL_TYPE_ROPE_F16,
|
|
|
|
| 173 |
GGML_METAL_KERNEL_TYPE_IM2COL_F16,
|
| 174 |
GGML_METAL_KERNEL_TYPE_IM2COL_F32,
|
| 175 |
GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
|
|
|
|
| 624 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm);
|
| 625 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
|
| 626 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
|
|
|
|
| 627 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
|
| 628 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
|
| 629 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
|
|
|
|
| 760 |
case GGML_OP_GROUP_NORM:
|
| 761 |
return ctx->support_simdgroup_reduction;
|
| 762 |
case GGML_OP_NORM:
|
|
|
|
| 763 |
case GGML_OP_ROPE:
|
| 764 |
case GGML_OP_IM2COL:
|
| 765 |
return true;
|
|
|
|
| 1370 |
case GGML_OP_SOFT_MAX:
|
| 1371 |
{
|
| 1372 |
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32);
|
|
|
|
| 1373 |
|
| 1374 |
int nth = 32; // SIMD width
|
| 1375 |
|
| 1376 |
id<MTLComputePipelineState> pipeline = nil;
|
| 1377 |
|
| 1378 |
+
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
|
| 1379 |
|
| 1380 |
if (ne00%4 == 0) {
|
| 1381 |
while (nth < ne00/4 && nth < 256) {
|
|
|
|
| 1406 |
const int64_t nrows_x = ggml_nrows(src0);
|
| 1407 |
const int64_t nrows_y = src0->ne[1];
|
| 1408 |
|
| 1409 |
+
const uint32_t n_head = nrows_x/nrows_y;
|
| 1410 |
+
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
|
| 1411 |
|
| 1412 |
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
| 1413 |
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
|
|
|
| 1419 |
} else {
|
| 1420 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
| 1421 |
}
|
| 1422 |
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
| 1423 |
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
| 1424 |
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
| 1425 |
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
| 1426 |
+
[encoder setBytes:&scale length:sizeof(scale) atIndex:6];
|
| 1427 |
+
[encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:7];
|
| 1428 |
+
[encoder setBytes:&m0 length:sizeof(m0) atIndex:8];
|
| 1429 |
+
[encoder setBytes:&m1 length:sizeof(m1) atIndex:9];
|
| 1430 |
+
[encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:10];
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1431 |
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
| 1432 |
|
| 1433 |
[encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
|
|
| 2232 |
|
| 2233 |
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
| 2234 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2235 |
case GGML_OP_ROPE:
|
| 2236 |
{
|
| 2237 |
GGML_ASSERT(ne10 == ne02);
|
|
|
|
| 2529 |
"the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big");
|
| 2530 |
|
| 2531 |
const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30);
|
| 2532 |
+
//const int64_t ne31 = src3 ? src3->ne[1] : 0;
|
| 2533 |
const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32);
|
| 2534 |
const int64_t ne33 = src3 ? src3->ne[3] : 0; GGML_UNUSED(ne33);
|
| 2535 |
|
|
|
|
| 2541 |
const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
|
| 2542 |
|
| 2543 |
float scale;
|
| 2544 |
+
float max_bias;
|
| 2545 |
+
|
| 2546 |
+
memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale));
|
| 2547 |
+
memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias));
|
| 2548 |
+
|
| 2549 |
+
const uint32_t n_head = src0->ne[2];
|
| 2550 |
+
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
|
| 2551 |
+
|
| 2552 |
+
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
| 2553 |
+
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
| 2554 |
|
| 2555 |
id<MTLComputePipelineState> pipeline = nil;
|
| 2556 |
|
|
|
|
| 2587 |
}
|
| 2588 |
|
| 2589 |
[encoder setComputePipelineState:pipeline];
|
| 2590 |
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 2591 |
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
| 2592 |
+
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
|
| 2593 |
+
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
|
| 2594 |
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:4];
|
| 2595 |
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:5];
|
| 2596 |
+
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:6];
|
| 2597 |
+
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:7];
|
| 2598 |
+
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:8];
|
| 2599 |
+
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:9];
|
| 2600 |
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:10];
|
| 2601 |
+
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:11];
|
| 2602 |
+
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:12];
|
| 2603 |
+
[encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:13];
|
| 2604 |
+
[encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:14];
|
| 2605 |
+
[encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:15];
|
| 2606 |
+
[encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:16];
|
| 2607 |
+
[encoder setBytes:&nb10 length:sizeof(uint64_t) atIndex:17];
|
| 2608 |
+
[encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:18];
|
| 2609 |
+
[encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:19];
|
| 2610 |
+
[encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:20];
|
| 2611 |
+
[encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:21];
|
| 2612 |
+
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:22];
|
| 2613 |
+
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:23];
|
| 2614 |
+
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:24];
|
| 2615 |
+
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:25];
|
| 2616 |
+
[encoder setBytes:&scale length:sizeof( float) atIndex:26];
|
| 2617 |
+
[encoder setBytes:&max_bias length:sizeof( float) atIndex:27];
|
| 2618 |
+
[encoder setBytes:&m0 length:sizeof(m0) atIndex:28];
|
| 2619 |
+
[encoder setBytes:&m1 length:sizeof(m1) atIndex:29];
|
| 2620 |
+
[encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:30];
|
| 2621 |
|
| 2622 |
if (!use_vec_kernel) {
|
| 2623 |
// half8x8 kernel
|
ggml-metal.metal
CHANGED
|
@@ -363,7 +363,6 @@ template<typename T>
|
|
| 363 |
kernel void kernel_soft_max(
|
| 364 |
device const char * src0,
|
| 365 |
device const char * src1,
|
| 366 |
-
device const char * src2,
|
| 367 |
device char * dst,
|
| 368 |
constant int64_t & ne00,
|
| 369 |
constant int64_t & ne01,
|
|
@@ -385,10 +384,9 @@ kernel void kernel_soft_max(
|
|
| 385 |
|
| 386 |
device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
| 387 |
device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00 : nullptr;
|
| 388 |
-
device const T * ppos = src2 != src0 ? (device const T *) src2 : nullptr;
|
| 389 |
device float * pdst = (device float *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
| 390 |
|
| 391 |
-
float slope =
|
| 392 |
|
| 393 |
// ALiBi
|
| 394 |
if (max_bias > 0.0f) {
|
|
@@ -404,7 +402,7 @@ kernel void kernel_soft_max(
|
|
| 404 |
float lmax = -INFINITY;
|
| 405 |
|
| 406 |
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
| 407 |
-
lmax = MAX(lmax, psrc0[i00]*scale + (pmask ?
|
| 408 |
}
|
| 409 |
|
| 410 |
// find the max value in the block
|
|
@@ -429,7 +427,7 @@ kernel void kernel_soft_max(
|
|
| 429 |
// parallel sum
|
| 430 |
float lsum = 0.0f;
|
| 431 |
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
| 432 |
-
const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ?
|
| 433 |
lsum += exp_psrc0;
|
| 434 |
pdst[i00] = exp_psrc0;
|
| 435 |
}
|
|
@@ -468,7 +466,6 @@ template<typename T>
|
|
| 468 |
kernel void kernel_soft_max_4(
|
| 469 |
device const char * src0,
|
| 470 |
device const char * src1,
|
| 471 |
-
device const char * src2,
|
| 472 |
device char * dst,
|
| 473 |
constant int64_t & ne00,
|
| 474 |
constant int64_t & ne01,
|
|
@@ -490,10 +487,9 @@ kernel void kernel_soft_max_4(
|
|
| 490 |
|
| 491 |
device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
|
| 492 |
device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00/4 : nullptr;
|
| 493 |
-
device const T * ppos = src2 != src0 ? (device const T *) src2 : nullptr;
|
| 494 |
device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
|
| 495 |
|
| 496 |
-
float slope =
|
| 497 |
|
| 498 |
if (max_bias > 0.0f) {
|
| 499 |
const int64_t h = i02;
|
|
@@ -508,7 +504,7 @@ kernel void kernel_soft_max_4(
|
|
| 508 |
float4 lmax4 = -INFINITY;
|
| 509 |
|
| 510 |
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
| 511 |
-
lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ?
|
| 512 |
}
|
| 513 |
|
| 514 |
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
|
|
@@ -534,7 +530,7 @@ kernel void kernel_soft_max_4(
|
|
| 534 |
// parallel sum
|
| 535 |
float4 lsum4 = 0.0f;
|
| 536 |
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
| 537 |
-
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ?
|
| 538 |
lsum4 += exp_psrc4;
|
| 539 |
pdst4[i00] = exp_psrc4;
|
| 540 |
}
|
|
@@ -1602,60 +1598,6 @@ kernel void kernel_mul_mv_f16_f32_l4(
|
|
| 1602 |
}
|
| 1603 |
}
|
| 1604 |
|
| 1605 |
-
kernel void kernel_alibi_f32(
|
| 1606 |
-
device const float * src0,
|
| 1607 |
-
device float * dst,
|
| 1608 |
-
constant int64_t & ne00,
|
| 1609 |
-
constant int64_t & ne01,
|
| 1610 |
-
constant int64_t & ne02,
|
| 1611 |
-
constant int64_t & ne03,
|
| 1612 |
-
constant uint64_t & nb00,
|
| 1613 |
-
constant uint64_t & nb01,
|
| 1614 |
-
constant uint64_t & nb02,
|
| 1615 |
-
constant uint64_t & nb03,
|
| 1616 |
-
constant int64_t & ne0,
|
| 1617 |
-
constant int64_t & ne1,
|
| 1618 |
-
constant int64_t & ne2,
|
| 1619 |
-
constant int64_t & ne3,
|
| 1620 |
-
constant uint64_t & nb0,
|
| 1621 |
-
constant uint64_t & nb1,
|
| 1622 |
-
constant uint64_t & nb2,
|
| 1623 |
-
constant uint64_t & nb3,
|
| 1624 |
-
constant float & m0,
|
| 1625 |
-
constant float & m1,
|
| 1626 |
-
constant int & n_heads_log2_floor,
|
| 1627 |
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1628 |
-
uint3 tpitg[[thread_position_in_threadgroup]],
|
| 1629 |
-
uint3 ntg[[threads_per_threadgroup]]) {
|
| 1630 |
-
const int64_t i03 = tgpig[2];
|
| 1631 |
-
const int64_t i02 = tgpig[1];
|
| 1632 |
-
const int64_t i01 = tgpig[0];
|
| 1633 |
-
|
| 1634 |
-
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
| 1635 |
-
|
| 1636 |
-
const int64_t i3 = n / (ne2*ne1*ne0);
|
| 1637 |
-
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
| 1638 |
-
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
| 1639 |
-
//const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
| 1640 |
-
|
| 1641 |
-
const int64_t k = i3*ne3 + i2;
|
| 1642 |
-
|
| 1643 |
-
float m_k;
|
| 1644 |
-
if (k < n_heads_log2_floor) {
|
| 1645 |
-
m_k = pow(m0, k + 1);
|
| 1646 |
-
} else {
|
| 1647 |
-
m_k = pow(m1, 2 * (k - n_heads_log2_floor) + 1);
|
| 1648 |
-
}
|
| 1649 |
-
|
| 1650 |
-
device char * dst_row = (device char *) dst + i3*nb3 + i2*nb2 + i1*nb1;
|
| 1651 |
-
device const char * src_row = (device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01;
|
| 1652 |
-
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
| 1653 |
-
const float src_v = *(device float *)(src_row + i00*nb00);
|
| 1654 |
-
device float * dst_v = (device float *)(dst_row + i00*nb0);
|
| 1655 |
-
*dst_v = i00 * m_k + src_v;
|
| 1656 |
-
}
|
| 1657 |
-
}
|
| 1658 |
-
|
| 1659 |
static float rope_yarn_ramp(const float low, const float high, const int i0) {
|
| 1660 |
const float y = (i0 / 2 - low) / max(0.001f, high - low);
|
| 1661 |
return 1.0f - min(1.0f, max(0.0f, y));
|
|
@@ -2123,13 +2065,16 @@ typedef void (flash_attn_ext_f16_t)(
|
|
| 2123 |
constant uint64_t & nb11,
|
| 2124 |
constant uint64_t & nb12,
|
| 2125 |
constant uint64_t & nb13,
|
| 2126 |
-
constant int64_t & ne31,
|
| 2127 |
constant uint64_t & nb31,
|
| 2128 |
constant int64_t & ne0,
|
| 2129 |
constant int64_t & ne1,
|
| 2130 |
constant int64_t & ne2,
|
| 2131 |
constant int64_t & ne3,
|
| 2132 |
constant float & scale,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2133 |
threadgroup half * shared,
|
| 2134 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 2135 |
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
@@ -2161,13 +2106,16 @@ kernel void kernel_flash_attn_ext_f16(
|
|
| 2161 |
constant uint64_t & nb11,
|
| 2162 |
constant uint64_t & nb12,
|
| 2163 |
constant uint64_t & nb13,
|
| 2164 |
-
constant int64_t & ne31,
|
| 2165 |
constant uint64_t & nb31,
|
| 2166 |
constant int64_t & ne0,
|
| 2167 |
constant int64_t & ne1,
|
| 2168 |
constant int64_t & ne2,
|
| 2169 |
constant int64_t & ne3,
|
| 2170 |
constant float & scale,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2171 |
threadgroup half * shared [[threadgroup(0)]],
|
| 2172 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 2173 |
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
@@ -2264,6 +2212,19 @@ kernel void kernel_flash_attn_ext_f16(
|
|
| 2264 |
// prepare diagonal scale matrix
|
| 2265 |
simdgroup_float8x8 mscale(scale);
|
| 2266 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2267 |
// loop over the KV cache
|
| 2268 |
// each simdgroup handles blocks of Q rows and C columns
|
| 2269 |
for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) {
|
|
@@ -2286,9 +2247,10 @@ kernel void kernel_flash_attn_ext_f16(
|
|
| 2286 |
simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
|
| 2287 |
}
|
| 2288 |
|
| 2289 |
-
// mqk = mqk*scale + mask
|
| 2290 |
simdgroup_half8x8 mm;
|
| 2291 |
simdgroup_load(mm, mp + ic + 8*cc, nb31/sizeof(half), 0, false);
|
|
|
|
| 2292 |
simdgroup_multiply_accumulate(mqk, mqk, mscale, mm);
|
| 2293 |
|
| 2294 |
simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
|
|
@@ -2479,13 +2441,16 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
|
| 2479 |
constant uint64_t & nb11,
|
| 2480 |
constant uint64_t & nb12,
|
| 2481 |
constant uint64_t & nb13,
|
| 2482 |
-
constant int64_t & ne31,
|
| 2483 |
constant uint64_t & nb31,
|
| 2484 |
constant int64_t & ne0,
|
| 2485 |
constant int64_t & ne1,
|
| 2486 |
constant int64_t & ne2,
|
| 2487 |
constant int64_t & ne3,
|
| 2488 |
constant float & scale,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2489 |
threadgroup half * shared [[threadgroup(0)]],
|
| 2490 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 2491 |
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
@@ -2504,6 +2469,18 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
|
| 2504 |
|
| 2505 |
const short T = D + 2*nsg*SH; // shared memory size per query in (half)
|
| 2506 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2507 |
//threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
|
| 2508 |
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
|
| 2509 |
threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
|
|
@@ -2610,10 +2587,10 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
|
| 2610 |
mqk += simd_shuffle_down(mqk, 2);
|
| 2611 |
mqk += simd_shuffle_down(mqk, 1);
|
| 2612 |
|
| 2613 |
-
// mqk = mqk*scale + mask
|
| 2614 |
if (tiisg == 0) {
|
| 2615 |
float4 mm = (float4) mp4[ic/4 + cc];
|
| 2616 |
-
mqk = mqk*scale + mm;
|
| 2617 |
|
| 2618 |
ss4[cc] = mqk;
|
| 2619 |
}
|
|
@@ -2847,7 +2824,8 @@ kernel void kernel_cpy_f32_f16(
|
|
| 2847 |
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
| 2848 |
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
| 2849 |
|
| 2850 |
-
|
|
|
|
| 2851 |
}
|
| 2852 |
}
|
| 2853 |
|
|
|
|
| 363 |
kernel void kernel_soft_max(
|
| 364 |
device const char * src0,
|
| 365 |
device const char * src1,
|
|
|
|
| 366 |
device char * dst,
|
| 367 |
constant int64_t & ne00,
|
| 368 |
constant int64_t & ne01,
|
|
|
|
| 384 |
|
| 385 |
device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
| 386 |
device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00 : nullptr;
|
|
|
|
| 387 |
device float * pdst = (device float *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
| 388 |
|
| 389 |
+
float slope = 1.0f;
|
| 390 |
|
| 391 |
// ALiBi
|
| 392 |
if (max_bias > 0.0f) {
|
|
|
|
| 402 |
float lmax = -INFINITY;
|
| 403 |
|
| 404 |
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
| 405 |
+
lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
|
| 406 |
}
|
| 407 |
|
| 408 |
// find the max value in the block
|
|
|
|
| 427 |
// parallel sum
|
| 428 |
float lsum = 0.0f;
|
| 429 |
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
| 430 |
+
const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val);
|
| 431 |
lsum += exp_psrc0;
|
| 432 |
pdst[i00] = exp_psrc0;
|
| 433 |
}
|
|
|
|
| 466 |
kernel void kernel_soft_max_4(
|
| 467 |
device const char * src0,
|
| 468 |
device const char * src1,
|
|
|
|
| 469 |
device char * dst,
|
| 470 |
constant int64_t & ne00,
|
| 471 |
constant int64_t & ne01,
|
|
|
|
| 487 |
|
| 488 |
device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
|
| 489 |
device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00/4 : nullptr;
|
|
|
|
| 490 |
device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
|
| 491 |
|
| 492 |
+
float slope = 1.0f;
|
| 493 |
|
| 494 |
if (max_bias > 0.0f) {
|
| 495 |
const int64_t h = i02;
|
|
|
|
| 504 |
float4 lmax4 = -INFINITY;
|
| 505 |
|
| 506 |
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
| 507 |
+
lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
|
| 508 |
}
|
| 509 |
|
| 510 |
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
|
|
|
|
| 530 |
// parallel sum
|
| 531 |
float4 lsum4 = 0.0f;
|
| 532 |
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
| 533 |
+
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val);
|
| 534 |
lsum4 += exp_psrc4;
|
| 535 |
pdst4[i00] = exp_psrc4;
|
| 536 |
}
|
|
|
|
| 1598 |
}
|
| 1599 |
}
|
| 1600 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1601 |
static float rope_yarn_ramp(const float low, const float high, const int i0) {
|
| 1602 |
const float y = (i0 / 2 - low) / max(0.001f, high - low);
|
| 1603 |
return 1.0f - min(1.0f, max(0.0f, y));
|
|
|
|
| 2065 |
constant uint64_t & nb11,
|
| 2066 |
constant uint64_t & nb12,
|
| 2067 |
constant uint64_t & nb13,
|
|
|
|
| 2068 |
constant uint64_t & nb31,
|
| 2069 |
constant int64_t & ne0,
|
| 2070 |
constant int64_t & ne1,
|
| 2071 |
constant int64_t & ne2,
|
| 2072 |
constant int64_t & ne3,
|
| 2073 |
constant float & scale,
|
| 2074 |
+
constant float & max_bias,
|
| 2075 |
+
constant float & m0,
|
| 2076 |
+
constant float & m1,
|
| 2077 |
+
constant uint32_t & n_head_log2,
|
| 2078 |
threadgroup half * shared,
|
| 2079 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 2080 |
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
|
|
| 2106 |
constant uint64_t & nb11,
|
| 2107 |
constant uint64_t & nb12,
|
| 2108 |
constant uint64_t & nb13,
|
|
|
|
| 2109 |
constant uint64_t & nb31,
|
| 2110 |
constant int64_t & ne0,
|
| 2111 |
constant int64_t & ne1,
|
| 2112 |
constant int64_t & ne2,
|
| 2113 |
constant int64_t & ne3,
|
| 2114 |
constant float & scale,
|
| 2115 |
+
constant float & max_bias,
|
| 2116 |
+
constant float & m0,
|
| 2117 |
+
constant float & m1,
|
| 2118 |
+
constant uint32_t & n_head_log2,
|
| 2119 |
threadgroup half * shared [[threadgroup(0)]],
|
| 2120 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 2121 |
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
|
|
| 2212 |
// prepare diagonal scale matrix
|
| 2213 |
simdgroup_float8x8 mscale(scale);
|
| 2214 |
|
| 2215 |
+
// prepare diagonal slope matrix
|
| 2216 |
+
simdgroup_float8x8 mslope(1.0f);
|
| 2217 |
+
|
| 2218 |
+
// ALiBi
|
| 2219 |
+
if (max_bias > 0.0f) {
|
| 2220 |
+
const short h = iq2;
|
| 2221 |
+
|
| 2222 |
+
const float base = h < n_head_log2 ? m0 : m1;
|
| 2223 |
+
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
| 2224 |
+
|
| 2225 |
+
mslope = simdgroup_float8x8(pow(base, exph));
|
| 2226 |
+
}
|
| 2227 |
+
|
| 2228 |
// loop over the KV cache
|
| 2229 |
// each simdgroup handles blocks of Q rows and C columns
|
| 2230 |
for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) {
|
|
|
|
| 2247 |
simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
|
| 2248 |
}
|
| 2249 |
|
| 2250 |
+
// mqk = mqk*scale + mask*slope
|
| 2251 |
simdgroup_half8x8 mm;
|
| 2252 |
simdgroup_load(mm, mp + ic + 8*cc, nb31/sizeof(half), 0, false);
|
| 2253 |
+
simdgroup_multiply(mm, mslope, mm);
|
| 2254 |
simdgroup_multiply_accumulate(mqk, mqk, mscale, mm);
|
| 2255 |
|
| 2256 |
simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
|
|
|
|
| 2441 |
constant uint64_t & nb11,
|
| 2442 |
constant uint64_t & nb12,
|
| 2443 |
constant uint64_t & nb13,
|
|
|
|
| 2444 |
constant uint64_t & nb31,
|
| 2445 |
constant int64_t & ne0,
|
| 2446 |
constant int64_t & ne1,
|
| 2447 |
constant int64_t & ne2,
|
| 2448 |
constant int64_t & ne3,
|
| 2449 |
constant float & scale,
|
| 2450 |
+
constant float & max_bias,
|
| 2451 |
+
constant float & m0,
|
| 2452 |
+
constant float & m1,
|
| 2453 |
+
constant uint32_t & n_head_log2,
|
| 2454 |
threadgroup half * shared [[threadgroup(0)]],
|
| 2455 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 2456 |
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
|
|
| 2469 |
|
| 2470 |
const short T = D + 2*nsg*SH; // shared memory size per query in (half)
|
| 2471 |
|
| 2472 |
+
float slope = 1.0f;
|
| 2473 |
+
|
| 2474 |
+
// ALiBi
|
| 2475 |
+
if (max_bias > 0.0f) {
|
| 2476 |
+
const short h = iq2;
|
| 2477 |
+
|
| 2478 |
+
const float base = h < n_head_log2 ? m0 : m1;
|
| 2479 |
+
const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
| 2480 |
+
|
| 2481 |
+
slope = pow(base, exp);
|
| 2482 |
+
}
|
| 2483 |
+
|
| 2484 |
//threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
|
| 2485 |
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
|
| 2486 |
threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
|
|
|
|
| 2587 |
mqk += simd_shuffle_down(mqk, 2);
|
| 2588 |
mqk += simd_shuffle_down(mqk, 1);
|
| 2589 |
|
| 2590 |
+
// mqk = mqk*scale + mask*slope
|
| 2591 |
if (tiisg == 0) {
|
| 2592 |
float4 mm = (float4) mp4[ic/4 + cc];
|
| 2593 |
+
mqk = mqk*scale + mm*slope;
|
| 2594 |
|
| 2595 |
ss4[cc] = mqk;
|
| 2596 |
}
|
|
|
|
| 2824 |
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
| 2825 |
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
| 2826 |
|
| 2827 |
+
// TODO: is there a better way to handle -INFINITY?
|
| 2828 |
+
dst_data[i00] = src[0] == -INFINITY ? -MAXHALF : src[0];
|
| 2829 |
}
|
| 2830 |
}
|
| 2831 |
|
ggml-sycl.cpp
CHANGED
|
@@ -3154,7 +3154,6 @@ typedef float (*vec_dot_q_mul_mat_sycl_t)(
|
|
| 3154 |
#define SYCL_SCALE_BLOCK_SIZE 256
|
| 3155 |
#define SYCL_CLAMP_BLOCK_SIZE 256
|
| 3156 |
#define SYCL_ROPE_BLOCK_SIZE 256
|
| 3157 |
-
#define SYCL_ALIBI_BLOCK_SIZE 32
|
| 3158 |
#define SYCL_DIAG_MASK_INF_BLOCK_SIZE 32
|
| 3159 |
#define SYCL_QUANTIZE_BLOCK_SIZE 256
|
| 3160 |
#define SYCL_DEQUANTIZE_BLOCK_SIZE 256
|
|
@@ -9316,32 +9315,6 @@ static void rope_glm_f32(
|
|
| 9316 |
dst[i + half_n_dims * 3] = x2*sin_block_theta + x3*cos_block_theta;
|
| 9317 |
}
|
| 9318 |
|
| 9319 |
-
static void alibi_f32(const float * x, float * dst, const int ncols, const int k_rows,
|
| 9320 |
-
const int n_heads_log2_floor, const float m0, const float m1,
|
| 9321 |
-
const sycl::nd_item<3> &item_ct1) {
|
| 9322 |
-
const int col = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
| 9323 |
-
item_ct1.get_local_id(2);
|
| 9324 |
-
|
| 9325 |
-
if (col >= ncols) {
|
| 9326 |
-
return;
|
| 9327 |
-
}
|
| 9328 |
-
|
| 9329 |
-
const int row = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
|
| 9330 |
-
item_ct1.get_local_id(1);
|
| 9331 |
-
const int i = row*ncols + col;
|
| 9332 |
-
|
| 9333 |
-
const int k = row/k_rows;
|
| 9334 |
-
|
| 9335 |
-
float m_k;
|
| 9336 |
-
if (k < n_heads_log2_floor) {
|
| 9337 |
-
m_k = dpct::pow(m0, k + 1);
|
| 9338 |
-
} else {
|
| 9339 |
-
m_k = dpct::pow(m1, 2 * (k - n_heads_log2_floor) + 1);
|
| 9340 |
-
}
|
| 9341 |
-
|
| 9342 |
-
dst[i] = col * m_k + x[i];
|
| 9343 |
-
}
|
| 9344 |
-
|
| 9345 |
static void k_sum_rows_f32(const float * x, float * dst, const int ncols,
|
| 9346 |
const sycl::nd_item<3> &item_ct1) {
|
| 9347 |
const int row = item_ct1.get_group(1);
|
|
@@ -9443,7 +9416,7 @@ static void diag_mask_inf_f32(const float * x, float * dst, const int ncols, con
|
|
| 9443 |
|
| 9444 |
|
| 9445 |
template <bool vals_smem, int ncols_template, int block_size_template>
|
| 9446 |
-
static void soft_max_f32(const float * x, const float * mask,
|
| 9447 |
const int nrows_y, const float scale, const float max_bias, const float m0,
|
| 9448 |
const float m1, uint32_t n_head_log2, const sycl::nd_item<3> &item_ct1, float *buf) {
|
| 9449 |
const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
|
|
@@ -9457,7 +9430,7 @@ static void soft_max_f32(const float * x, const float * mask, const float *pos,
|
|
| 9457 |
const int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
|
| 9458 |
const int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
|
| 9459 |
|
| 9460 |
-
float slope =
|
| 9461 |
|
| 9462 |
// ALiBi
|
| 9463 |
if (max_bias > 0.0f) {
|
|
@@ -9482,7 +9455,7 @@ static void soft_max_f32(const float * x, const float * mask, const float *pos,
|
|
| 9482 |
const int ix = rowx*ncols + col;
|
| 9483 |
const int iy = rowy*ncols + col;
|
| 9484 |
|
| 9485 |
-
const float val = x[ix]*scale + (mask ? mask[iy] : 0.0f)
|
| 9486 |
|
| 9487 |
vals[col] = val;
|
| 9488 |
max_val = sycl::max(max_val, val);
|
|
@@ -12964,20 +12937,6 @@ static void rope_glm_f32_sycl(const float *x, float *dst, int ncols, int nrows,
|
|
| 12964 |
});
|
| 12965 |
}
|
| 12966 |
|
| 12967 |
-
static void alibi_f32_sycl(const float *x, float *dst, const int ncols,
|
| 12968 |
-
const int nrows, const int k_rows,
|
| 12969 |
-
const int n_heads_log2_floor, const float m0,
|
| 12970 |
-
const float m1, dpct::queue_ptr stream) {
|
| 12971 |
-
const sycl::range<3> block_dims(1, 1, SYCL_ALIBI_BLOCK_SIZE);
|
| 12972 |
-
const int num_blocks_x = (ncols + SYCL_ALIBI_BLOCK_SIZE - 1) / (SYCL_ALIBI_BLOCK_SIZE);
|
| 12973 |
-
const sycl::range<3> block_nums(1, nrows, num_blocks_x);
|
| 12974 |
-
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
| 12975 |
-
[=](sycl::nd_item<3> item_ct1) {
|
| 12976 |
-
alibi_f32(x, dst, ncols, k_rows,
|
| 12977 |
-
n_heads_log2_floor, m0, m1, item_ct1);
|
| 12978 |
-
});
|
| 12979 |
-
}
|
| 12980 |
-
|
| 12981 |
static void sum_rows_f32_sycl(const float *x, float *dst, const int ncols,
|
| 12982 |
const int nrows, dpct::queue_ptr stream) {
|
| 12983 |
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
|
|
@@ -13058,7 +13017,7 @@ static void diag_mask_inf_f32_sycl(const float *x, float *dst,
|
|
| 13058 |
}
|
| 13059 |
|
| 13060 |
template <bool vals_smem, int ncols_template, int block_size_template>
|
| 13061 |
-
static void soft_max_f32_submitter(const float * x, const float * mask,
|
| 13062 |
const int nrows_y, const float scale, const float max_bias, const float m0,
|
| 13063 |
const float m1, uint32_t n_head_log2, sycl::range<3> block_nums, sycl::range<3> block_dims,
|
| 13064 |
const size_t n_local_scratch, dpct::queue_ptr stream) {
|
|
@@ -13068,7 +13027,7 @@ static void soft_max_f32_submitter(const float * x, const float * mask, const fl
|
|
| 13068 |
cgh.parallel_for(
|
| 13069 |
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
| 13070 |
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
|
| 13071 |
-
soft_max_f32<vals_smem, ncols_template, block_size_template>(x, mask,
|
| 13072 |
nrows_y, scale, max_bias, m0,
|
| 13073 |
m1, n_head_log2, item_ct1,
|
| 13074 |
local_buf_acc.get_pointer());
|
|
@@ -13076,7 +13035,7 @@ static void soft_max_f32_submitter(const float * x, const float * mask, const fl
|
|
| 13076 |
});
|
| 13077 |
}
|
| 13078 |
|
| 13079 |
-
static void soft_max_f32_sycl(const float * x, const float * mask,
|
| 13080 |
float * dst, const int ncols_x, const int nrows_x,
|
| 13081 |
const int nrows_y, const float scale, const float max_bias,
|
| 13082 |
dpct::queue_ptr stream) {
|
|
@@ -13098,60 +13057,60 @@ static void soft_max_f32_sycl(const float * x, const float * mask, const float *
|
|
| 13098 |
const size_t local_mem_size = stream->get_device().get_info<sycl::info::device::local_mem_size>();
|
| 13099 |
if (n_local_scratch*sizeof(float) < local_mem_size) {
|
| 13100 |
if (ncols_x > max_block_size) {
|
| 13101 |
-
soft_max_f32_submitter<true, 0, 0>(x, mask,
|
| 13102 |
max_bias, m0, m1, n_head_log2, block_nums,
|
| 13103 |
block_dims, n_local_scratch, stream);
|
| 13104 |
return;
|
| 13105 |
}
|
| 13106 |
switch (ncols_x) {
|
| 13107 |
case 32:
|
| 13108 |
-
soft_max_f32_submitter<true, 32, 32>(x, mask,
|
| 13109 |
max_bias, m0, m1, n_head_log2, block_nums,
|
| 13110 |
block_dims, n_local_scratch, stream);
|
| 13111 |
break;
|
| 13112 |
case 64:
|
| 13113 |
-
soft_max_f32_submitter<true, 64, 64>(x, mask,
|
| 13114 |
max_bias, m0, m1, n_head_log2, block_nums,
|
| 13115 |
block_dims, n_local_scratch, stream);
|
| 13116 |
break;
|
| 13117 |
case 128:
|
| 13118 |
-
soft_max_f32_submitter<true, 128, 128>(x, mask,
|
| 13119 |
max_bias, m0, m1, n_head_log2, block_nums,
|
| 13120 |
block_dims, n_local_scratch, stream);
|
| 13121 |
break;
|
| 13122 |
case 256:
|
| 13123 |
-
soft_max_f32_submitter<true, 256, 256>(x, mask,
|
| 13124 |
max_bias, m0, m1, n_head_log2, block_nums,
|
| 13125 |
block_dims, n_local_scratch, stream);
|
| 13126 |
break;
|
| 13127 |
case 512:
|
| 13128 |
-
soft_max_f32_submitter<true, 512, 512>(x, mask,
|
| 13129 |
max_bias, m0, m1, n_head_log2, block_nums,
|
| 13130 |
block_dims, n_local_scratch, stream);
|
| 13131 |
break;
|
| 13132 |
case 1024:
|
| 13133 |
-
soft_max_f32_submitter<true, 1024, 1024>(x, mask,
|
| 13134 |
max_bias, m0, m1, n_head_log2, block_nums,
|
| 13135 |
block_dims, n_local_scratch, stream);
|
| 13136 |
break;
|
| 13137 |
case 2048:
|
| 13138 |
-
soft_max_f32_submitter<true, 2048, 1024>(x, mask,
|
| 13139 |
max_bias, m0, m1, n_head_log2, block_nums,
|
| 13140 |
block_dims, n_local_scratch, stream);
|
| 13141 |
break;
|
| 13142 |
case 4096:
|
| 13143 |
-
soft_max_f32_submitter<true, 4096, 1024>(x, mask,
|
| 13144 |
max_bias, m0, m1, n_head_log2, block_nums,
|
| 13145 |
block_dims, n_local_scratch, stream);
|
| 13146 |
break;
|
| 13147 |
default:
|
| 13148 |
-
soft_max_f32_submitter<true, 0, 0>(x, mask,
|
| 13149 |
max_bias, m0, m1, n_head_log2, block_nums,
|
| 13150 |
block_dims, n_local_scratch, stream);
|
| 13151 |
break;
|
| 13152 |
}
|
| 13153 |
} else {
|
| 13154 |
-
soft_max_f32_submitter<false, 0, 0>(x, mask,
|
| 13155 |
max_bias, m0, m1, n_head_log2, block_nums,
|
| 13156 |
block_dims, WARP_SIZE, stream);
|
| 13157 |
}
|
|
@@ -14562,36 +14521,6 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
|
|
| 14562 |
(void) src1_dd;
|
| 14563 |
}
|
| 14564 |
|
| 14565 |
-
inline void ggml_sycl_op_alibi(const ggml_tensor *src0, const ggml_tensor *src1,
|
| 14566 |
-
ggml_tensor *dst, const float *src0_dd,
|
| 14567 |
-
const float *src1_dd, float *dst_dd,
|
| 14568 |
-
const dpct::queue_ptr &main_stream) {
|
| 14569 |
-
|
| 14570 |
-
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
| 14571 |
-
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
| 14572 |
-
|
| 14573 |
-
GGML_TENSOR_LOCALS_3(int64_t, ne0, src0, ne);
|
| 14574 |
-
const int64_t nrows = ggml_nrows(src0);
|
| 14575 |
-
|
| 14576 |
-
//const int n_past = ((int32_t *) dst->op_params)[0];
|
| 14577 |
-
const int n_head = ((int32_t *) dst->op_params)[1];
|
| 14578 |
-
float max_bias;
|
| 14579 |
-
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
|
| 14580 |
-
|
| 14581 |
-
//GGML_ASSERT(ne01 + n_past == ne00);
|
| 14582 |
-
GGML_ASSERT(n_head == ne02);
|
| 14583 |
-
|
| 14584 |
-
const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
|
| 14585 |
-
|
| 14586 |
-
const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
|
| 14587 |
-
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
|
| 14588 |
-
|
| 14589 |
-
alibi_f32_sycl(src0_dd, dst_dd, ne00, nrows, ne01, n_heads_log2_floor, m0, m1, main_stream);
|
| 14590 |
-
|
| 14591 |
-
(void) src1;
|
| 14592 |
-
(void) src1_dd;
|
| 14593 |
-
}
|
| 14594 |
-
|
| 14595 |
static void ggml_sycl_op_pool2d(const ggml_tensor *src0,
|
| 14596 |
const ggml_tensor *src1, ggml_tensor *dst,
|
| 14597 |
const float *src0_dd, const float *src1_dd,
|
|
@@ -14746,12 +14675,9 @@ inline void ggml_sycl_op_soft_max(const ggml_tensor *src0,
|
|
| 14746 |
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
| 14747 |
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
| 14748 |
|
| 14749 |
-
|
| 14750 |
-
|
| 14751 |
-
#pragma message("TODO: add ggml_sycl_op_soft_max() F16 src1 and src2 support")
|
| 14752 |
#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021")
|
| 14753 |
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
|
| 14754 |
-
GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32); // src2 contains positions and it is optional
|
| 14755 |
|
| 14756 |
const int64_t ne00 = src0->ne[0];
|
| 14757 |
const int64_t nrows_x = ggml_nrows(src0);
|
|
@@ -14763,25 +14689,7 @@ inline void ggml_sycl_op_soft_max(const ggml_tensor *src0,
|
|
| 14763 |
memcpy(&scale, dst->op_params + 0, sizeof(float));
|
| 14764 |
memcpy(&max_bias, dst->op_params + 1, sizeof(float));
|
| 14765 |
|
| 14766 |
-
|
| 14767 |
-
float * src2_dd = nullptr;
|
| 14768 |
-
sycl_pool_alloc<float> src2_f;
|
| 14769 |
-
|
| 14770 |
-
const bool use_src2 = src2 != nullptr;
|
| 14771 |
-
|
| 14772 |
-
if (use_src2) {
|
| 14773 |
-
const bool src2_on_device = src2->backend == GGML_BACKEND_TYPE_GPU;
|
| 14774 |
-
|
| 14775 |
-
if (src2_on_device) {
|
| 14776 |
-
ggml_tensor_extra_gpu * src2_extra = (ggml_tensor_extra_gpu *) src2->extra;
|
| 14777 |
-
src2_dd = (float *) src2_extra->data_device[g_main_device];
|
| 14778 |
-
} else {
|
| 14779 |
-
src2_dd = src2_f.alloc(ggml_nelements(src2));
|
| 14780 |
-
SYCL_CHECK(ggml_sycl_cpy_tensor_2d(src2_dd, src2, 0, 0, 0, 1, main_stream));
|
| 14781 |
-
}
|
| 14782 |
-
}
|
| 14783 |
-
|
| 14784 |
-
soft_max_f32_sycl(src0_dd, src1 ? src1_dd : nullptr, src2_dd, dst_dd, ne00,
|
| 14785 |
nrows_x, nrows_y, scale, max_bias, main_stream);
|
| 14786 |
}
|
| 14787 |
|
|
@@ -16232,10 +16140,6 @@ static void ggml_sycl_rope(const ggml_tensor * src0, const ggml_tensor * src1, g
|
|
| 16232 |
ggml_sycl_op_flatten(src0, src1, dst, ggml_sycl_op_rope);
|
| 16233 |
}
|
| 16234 |
|
| 16235 |
-
static void ggml_sycl_alibi(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 16236 |
-
ggml_sycl_op_flatten(src0, src1, dst, ggml_sycl_op_alibi);
|
| 16237 |
-
}
|
| 16238 |
-
|
| 16239 |
static void ggml_sycl_pool2d(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 16240 |
ggml_sycl_op_flatten(src0, src1, dst, ggml_sycl_op_pool2d);
|
| 16241 |
}
|
|
@@ -16612,9 +16516,6 @@ bool ggml_sycl_compute_forward(struct ggml_compute_params * params, struct ggml_
|
|
| 16612 |
case GGML_OP_ROPE:
|
| 16613 |
func = ggml_sycl_rope;
|
| 16614 |
break;
|
| 16615 |
-
case GGML_OP_ALIBI:
|
| 16616 |
-
func = ggml_sycl_alibi;
|
| 16617 |
-
break;
|
| 16618 |
case GGML_OP_IM2COL:
|
| 16619 |
func = ggml_sycl_im2col;
|
| 16620 |
break;
|
|
@@ -17744,7 +17645,6 @@ GGML_CALL static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, cons
|
|
| 17744 |
case GGML_OP_DIAG_MASK_INF:
|
| 17745 |
case GGML_OP_SOFT_MAX:
|
| 17746 |
case GGML_OP_ROPE:
|
| 17747 |
-
case GGML_OP_ALIBI:
|
| 17748 |
case GGML_OP_IM2COL:
|
| 17749 |
case GGML_OP_POOL_2D:
|
| 17750 |
case GGML_OP_SUM_ROWS:
|
|
|
|
| 3154 |
#define SYCL_SCALE_BLOCK_SIZE 256
|
| 3155 |
#define SYCL_CLAMP_BLOCK_SIZE 256
|
| 3156 |
#define SYCL_ROPE_BLOCK_SIZE 256
|
|
|
|
| 3157 |
#define SYCL_DIAG_MASK_INF_BLOCK_SIZE 32
|
| 3158 |
#define SYCL_QUANTIZE_BLOCK_SIZE 256
|
| 3159 |
#define SYCL_DEQUANTIZE_BLOCK_SIZE 256
|
|
|
|
| 9315 |
dst[i + half_n_dims * 3] = x2*sin_block_theta + x3*cos_block_theta;
|
| 9316 |
}
|
| 9317 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9318 |
static void k_sum_rows_f32(const float * x, float * dst, const int ncols,
|
| 9319 |
const sycl::nd_item<3> &item_ct1) {
|
| 9320 |
const int row = item_ct1.get_group(1);
|
|
|
|
| 9416 |
|
| 9417 |
|
| 9418 |
template <bool vals_smem, int ncols_template, int block_size_template>
|
| 9419 |
+
static void soft_max_f32(const float * x, const float * mask, float * dst, const int ncols_par,
|
| 9420 |
const int nrows_y, const float scale, const float max_bias, const float m0,
|
| 9421 |
const float m1, uint32_t n_head_log2, const sycl::nd_item<3> &item_ct1, float *buf) {
|
| 9422 |
const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
|
|
|
|
| 9430 |
const int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
|
| 9431 |
const int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
|
| 9432 |
|
| 9433 |
+
float slope = 1.0f;
|
| 9434 |
|
| 9435 |
// ALiBi
|
| 9436 |
if (max_bias > 0.0f) {
|
|
|
|
| 9455 |
const int ix = rowx*ncols + col;
|
| 9456 |
const int iy = rowy*ncols + col;
|
| 9457 |
|
| 9458 |
+
const float val = x[ix]*scale + (mask ? slope*mask[iy] : 0.0f);
|
| 9459 |
|
| 9460 |
vals[col] = val;
|
| 9461 |
max_val = sycl::max(max_val, val);
|
|
|
|
| 12937 |
});
|
| 12938 |
}
|
| 12939 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12940 |
static void sum_rows_f32_sycl(const float *x, float *dst, const int ncols,
|
| 12941 |
const int nrows, dpct::queue_ptr stream) {
|
| 12942 |
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
|
|
|
|
| 13017 |
}
|
| 13018 |
|
| 13019 |
template <bool vals_smem, int ncols_template, int block_size_template>
|
| 13020 |
+
static void soft_max_f32_submitter(const float * x, const float * mask, float * dst, const int ncols_par,
|
| 13021 |
const int nrows_y, const float scale, const float max_bias, const float m0,
|
| 13022 |
const float m1, uint32_t n_head_log2, sycl::range<3> block_nums, sycl::range<3> block_dims,
|
| 13023 |
const size_t n_local_scratch, dpct::queue_ptr stream) {
|
|
|
|
| 13027 |
cgh.parallel_for(
|
| 13028 |
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
| 13029 |
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
|
| 13030 |
+
soft_max_f32<vals_smem, ncols_template, block_size_template>(x, mask, dst, ncols_par,
|
| 13031 |
nrows_y, scale, max_bias, m0,
|
| 13032 |
m1, n_head_log2, item_ct1,
|
| 13033 |
local_buf_acc.get_pointer());
|
|
|
|
| 13035 |
});
|
| 13036 |
}
|
| 13037 |
|
| 13038 |
+
static void soft_max_f32_sycl(const float * x, const float * mask,
|
| 13039 |
float * dst, const int ncols_x, const int nrows_x,
|
| 13040 |
const int nrows_y, const float scale, const float max_bias,
|
| 13041 |
dpct::queue_ptr stream) {
|
|
|
|
| 13057 |
const size_t local_mem_size = stream->get_device().get_info<sycl::info::device::local_mem_size>();
|
| 13058 |
if (n_local_scratch*sizeof(float) < local_mem_size) {
|
| 13059 |
if (ncols_x > max_block_size) {
|
| 13060 |
+
soft_max_f32_submitter<true, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale,
|
| 13061 |
max_bias, m0, m1, n_head_log2, block_nums,
|
| 13062 |
block_dims, n_local_scratch, stream);
|
| 13063 |
return;
|
| 13064 |
}
|
| 13065 |
switch (ncols_x) {
|
| 13066 |
case 32:
|
| 13067 |
+
soft_max_f32_submitter<true, 32, 32>(x, mask, dst, ncols_x, nrows_y, scale,
|
| 13068 |
max_bias, m0, m1, n_head_log2, block_nums,
|
| 13069 |
block_dims, n_local_scratch, stream);
|
| 13070 |
break;
|
| 13071 |
case 64:
|
| 13072 |
+
soft_max_f32_submitter<true, 64, 64>(x, mask, dst, ncols_x, nrows_y, scale,
|
| 13073 |
max_bias, m0, m1, n_head_log2, block_nums,
|
| 13074 |
block_dims, n_local_scratch, stream);
|
| 13075 |
break;
|
| 13076 |
case 128:
|
| 13077 |
+
soft_max_f32_submitter<true, 128, 128>(x, mask, dst, ncols_x, nrows_y, scale,
|
| 13078 |
max_bias, m0, m1, n_head_log2, block_nums,
|
| 13079 |
block_dims, n_local_scratch, stream);
|
| 13080 |
break;
|
| 13081 |
case 256:
|
| 13082 |
+
soft_max_f32_submitter<true, 256, 256>(x, mask, dst, ncols_x, nrows_y, scale,
|
| 13083 |
max_bias, m0, m1, n_head_log2, block_nums,
|
| 13084 |
block_dims, n_local_scratch, stream);
|
| 13085 |
break;
|
| 13086 |
case 512:
|
| 13087 |
+
soft_max_f32_submitter<true, 512, 512>(x, mask, dst, ncols_x, nrows_y, scale,
|
| 13088 |
max_bias, m0, m1, n_head_log2, block_nums,
|
| 13089 |
block_dims, n_local_scratch, stream);
|
| 13090 |
break;
|
| 13091 |
case 1024:
|
| 13092 |
+
soft_max_f32_submitter<true, 1024, 1024>(x, mask, dst, ncols_x, nrows_y, scale,
|
| 13093 |
max_bias, m0, m1, n_head_log2, block_nums,
|
| 13094 |
block_dims, n_local_scratch, stream);
|
| 13095 |
break;
|
| 13096 |
case 2048:
|
| 13097 |
+
soft_max_f32_submitter<true, 2048, 1024>(x, mask, dst, ncols_x, nrows_y, scale,
|
| 13098 |
max_bias, m0, m1, n_head_log2, block_nums,
|
| 13099 |
block_dims, n_local_scratch, stream);
|
| 13100 |
break;
|
| 13101 |
case 4096:
|
| 13102 |
+
soft_max_f32_submitter<true, 4096, 1024>(x, mask, dst, ncols_x, nrows_y, scale,
|
| 13103 |
max_bias, m0, m1, n_head_log2, block_nums,
|
| 13104 |
block_dims, n_local_scratch, stream);
|
| 13105 |
break;
|
| 13106 |
default:
|
| 13107 |
+
soft_max_f32_submitter<true, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale,
|
| 13108 |
max_bias, m0, m1, n_head_log2, block_nums,
|
| 13109 |
block_dims, n_local_scratch, stream);
|
| 13110 |
break;
|
| 13111 |
}
|
| 13112 |
} else {
|
| 13113 |
+
soft_max_f32_submitter<false, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale,
|
| 13114 |
max_bias, m0, m1, n_head_log2, block_nums,
|
| 13115 |
block_dims, WARP_SIZE, stream);
|
| 13116 |
}
|
|
|
|
| 14521 |
(void) src1_dd;
|
| 14522 |
}
|
| 14523 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14524 |
static void ggml_sycl_op_pool2d(const ggml_tensor *src0,
|
| 14525 |
const ggml_tensor *src1, ggml_tensor *dst,
|
| 14526 |
const float *src0_dd, const float *src1_dd,
|
|
|
|
| 14675 |
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
| 14676 |
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
| 14677 |
|
| 14678 |
+
#pragma message("TODO: add ggml_sycl_op_soft_max() F16 src1 support")
|
|
|
|
|
|
|
| 14679 |
#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021")
|
| 14680 |
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
|
|
|
|
| 14681 |
|
| 14682 |
const int64_t ne00 = src0->ne[0];
|
| 14683 |
const int64_t nrows_x = ggml_nrows(src0);
|
|
|
|
| 14689 |
memcpy(&scale, dst->op_params + 0, sizeof(float));
|
| 14690 |
memcpy(&max_bias, dst->op_params + 1, sizeof(float));
|
| 14691 |
|
| 14692 |
+
soft_max_f32_sycl(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14693 |
nrows_x, nrows_y, scale, max_bias, main_stream);
|
| 14694 |
}
|
| 14695 |
|
|
|
|
| 16140 |
ggml_sycl_op_flatten(src0, src1, dst, ggml_sycl_op_rope);
|
| 16141 |
}
|
| 16142 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16143 |
static void ggml_sycl_pool2d(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 16144 |
ggml_sycl_op_flatten(src0, src1, dst, ggml_sycl_op_pool2d);
|
| 16145 |
}
|
|
|
|
| 16516 |
case GGML_OP_ROPE:
|
| 16517 |
func = ggml_sycl_rope;
|
| 16518 |
break;
|
|
|
|
|
|
|
|
|
|
| 16519 |
case GGML_OP_IM2COL:
|
| 16520 |
func = ggml_sycl_im2col;
|
| 16521 |
break;
|
|
|
|
| 17645 |
case GGML_OP_DIAG_MASK_INF:
|
| 17646 |
case GGML_OP_SOFT_MAX:
|
| 17647 |
case GGML_OP_ROPE:
|
|
|
|
| 17648 |
case GGML_OP_IM2COL:
|
| 17649 |
case GGML_OP_POOL_2D:
|
| 17650 |
case GGML_OP_SUM_ROWS:
|
ggml-vulkan.cpp
CHANGED
|
@@ -3830,9 +3830,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
| 3830 |
return nullptr;
|
| 3831 |
case GGML_OP_SOFT_MAX:
|
| 3832 |
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);
|
| 3833 |
-
GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32 || src2->type == GGML_TYPE_F16);
|
| 3834 |
|
| 3835 |
-
if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) &&
|
| 3836 |
return ctx->device->pipeline_soft_max_f32;
|
| 3837 |
}
|
| 3838 |
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16 && src2->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
|
|
@@ -4286,6 +4285,9 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context * subctx,
|
|
| 4286 |
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
| 4287 |
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
| 4288 |
|
|
|
|
|
|
|
|
|
|
| 4289 |
ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_SOFT_MAX, {
|
| 4290 |
ncols,
|
| 4291 |
src1 != nullptr ? nrows_y : (uint32_t)0,
|
|
|
|
| 3830 |
return nullptr;
|
| 3831 |
case GGML_OP_SOFT_MAX:
|
| 3832 |
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);
|
|
|
|
| 3833 |
|
| 3834 |
+
if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
|
| 3835 |
return ctx->device->pipeline_soft_max_f32;
|
| 3836 |
}
|
| 3837 |
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16 && src2->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
|
|
|
|
| 4285 |
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
| 4286 |
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
| 4287 |
|
| 4288 |
+
#pragma message("TODO: src2 is no longer used in soft_max - should be removed and ALiBi calculation should be updated")
|
| 4289 |
+
#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/7192")
|
| 4290 |
+
|
| 4291 |
ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_SOFT_MAX, {
|
| 4292 |
ncols,
|
| 4293 |
src1 != nullptr ? nrows_y : (uint32_t)0,
|
ggml.c
CHANGED
|
@@ -2186,7 +2186,6 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|
| 2186 |
"SOFT_MAX_BACK",
|
| 2187 |
"ROPE",
|
| 2188 |
"ROPE_BACK",
|
| 2189 |
-
"ALIBI",
|
| 2190 |
"CLAMP",
|
| 2191 |
"CONV_TRANSPOSE_1D",
|
| 2192 |
"IM2COL",
|
|
@@ -2228,7 +2227,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|
| 2228 |
"CROSS_ENTROPY_LOSS_BACK",
|
| 2229 |
};
|
| 2230 |
|
| 2231 |
-
static_assert(GGML_OP_COUNT ==
|
| 2232 |
|
| 2233 |
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
| 2234 |
"none",
|
|
@@ -2277,7 +2276,6 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|
| 2277 |
"soft_max_back(x)",
|
| 2278 |
"rope(x)",
|
| 2279 |
"rope_back(x)",
|
| 2280 |
-
"alibi(x)",
|
| 2281 |
"clamp(x)",
|
| 2282 |
"conv_transpose_1d(x)",
|
| 2283 |
"im2col(x)",
|
|
@@ -2319,7 +2317,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|
| 2319 |
"cross_entropy_loss_back(x,y)",
|
| 2320 |
};
|
| 2321 |
|
| 2322 |
-
static_assert(GGML_OP_COUNT ==
|
| 2323 |
|
| 2324 |
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
| 2325 |
|
|
@@ -5662,7 +5660,6 @@ static struct ggml_tensor * ggml_soft_max_impl(
|
|
| 5662 |
struct ggml_context * ctx,
|
| 5663 |
struct ggml_tensor * a,
|
| 5664 |
struct ggml_tensor * mask,
|
| 5665 |
-
struct ggml_tensor * pos,
|
| 5666 |
float scale,
|
| 5667 |
float max_bias,
|
| 5668 |
bool inplace) {
|
|
@@ -5676,18 +5673,8 @@ static struct ggml_tensor * ggml_soft_max_impl(
|
|
| 5676 |
GGML_ASSERT(mask->ne[1] >= a->ne[1]);
|
| 5677 |
}
|
| 5678 |
|
| 5679 |
-
if (pos) {
|
| 5680 |
-
GGML_ASSERT(ggml_is_vector(pos));
|
| 5681 |
-
GGML_ASSERT(pos->type == GGML_TYPE_F16 || pos->type == GGML_TYPE_F32);
|
| 5682 |
-
GGML_ASSERT(pos->ne[0] == a->ne[0]);
|
| 5683 |
-
}
|
| 5684 |
-
|
| 5685 |
-
if (pos && mask) {
|
| 5686 |
-
GGML_ASSERT(pos->type == mask->type);
|
| 5687 |
-
}
|
| 5688 |
-
|
| 5689 |
if (max_bias > 0.0f) {
|
| 5690 |
-
GGML_ASSERT(
|
| 5691 |
}
|
| 5692 |
|
| 5693 |
bool is_node = false;
|
|
@@ -5705,7 +5692,6 @@ static struct ggml_tensor * ggml_soft_max_impl(
|
|
| 5705 |
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
| 5706 |
result->src[0] = a;
|
| 5707 |
result->src[1] = mask;
|
| 5708 |
-
result->src[2] = pos;
|
| 5709 |
|
| 5710 |
return result;
|
| 5711 |
}
|
|
@@ -5713,23 +5699,22 @@ static struct ggml_tensor * ggml_soft_max_impl(
|
|
| 5713 |
struct ggml_tensor * ggml_soft_max(
|
| 5714 |
struct ggml_context * ctx,
|
| 5715 |
struct ggml_tensor * a) {
|
| 5716 |
-
return ggml_soft_max_impl(ctx, a, NULL,
|
| 5717 |
}
|
| 5718 |
|
| 5719 |
struct ggml_tensor * ggml_soft_max_inplace(
|
| 5720 |
struct ggml_context * ctx,
|
| 5721 |
struct ggml_tensor * a) {
|
| 5722 |
-
return ggml_soft_max_impl(ctx, a, NULL,
|
| 5723 |
}
|
| 5724 |
|
| 5725 |
struct ggml_tensor * ggml_soft_max_ext(
|
| 5726 |
struct ggml_context * ctx,
|
| 5727 |
struct ggml_tensor * a,
|
| 5728 |
struct ggml_tensor * mask,
|
| 5729 |
-
struct ggml_tensor * pos,
|
| 5730 |
float scale,
|
| 5731 |
float max_bias) {
|
| 5732 |
-
return ggml_soft_max_impl(ctx, a, mask,
|
| 5733 |
}
|
| 5734 |
|
| 5735 |
// ggml_soft_max_back
|
|
@@ -5944,37 +5929,6 @@ struct ggml_tensor * ggml_rope_back(
|
|
| 5944 |
return result;
|
| 5945 |
}
|
| 5946 |
|
| 5947 |
-
// ggml_alibi
|
| 5948 |
-
|
| 5949 |
-
struct ggml_tensor * ggml_alibi(
|
| 5950 |
-
struct ggml_context * ctx,
|
| 5951 |
-
struct ggml_tensor * a,
|
| 5952 |
-
int n_past,
|
| 5953 |
-
int n_head,
|
| 5954 |
-
float bias_max) {
|
| 5955 |
-
GGML_ASSERT(n_past >= 0);
|
| 5956 |
-
bool is_node = false;
|
| 5957 |
-
|
| 5958 |
-
if (a->grad) {
|
| 5959 |
-
GGML_ASSERT(false); // TODO: implement backward
|
| 5960 |
-
is_node = true;
|
| 5961 |
-
}
|
| 5962 |
-
|
| 5963 |
-
// TODO: when implement backward, fix this:
|
| 5964 |
-
//struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
| 5965 |
-
struct ggml_tensor * result = ggml_view_tensor(ctx, a);
|
| 5966 |
-
|
| 5967 |
-
int32_t op_params[3] = { n_past, n_head };
|
| 5968 |
-
memcpy(op_params + 2, &bias_max, sizeof(float));
|
| 5969 |
-
ggml_set_op_params(result, op_params, sizeof(op_params));
|
| 5970 |
-
|
| 5971 |
-
result->op = GGML_OP_ALIBI;
|
| 5972 |
-
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
| 5973 |
-
result->src[0] = a;
|
| 5974 |
-
|
| 5975 |
-
return result;
|
| 5976 |
-
}
|
| 5977 |
-
|
| 5978 |
// ggml_clamp
|
| 5979 |
|
| 5980 |
struct ggml_tensor * ggml_clamp(
|
|
@@ -6502,9 +6456,11 @@ struct ggml_tensor * ggml_flash_attn_ext(
|
|
| 6502 |
struct ggml_tensor * k,
|
| 6503 |
struct ggml_tensor * v,
|
| 6504 |
struct ggml_tensor * mask,
|
| 6505 |
-
float scale
|
|
|
|
| 6506 |
GGML_ASSERT(ggml_can_mul_mat(k, q));
|
| 6507 |
// TODO: check if vT can be multiplied by (k*qT)
|
|
|
|
| 6508 |
if (mask) {
|
| 6509 |
GGML_ASSERT(ggml_is_contiguous(mask));
|
| 6510 |
GGML_ASSERT(mask->ne[2] == 1);
|
|
@@ -6514,6 +6470,10 @@ struct ggml_tensor * ggml_flash_attn_ext(
|
|
| 6514 |
//GGML_ASSERT(ggml_can_repeat_rows(mask, qk));
|
| 6515 |
}
|
| 6516 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6517 |
bool is_node = false;
|
| 6518 |
|
| 6519 |
if (q->grad || k->grad || v->grad) {
|
|
@@ -6524,7 +6484,7 @@ struct ggml_tensor * ggml_flash_attn_ext(
|
|
| 6524 |
int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] };
|
| 6525 |
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
|
| 6526 |
|
| 6527 |
-
float params[] = { scale };
|
| 6528 |
ggml_set_op_params(result, params, sizeof(params));
|
| 6529 |
|
| 6530 |
result->op = GGML_OP_FLASH_ATTN_EXT;
|
|
@@ -6544,7 +6504,7 @@ void ggml_flash_attn_ext_set_prec(
|
|
| 6544 |
|
| 6545 |
const int32_t prec_i32 = (int32_t) prec;
|
| 6546 |
|
| 6547 |
-
ggml_set_op_params_i32(a,
|
| 6548 |
}
|
| 6549 |
|
| 6550 |
// ggml_flash_ff
|
|
@@ -13395,7 +13355,6 @@ static void ggml_compute_forward_soft_max_f32(
|
|
| 13395 |
|
| 13396 |
const struct ggml_tensor * src0 = dst->src[0];
|
| 13397 |
const struct ggml_tensor * src1 = dst->src[1];
|
| 13398 |
-
const struct ggml_tensor * src2 = dst->src[2];
|
| 13399 |
|
| 13400 |
assert(ggml_is_contiguous(dst));
|
| 13401 |
assert(ggml_are_same_shape(src0, dst));
|
|
@@ -13421,8 +13380,8 @@ static void ggml_compute_forward_soft_max_f32(
|
|
| 13421 |
|
| 13422 |
// TODO: is this supposed to be ceil instead of floor?
|
| 13423 |
// https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
|
| 13424 |
-
const uint32_t
|
| 13425 |
-
const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(
|
| 13426 |
|
| 13427 |
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
| 13428 |
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
|
@@ -13439,13 +13398,13 @@ static void ggml_compute_forward_soft_max_f32(
|
|
| 13439 |
|
| 13440 |
float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
|
| 13441 |
|
| 13442 |
-
|
| 13443 |
-
ggml_fp16_t * pos_f16 = src2 ? (ggml_fp16_t *) src2->data : src0->data;
|
| 13444 |
-
float * pos_f32 = src2 ? (float *) src2->data : src0->data;
|
| 13445 |
-
|
| 13446 |
-
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16);
|
| 13447 |
|
| 13448 |
for (int i1 = ir0; i1 < ir1; i1++) {
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13449 |
float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
|
| 13450 |
float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);
|
| 13451 |
|
|
@@ -13458,27 +13417,11 @@ static void ggml_compute_forward_soft_max_f32(
|
|
| 13458 |
if (mp_f32) {
|
| 13459 |
if (use_f16) {
|
| 13460 |
for (int i = 0; i < nc; ++i) {
|
| 13461 |
-
wp[i] += GGML_FP16_TO_FP32(mp_f16[i]);
|
| 13462 |
}
|
| 13463 |
} else {
|
| 13464 |
for (int i = 0; i < nc; ++i) {
|
| 13465 |
-
wp[i] += mp_f32[i];
|
| 13466 |
-
}
|
| 13467 |
-
}
|
| 13468 |
-
}
|
| 13469 |
-
|
| 13470 |
-
// ALiBi bias
|
| 13471 |
-
if (max_bias > 0.0f) {
|
| 13472 |
-
const uint32_t h = (i1/ne01)%ne02; // head
|
| 13473 |
-
const float slope = h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1);
|
| 13474 |
-
|
| 13475 |
-
if (use_f16) {
|
| 13476 |
-
for (int i = 0; i < nc; ++i) {
|
| 13477 |
-
wp[i] += slope*GGML_FP16_TO_FP32(pos_f16[i]);
|
| 13478 |
-
}
|
| 13479 |
-
} else {
|
| 13480 |
-
for (int i = 0; i < nc; ++i) {
|
| 13481 |
-
wp[i] += slope*pos_f32[i];
|
| 13482 |
}
|
| 13483 |
}
|
| 13484 |
}
|
|
@@ -13640,178 +13583,6 @@ static void ggml_compute_forward_soft_max_back(
|
|
| 13640 |
}
|
| 13641 |
}
|
| 13642 |
|
| 13643 |
-
// ggml_compute_forward_alibi
|
| 13644 |
-
|
| 13645 |
-
static void ggml_compute_forward_alibi_f32(
|
| 13646 |
-
const struct ggml_compute_params * params,
|
| 13647 |
-
struct ggml_tensor * dst) {
|
| 13648 |
-
|
| 13649 |
-
const struct ggml_tensor * src0 = dst->src[0];
|
| 13650 |
-
|
| 13651 |
-
assert(params->ith == 0);
|
| 13652 |
-
|
| 13653 |
-
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
|
| 13654 |
-
return;
|
| 13655 |
-
}
|
| 13656 |
-
|
| 13657 |
-
//const int n_past = ((int32_t *) dst->op_params)[0];
|
| 13658 |
-
const int n_head = ((int32_t *) dst->op_params)[1];
|
| 13659 |
-
float max_bias;
|
| 13660 |
-
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
|
| 13661 |
-
|
| 13662 |
-
const int64_t ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
|
| 13663 |
-
const int64_t ne1 = src0->ne[1]; // seq_len_without_past
|
| 13664 |
-
const int64_t ne2 = src0->ne[2]; // n_head -> this is k
|
| 13665 |
-
//const int64_t ne3 = src0->ne[3]; // 1 -> bsz
|
| 13666 |
-
|
| 13667 |
-
const int64_t n = ggml_nrows(src0);
|
| 13668 |
-
const int64_t ne2_ne3 = n/ne1; // ne2*ne3
|
| 13669 |
-
|
| 13670 |
-
const size_t nb0 = src0->nb[0];
|
| 13671 |
-
const size_t nb1 = src0->nb[1];
|
| 13672 |
-
const size_t nb2 = src0->nb[2];
|
| 13673 |
-
//const int nb3 = src0->nb[3];
|
| 13674 |
-
|
| 13675 |
-
GGML_ASSERT(nb0 == sizeof(float));
|
| 13676 |
-
GGML_ASSERT(n_head == ne2);
|
| 13677 |
-
|
| 13678 |
-
// add alibi to src0 (KQ_scaled)
|
| 13679 |
-
const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
|
| 13680 |
-
|
| 13681 |
-
const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
|
| 13682 |
-
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
|
| 13683 |
-
|
| 13684 |
-
for (int64_t k = 0; k < ne2_ne3; k++) {
|
| 13685 |
-
// TODO: k*nb2 or k*nb3
|
| 13686 |
-
float m_k;
|
| 13687 |
-
|
| 13688 |
-
if (k < n_heads_log2_floor) {
|
| 13689 |
-
m_k = powf(m0, k + 1);
|
| 13690 |
-
} else {
|
| 13691 |
-
m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
|
| 13692 |
-
}
|
| 13693 |
-
|
| 13694 |
-
for (int64_t i = 0; i < ne0; i++) {
|
| 13695 |
-
for (int64_t j = 0; j < ne1; j++) {
|
| 13696 |
-
float * const src = (float *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2);
|
| 13697 |
-
float * pdst = (float *)((char *) dst->data + i*nb0 + j*nb1 + k*nb2);
|
| 13698 |
-
pdst[0] = i * m_k + src[0];
|
| 13699 |
-
}
|
| 13700 |
-
}
|
| 13701 |
-
}
|
| 13702 |
-
}
|
| 13703 |
-
|
| 13704 |
-
static void ggml_compute_forward_alibi_f16(
|
| 13705 |
-
const struct ggml_compute_params * params,
|
| 13706 |
-
struct ggml_tensor * dst) {
|
| 13707 |
-
|
| 13708 |
-
const struct ggml_tensor * src0 = dst->src[0];
|
| 13709 |
-
|
| 13710 |
-
assert(params->ith == 0);
|
| 13711 |
-
|
| 13712 |
-
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
|
| 13713 |
-
return;
|
| 13714 |
-
}
|
| 13715 |
-
|
| 13716 |
-
//const int n_past = ((int32_t *) dst->op_params)[0];
|
| 13717 |
-
const int n_head = ((int32_t *) dst->op_params)[1];
|
| 13718 |
-
float max_bias;
|
| 13719 |
-
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
|
| 13720 |
-
|
| 13721 |
-
const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
|
| 13722 |
-
const int ne1 = src0->ne[1]; // seq_len_without_past
|
| 13723 |
-
const int ne2 = src0->ne[2]; // n_head -> this is k
|
| 13724 |
-
//const int ne3 = src0->ne[3]; // 1 -> bsz
|
| 13725 |
-
|
| 13726 |
-
const int n = ggml_nrows(src0);
|
| 13727 |
-
const int ne2_ne3 = n/ne1; // ne2*ne3
|
| 13728 |
-
|
| 13729 |
-
const int nb0 = src0->nb[0];
|
| 13730 |
-
const int nb1 = src0->nb[1];
|
| 13731 |
-
const int nb2 = src0->nb[2];
|
| 13732 |
-
//const int nb3 = src0->nb[3];
|
| 13733 |
-
|
| 13734 |
-
GGML_ASSERT(nb0 == sizeof(ggml_fp16_t));
|
| 13735 |
-
//GGML_ASSERT(ne1 + n_past == ne0); (void) n_past;
|
| 13736 |
-
GGML_ASSERT(n_head == ne2);
|
| 13737 |
-
|
| 13738 |
-
// add alibi to src0 (KQ_scaled)
|
| 13739 |
-
const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
|
| 13740 |
-
|
| 13741 |
-
const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
|
| 13742 |
-
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
|
| 13743 |
-
|
| 13744 |
-
for (int k = 0; k < ne2_ne3; k++) {
|
| 13745 |
-
// TODO: k*nb2 or k*nb3
|
| 13746 |
-
float m_k;
|
| 13747 |
-
|
| 13748 |
-
if (k < n_heads_log2_floor) {
|
| 13749 |
-
m_k = powf(m0, k + 1);
|
| 13750 |
-
} else {
|
| 13751 |
-
m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
|
| 13752 |
-
}
|
| 13753 |
-
|
| 13754 |
-
for (int i = 0; i < ne0; i++) {
|
| 13755 |
-
for (int j = 0; j < ne1; j++) {
|
| 13756 |
-
ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2);
|
| 13757 |
-
float * pdst = (float *)((char *) dst->data + i*nb0 + j*nb1 + k*nb2);
|
| 13758 |
-
|
| 13759 |
-
// we return F32
|
| 13760 |
-
pdst[0] = i * m_k + GGML_FP16_TO_FP32(src[0]);
|
| 13761 |
-
}
|
| 13762 |
-
}
|
| 13763 |
-
}
|
| 13764 |
-
}
|
| 13765 |
-
|
| 13766 |
-
static void ggml_compute_forward_alibi(
|
| 13767 |
-
const struct ggml_compute_params * params,
|
| 13768 |
-
struct ggml_tensor * dst) {
|
| 13769 |
-
|
| 13770 |
-
const struct ggml_tensor * src0 = dst->src[0];
|
| 13771 |
-
|
| 13772 |
-
switch (src0->type) {
|
| 13773 |
-
case GGML_TYPE_F16:
|
| 13774 |
-
{
|
| 13775 |
-
ggml_compute_forward_alibi_f16(params, dst);
|
| 13776 |
-
} break;
|
| 13777 |
-
case GGML_TYPE_F32:
|
| 13778 |
-
{
|
| 13779 |
-
ggml_compute_forward_alibi_f32(params, dst);
|
| 13780 |
-
} break;
|
| 13781 |
-
case GGML_TYPE_BF16:
|
| 13782 |
-
case GGML_TYPE_Q4_0:
|
| 13783 |
-
case GGML_TYPE_Q4_1:
|
| 13784 |
-
case GGML_TYPE_Q5_0:
|
| 13785 |
-
case GGML_TYPE_Q5_1:
|
| 13786 |
-
case GGML_TYPE_Q8_0:
|
| 13787 |
-
case GGML_TYPE_Q8_1:
|
| 13788 |
-
case GGML_TYPE_Q2_K:
|
| 13789 |
-
case GGML_TYPE_Q3_K:
|
| 13790 |
-
case GGML_TYPE_Q4_K:
|
| 13791 |
-
case GGML_TYPE_Q5_K:
|
| 13792 |
-
case GGML_TYPE_Q6_K:
|
| 13793 |
-
case GGML_TYPE_IQ2_XXS:
|
| 13794 |
-
case GGML_TYPE_IQ2_XS:
|
| 13795 |
-
case GGML_TYPE_IQ3_XXS:
|
| 13796 |
-
case GGML_TYPE_IQ1_S:
|
| 13797 |
-
case GGML_TYPE_IQ1_M:
|
| 13798 |
-
case GGML_TYPE_IQ4_NL:
|
| 13799 |
-
case GGML_TYPE_IQ4_XS:
|
| 13800 |
-
case GGML_TYPE_IQ3_S:
|
| 13801 |
-
case GGML_TYPE_IQ2_S:
|
| 13802 |
-
case GGML_TYPE_Q8_K:
|
| 13803 |
-
case GGML_TYPE_I8:
|
| 13804 |
-
case GGML_TYPE_I16:
|
| 13805 |
-
case GGML_TYPE_I32:
|
| 13806 |
-
case GGML_TYPE_I64:
|
| 13807 |
-
case GGML_TYPE_F64:
|
| 13808 |
-
case GGML_TYPE_COUNT:
|
| 13809 |
-
{
|
| 13810 |
-
GGML_ASSERT(false);
|
| 13811 |
-
} break;
|
| 13812 |
-
}
|
| 13813 |
-
}
|
| 13814 |
-
|
| 13815 |
// ggml_compute_forward_clamp
|
| 13816 |
|
| 13817 |
static void ggml_compute_forward_clamp_f32(
|
|
@@ -15825,8 +15596,17 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
| 15825 |
const int ir0 = dr*ith;
|
| 15826 |
const int ir1 = MIN(ir0 + dr, nr);
|
| 15827 |
|
| 15828 |
-
float scale
|
| 15829 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15830 |
|
| 15831 |
// loop over n_batch and n_head
|
| 15832 |
for (int ir = ir0; ir < ir1; ++ir) {
|
|
@@ -15835,6 +15615,9 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
| 15835 |
const int iq2 = (ir - iq3*neq2*neq1)/neq1;
|
| 15836 |
const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
|
| 15837 |
|
|
|
|
|
|
|
|
|
|
| 15838 |
float S = 0.0f;
|
| 15839 |
float M = -INFINITY;
|
| 15840 |
|
|
@@ -15858,7 +15641,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
| 15858 |
// loop over n_kv and n_head_kv
|
| 15859 |
// ref: https://arxiv.org/pdf/2112.05682.pdf
|
| 15860 |
for (int64_t ic = 0; ic < nek1; ++ic) {
|
| 15861 |
-
const float mv = mp ? GGML_FP16_TO_FP32(mp[ic]) : 0.0f;
|
| 15862 |
if (mv == -INFINITY) {
|
| 15863 |
continue;
|
| 15864 |
}
|
|
@@ -15929,7 +15712,7 @@ static void ggml_compute_forward_flash_attn_ext(
|
|
| 15929 |
const struct ggml_tensor * v,
|
| 15930 |
const struct ggml_tensor * mask,
|
| 15931 |
struct ggml_tensor * dst) {
|
| 15932 |
-
switch (dst->op_params[
|
| 15933 |
case GGML_PREC_DEFAULT:
|
| 15934 |
case GGML_PREC_F32:
|
| 15935 |
{
|
|
@@ -17696,10 +17479,6 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|
| 17696 |
{
|
| 17697 |
ggml_compute_forward_rope_back(params, tensor);
|
| 17698 |
} break;
|
| 17699 |
-
case GGML_OP_ALIBI:
|
| 17700 |
-
{
|
| 17701 |
-
ggml_compute_forward_alibi(params, tensor);
|
| 17702 |
-
} break;
|
| 17703 |
case GGML_OP_CLAMP:
|
| 17704 |
{
|
| 17705 |
ggml_compute_forward_clamp(params, tensor);
|
|
@@ -18718,10 +18497,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
| 18718 |
zero_table);
|
| 18719 |
}
|
| 18720 |
} break;
|
| 18721 |
-
case GGML_OP_ALIBI:
|
| 18722 |
-
{
|
| 18723 |
-
GGML_ASSERT(false); // TODO: not implemented
|
| 18724 |
-
} break;
|
| 18725 |
case GGML_OP_CLAMP:
|
| 18726 |
{
|
| 18727 |
GGML_ASSERT(false); // TODO: not implemented
|
|
@@ -19499,10 +19274,6 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_
|
|
| 19499 |
{
|
| 19500 |
n_tasks = n_threads;
|
| 19501 |
} break;
|
| 19502 |
-
case GGML_OP_ALIBI:
|
| 19503 |
-
{
|
| 19504 |
-
n_tasks = 1; //TODO
|
| 19505 |
-
} break;
|
| 19506 |
case GGML_OP_CLAMP:
|
| 19507 |
{
|
| 19508 |
n_tasks = 1; //TODO
|
|
|
|
| 2186 |
"SOFT_MAX_BACK",
|
| 2187 |
"ROPE",
|
| 2188 |
"ROPE_BACK",
|
|
|
|
| 2189 |
"CLAMP",
|
| 2190 |
"CONV_TRANSPOSE_1D",
|
| 2191 |
"IM2COL",
|
|
|
|
| 2227 |
"CROSS_ENTROPY_LOSS_BACK",
|
| 2228 |
};
|
| 2229 |
|
| 2230 |
+
static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
|
| 2231 |
|
| 2232 |
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
| 2233 |
"none",
|
|
|
|
| 2276 |
"soft_max_back(x)",
|
| 2277 |
"rope(x)",
|
| 2278 |
"rope_back(x)",
|
|
|
|
| 2279 |
"clamp(x)",
|
| 2280 |
"conv_transpose_1d(x)",
|
| 2281 |
"im2col(x)",
|
|
|
|
| 2317 |
"cross_entropy_loss_back(x,y)",
|
| 2318 |
};
|
| 2319 |
|
| 2320 |
+
static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
|
| 2321 |
|
| 2322 |
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
| 2323 |
|
|
|
|
| 5660 |
struct ggml_context * ctx,
|
| 5661 |
struct ggml_tensor * a,
|
| 5662 |
struct ggml_tensor * mask,
|
|
|
|
| 5663 |
float scale,
|
| 5664 |
float max_bias,
|
| 5665 |
bool inplace) {
|
|
|
|
| 5673 |
GGML_ASSERT(mask->ne[1] >= a->ne[1]);
|
| 5674 |
}
|
| 5675 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5676 |
if (max_bias > 0.0f) {
|
| 5677 |
+
GGML_ASSERT(mask);
|
| 5678 |
}
|
| 5679 |
|
| 5680 |
bool is_node = false;
|
|
|
|
| 5692 |
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
| 5693 |
result->src[0] = a;
|
| 5694 |
result->src[1] = mask;
|
|
|
|
| 5695 |
|
| 5696 |
return result;
|
| 5697 |
}
|
|
|
|
| 5699 |
struct ggml_tensor * ggml_soft_max(
|
| 5700 |
struct ggml_context * ctx,
|
| 5701 |
struct ggml_tensor * a) {
|
| 5702 |
+
return ggml_soft_max_impl(ctx, a, NULL, 1.0f, 0.0f, false);
|
| 5703 |
}
|
| 5704 |
|
| 5705 |
struct ggml_tensor * ggml_soft_max_inplace(
|
| 5706 |
struct ggml_context * ctx,
|
| 5707 |
struct ggml_tensor * a) {
|
| 5708 |
+
return ggml_soft_max_impl(ctx, a, NULL, 1.0f, 0.0f, true);
|
| 5709 |
}
|
| 5710 |
|
| 5711 |
struct ggml_tensor * ggml_soft_max_ext(
|
| 5712 |
struct ggml_context * ctx,
|
| 5713 |
struct ggml_tensor * a,
|
| 5714 |
struct ggml_tensor * mask,
|
|
|
|
| 5715 |
float scale,
|
| 5716 |
float max_bias) {
|
| 5717 |
+
return ggml_soft_max_impl(ctx, a, mask, scale, max_bias, false);
|
| 5718 |
}
|
| 5719 |
|
| 5720 |
// ggml_soft_max_back
|
|
|
|
| 5929 |
return result;
|
| 5930 |
}
|
| 5931 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5932 |
// ggml_clamp
|
| 5933 |
|
| 5934 |
struct ggml_tensor * ggml_clamp(
|
|
|
|
| 6456 |
struct ggml_tensor * k,
|
| 6457 |
struct ggml_tensor * v,
|
| 6458 |
struct ggml_tensor * mask,
|
| 6459 |
+
float scale,
|
| 6460 |
+
float max_bias) {
|
| 6461 |
GGML_ASSERT(ggml_can_mul_mat(k, q));
|
| 6462 |
// TODO: check if vT can be multiplied by (k*qT)
|
| 6463 |
+
|
| 6464 |
if (mask) {
|
| 6465 |
GGML_ASSERT(ggml_is_contiguous(mask));
|
| 6466 |
GGML_ASSERT(mask->ne[2] == 1);
|
|
|
|
| 6470 |
//GGML_ASSERT(ggml_can_repeat_rows(mask, qk));
|
| 6471 |
}
|
| 6472 |
|
| 6473 |
+
if (max_bias > 0.0f) {
|
| 6474 |
+
GGML_ASSERT(mask);
|
| 6475 |
+
}
|
| 6476 |
+
|
| 6477 |
bool is_node = false;
|
| 6478 |
|
| 6479 |
if (q->grad || k->grad || v->grad) {
|
|
|
|
| 6484 |
int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] };
|
| 6485 |
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
|
| 6486 |
|
| 6487 |
+
float params[] = { scale, max_bias };
|
| 6488 |
ggml_set_op_params(result, params, sizeof(params));
|
| 6489 |
|
| 6490 |
result->op = GGML_OP_FLASH_ATTN_EXT;
|
|
|
|
| 6504 |
|
| 6505 |
const int32_t prec_i32 = (int32_t) prec;
|
| 6506 |
|
| 6507 |
+
ggml_set_op_params_i32(a, 2, prec_i32); // scale is on first pos, max_bias on second
|
| 6508 |
}
|
| 6509 |
|
| 6510 |
// ggml_flash_ff
|
|
|
|
| 13355 |
|
| 13356 |
const struct ggml_tensor * src0 = dst->src[0];
|
| 13357 |
const struct ggml_tensor * src1 = dst->src[1];
|
|
|
|
| 13358 |
|
| 13359 |
assert(ggml_is_contiguous(dst));
|
| 13360 |
assert(ggml_are_same_shape(src0, dst));
|
|
|
|
| 13380 |
|
| 13381 |
// TODO: is this supposed to be ceil instead of floor?
|
| 13382 |
// https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
|
| 13383 |
+
const uint32_t n_head = ne02;
|
| 13384 |
+
const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
|
| 13385 |
|
| 13386 |
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
| 13387 |
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
|
|
|
| 13398 |
|
| 13399 |
float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
|
| 13400 |
|
| 13401 |
+
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13402 |
|
| 13403 |
for (int i1 = ir0; i1 < ir1; i1++) {
|
| 13404 |
+
// ALiBi
|
| 13405 |
+
const uint32_t h = (i1/ne01)%ne02; // head
|
| 13406 |
+
const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
|
| 13407 |
+
|
| 13408 |
float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
|
| 13409 |
float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);
|
| 13410 |
|
|
|
|
| 13417 |
if (mp_f32) {
|
| 13418 |
if (use_f16) {
|
| 13419 |
for (int i = 0; i < nc; ++i) {
|
| 13420 |
+
wp[i] += slope*GGML_FP16_TO_FP32(mp_f16[i]);
|
| 13421 |
}
|
| 13422 |
} else {
|
| 13423 |
for (int i = 0; i < nc; ++i) {
|
| 13424 |
+
wp[i] += slope*mp_f32[i];
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13425 |
}
|
| 13426 |
}
|
| 13427 |
}
|
|
|
|
| 13583 |
}
|
| 13584 |
}
|
| 13585 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13586 |
// ggml_compute_forward_clamp
|
| 13587 |
|
| 13588 |
static void ggml_compute_forward_clamp_f32(
|
|
|
|
| 15596 |
const int ir0 = dr*ith;
|
| 15597 |
const int ir1 = MIN(ir0 + dr, nr);
|
| 15598 |
|
| 15599 |
+
float scale = 1.0f;
|
| 15600 |
+
float max_bias = 0.0f;
|
| 15601 |
+
|
| 15602 |
+
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
|
| 15603 |
+
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
|
| 15604 |
+
|
| 15605 |
+
const uint32_t n_head = neq2;
|
| 15606 |
+
const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
|
| 15607 |
+
|
| 15608 |
+
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
| 15609 |
+
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
| 15610 |
|
| 15611 |
// loop over n_batch and n_head
|
| 15612 |
for (int ir = ir0; ir < ir1; ++ir) {
|
|
|
|
| 15615 |
const int iq2 = (ir - iq3*neq2*neq1)/neq1;
|
| 15616 |
const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
|
| 15617 |
|
| 15618 |
+
const uint32_t h = iq2; // head
|
| 15619 |
+
const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
|
| 15620 |
+
|
| 15621 |
float S = 0.0f;
|
| 15622 |
float M = -INFINITY;
|
| 15623 |
|
|
|
|
| 15641 |
// loop over n_kv and n_head_kv
|
| 15642 |
// ref: https://arxiv.org/pdf/2112.05682.pdf
|
| 15643 |
for (int64_t ic = 0; ic < nek1; ++ic) {
|
| 15644 |
+
const float mv = mp ? slope*GGML_FP16_TO_FP32(mp[ic]) : 0.0f;
|
| 15645 |
if (mv == -INFINITY) {
|
| 15646 |
continue;
|
| 15647 |
}
|
|
|
|
| 15712 |
const struct ggml_tensor * v,
|
| 15713 |
const struct ggml_tensor * mask,
|
| 15714 |
struct ggml_tensor * dst) {
|
| 15715 |
+
switch (dst->op_params[2]) {
|
| 15716 |
case GGML_PREC_DEFAULT:
|
| 15717 |
case GGML_PREC_F32:
|
| 15718 |
{
|
|
|
|
| 17479 |
{
|
| 17480 |
ggml_compute_forward_rope_back(params, tensor);
|
| 17481 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17482 |
case GGML_OP_CLAMP:
|
| 17483 |
{
|
| 17484 |
ggml_compute_forward_clamp(params, tensor);
|
|
|
|
| 18497 |
zero_table);
|
| 18498 |
}
|
| 18499 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18500 |
case GGML_OP_CLAMP:
|
| 18501 |
{
|
| 18502 |
GGML_ASSERT(false); // TODO: not implemented
|
|
|
|
| 19274 |
{
|
| 19275 |
n_tasks = n_threads;
|
| 19276 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19277 |
case GGML_OP_CLAMP:
|
| 19278 |
{
|
| 19279 |
n_tasks = 1; //TODO
|
ggml.h
CHANGED
|
@@ -468,7 +468,6 @@ extern "C" {
|
|
| 468 |
GGML_OP_SOFT_MAX_BACK,
|
| 469 |
GGML_OP_ROPE,
|
| 470 |
GGML_OP_ROPE_BACK,
|
| 471 |
-
GGML_OP_ALIBI,
|
| 472 |
GGML_OP_CLAMP,
|
| 473 |
GGML_OP_CONV_TRANSPOSE_1D,
|
| 474 |
GGML_OP_IM2COL,
|
|
@@ -1437,15 +1436,13 @@ extern "C" {
|
|
| 1437 |
struct ggml_context * ctx,
|
| 1438 |
struct ggml_tensor * a);
|
| 1439 |
|
| 1440 |
-
// fused soft_max(a*scale + mask
|
| 1441 |
// mask is optional
|
| 1442 |
-
// pos is required when max_bias > 0.0f
|
| 1443 |
// max_bias = 0.0f for no ALiBi
|
| 1444 |
GGML_API struct ggml_tensor * ggml_soft_max_ext(
|
| 1445 |
struct ggml_context * ctx,
|
| 1446 |
struct ggml_tensor * a,
|
| 1447 |
struct ggml_tensor * mask,
|
| 1448 |
-
struct ggml_tensor * pos,
|
| 1449 |
float scale,
|
| 1450 |
float max_bias);
|
| 1451 |
|
|
@@ -1547,16 +1544,6 @@ extern "C" {
|
|
| 1547 |
float xpos_base,
|
| 1548 |
bool xpos_down);
|
| 1549 |
|
| 1550 |
-
// alibi position embedding
|
| 1551 |
-
// in-place, returns view(a)
|
| 1552 |
-
GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_alibi(
|
| 1553 |
-
struct ggml_context * ctx,
|
| 1554 |
-
struct ggml_tensor * a,
|
| 1555 |
-
int n_past,
|
| 1556 |
-
int n_head,
|
| 1557 |
-
float bias_max),
|
| 1558 |
-
"use ggml_soft_max_ext instead (will be removed in Mar 2024)");
|
| 1559 |
-
|
| 1560 |
// clamp
|
| 1561 |
// in-place, returns view(a)
|
| 1562 |
GGML_API struct ggml_tensor * ggml_clamp(
|
|
@@ -1753,7 +1740,8 @@ extern "C" {
|
|
| 1753 |
struct ggml_tensor * k,
|
| 1754 |
struct ggml_tensor * v,
|
| 1755 |
struct ggml_tensor * mask,
|
| 1756 |
-
float scale
|
|
|
|
| 1757 |
|
| 1758 |
GGML_API void ggml_flash_attn_ext_set_prec(
|
| 1759 |
struct ggml_tensor * a,
|
|
|
|
| 468 |
GGML_OP_SOFT_MAX_BACK,
|
| 469 |
GGML_OP_ROPE,
|
| 470 |
GGML_OP_ROPE_BACK,
|
|
|
|
| 471 |
GGML_OP_CLAMP,
|
| 472 |
GGML_OP_CONV_TRANSPOSE_1D,
|
| 473 |
GGML_OP_IM2COL,
|
|
|
|
| 1436 |
struct ggml_context * ctx,
|
| 1437 |
struct ggml_tensor * a);
|
| 1438 |
|
| 1439 |
+
// fused soft_max(a*scale + mask*(ALiBi slope))
|
| 1440 |
// mask is optional
|
|
|
|
| 1441 |
// max_bias = 0.0f for no ALiBi
|
| 1442 |
GGML_API struct ggml_tensor * ggml_soft_max_ext(
|
| 1443 |
struct ggml_context * ctx,
|
| 1444 |
struct ggml_tensor * a,
|
| 1445 |
struct ggml_tensor * mask,
|
|
|
|
| 1446 |
float scale,
|
| 1447 |
float max_bias);
|
| 1448 |
|
|
|
|
| 1544 |
float xpos_base,
|
| 1545 |
bool xpos_down);
|
| 1546 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1547 |
// clamp
|
| 1548 |
// in-place, returns view(a)
|
| 1549 |
GGML_API struct ggml_tensor * ggml_clamp(
|
|
|
|
| 1740 |
struct ggml_tensor * k,
|
| 1741 |
struct ggml_tensor * v,
|
| 1742 |
struct ggml_tensor * mask,
|
| 1743 |
+
float scale,
|
| 1744 |
+
float max_bias);
|
| 1745 |
|
| 1746 |
GGML_API void ggml_flash_attn_ext_set_prec(
|
| 1747 |
struct ggml_tensor * a,
|