ggerganov commited on
Commit
6cb3028
·
1 Parent(s): ae75124

metal : fix flash attention kernel requirements (llama/7169)

Browse files

* metal : fix flash attention kernel requirements

ggml-ci

* metal : fix ggml_metal_supports_op

ggml-ci

Files changed (1) hide show
  1. ggml-metal.m +10 -9
ggml-metal.m CHANGED
@@ -635,14 +635,14 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
635
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
636
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
637
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
638
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, true);
639
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, true);
640
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, true);
641
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, true);
642
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, true);
643
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, true);
644
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, true);
645
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, true);
646
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
647
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
648
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
@@ -775,8 +775,9 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
775
  case GGML_OP_TIMESTEP_EMBEDDING:
776
  case GGML_OP_ARGSORT:
777
  case GGML_OP_LEAKY_RELU:
778
- case GGML_OP_FLASH_ATTN_EXT:
779
  return true;
 
 
780
  case GGML_OP_MUL_MAT:
781
  case GGML_OP_MUL_MAT_ID:
782
  return ctx->support_simdgroup_reduction &&
 
635
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
636
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
637
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
638
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, ctx->support_simdgroup_mm);
639
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, ctx->support_simdgroup_mm);
640
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, ctx->support_simdgroup_mm);
641
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, ctx->support_simdgroup_mm);
642
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, ctx->support_simdgroup_mm);
643
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, ctx->support_simdgroup_mm);
644
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, ctx->support_simdgroup_reduction);
645
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_reduction);
646
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
647
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
648
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
 
775
  case GGML_OP_TIMESTEP_EMBEDDING:
776
  case GGML_OP_ARGSORT:
777
  case GGML_OP_LEAKY_RELU:
 
778
  return true;
779
+ case GGML_OP_FLASH_ATTN_EXT:
780
+ return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
781
  case GGML_OP_MUL_MAT:
782
  case GGML_OP_MUL_MAT_ID:
783
  return ctx->support_simdgroup_reduction &&