ggerganov commited on
Commit
98ce302
·
unverified ·
1 Parent(s): 5838a14

metal : support FA without mask + add asserts (llama/7278)

Browse files

* ggml : fa without mask + add asserts

ggml-ci

* metal : support non-contiguous KV

ggml-ci

Files changed (4) hide show
  1. ggml-metal.m +38 -31
  2. ggml-metal.metal +20 -33
  3. ggml.c +10 -0
  4. ggml.h +2 -1
ggml-metal.m CHANGED
@@ -2512,13 +2512,14 @@ static enum ggml_status ggml_metal_graph_compute(
2512
  } break;
2513
  case GGML_OP_FLASH_ATTN_EXT:
2514
  {
2515
- GGML_ASSERT(ne00 % 4 == 0);
 
 
2516
  GGML_ASSERT(src0->type == GGML_TYPE_F32);
2517
 
2518
- struct ggml_tensor * src3 = gf->nodes[i]->src[3];
2519
 
2520
- GGML_ASSERT(ggml_are_same_shape(src1, src2));
2521
- GGML_ASSERT(src3);
2522
 
2523
  size_t offs_src3 = 0;
2524
 
@@ -2528,6 +2529,11 @@ static enum ggml_status ggml_metal_graph_compute(
2528
  GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) &&
2529
  "the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big");
2530
 
 
 
 
 
 
2531
  const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30);
2532
  //const int64_t ne31 = src3 ? src3->ne[1] : 0;
2533
  const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32);
@@ -2590,34 +2596,35 @@ static enum ggml_status ggml_metal_graph_compute(
2590
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2591
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
2592
  [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
2593
- [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
 
 
 
 
2594
  [encoder setBuffer:id_dst offset:offs_dst atIndex:4];
2595
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:5];
2596
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:6];
2597
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:7];
2598
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:8];
2599
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:9];
2600
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:10];
2601
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:11];
2602
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:12];
2603
- [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:13];
2604
- [encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:14];
2605
- [encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:15];
2606
- [encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:16];
2607
- [encoder setBytes:&nb10 length:sizeof(uint64_t) atIndex:17];
2608
- [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:18];
2609
- [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:19];
2610
- [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:20];
2611
- [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:21];
2612
- [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:22];
2613
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:23];
2614
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:24];
2615
- [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:25];
2616
- [encoder setBytes:&scale length:sizeof( float) atIndex:26];
2617
- [encoder setBytes:&max_bias length:sizeof( float) atIndex:27];
2618
- [encoder setBytes:&m0 length:sizeof(m0) atIndex:28];
2619
- [encoder setBytes:&m1 length:sizeof(m1) atIndex:29];
2620
- [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:30];
2621
 
2622
  if (!use_vec_kernel) {
2623
  // half8x8 kernel
 
2512
  } break;
2513
  case GGML_OP_FLASH_ATTN_EXT:
2514
  {
2515
+ GGML_ASSERT(ne00 % 4 == 0);
2516
+ GGML_ASSERT(ne11 % 32 == 0);
2517
+
2518
  GGML_ASSERT(src0->type == GGML_TYPE_F32);
2519
 
2520
+ GGML_ASSERT(ggml_are_same_shape (src1, src2));
2521
 
2522
+ struct ggml_tensor * src3 = gf->nodes[i]->src[3];
 
2523
 
2524
  size_t offs_src3 = 0;
2525
 
 
2529
  GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) &&
2530
  "the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big");
2531
 
2532
+ const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
2533
+ const uint64_t nb21 = src2 ? src2->nb[1] : 0;
2534
+ const uint64_t nb22 = src2 ? src2->nb[2] : 0;
2535
+ const uint64_t nb23 = src2 ? src2->nb[3] : 0;
2536
+
2537
  const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30);
2538
  //const int64_t ne31 = src3 ? src3->ne[1] : 0;
2539
  const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32);
 
2596
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2597
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
2598
  [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
2599
+ if (id_src3) {
2600
+ [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
2601
+ } else {
2602
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:3];
2603
+ }
2604
  [encoder setBuffer:id_dst offset:offs_dst atIndex:4];
2605
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5];
2606
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6];
2607
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7];
2608
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
2609
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
2610
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
2611
+ [encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:11];
2612
+ [encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:12];
2613
+ [encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:13];
2614
+ [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:14];
2615
+ [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:15];
2616
+ [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:16];
2617
+ [encoder setBytes:&nb21 length:sizeof(uint64_t) atIndex:17];
2618
+ [encoder setBytes:&nb22 length:sizeof(uint64_t) atIndex:18];
2619
+ [encoder setBytes:&nb23 length:sizeof(uint64_t) atIndex:19];
2620
+ [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:20];
2621
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:21];
2622
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:22];
2623
+ [encoder setBytes:&scale length:sizeof( float) atIndex:23];
2624
+ [encoder setBytes:&max_bias length:sizeof( float) atIndex:24];
2625
+ [encoder setBytes:&m0 length:sizeof(m0) atIndex:25];
2626
+ [encoder setBytes:&m1 length:sizeof(m1) atIndex:26];
2627
+ [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:27];
 
 
 
2628
 
2629
  if (!use_vec_kernel) {
2630
  // half8x8 kernel
ggml-metal.metal CHANGED
@@ -2049,27 +2049,24 @@ typedef void (flash_attn_ext_f16_t)(
2049
  device const char * v,
2050
  device const char * mask,
2051
  device float * dst,
2052
- constant int64_t & ne00,
2053
  constant int64_t & ne01,
2054
  constant int64_t & ne02,
2055
  constant int64_t & ne03,
2056
- constant uint64_t & nb00,
2057
  constant uint64_t & nb01,
2058
  constant uint64_t & nb02,
2059
  constant uint64_t & nb03,
2060
- constant int64_t & ne10,
2061
  constant int64_t & ne11,
2062
  constant int64_t & ne12,
2063
  constant int64_t & ne13,
2064
- constant uint64_t & nb10,
2065
  constant uint64_t & nb11,
2066
  constant uint64_t & nb12,
2067
  constant uint64_t & nb13,
 
 
 
2068
  constant uint64_t & nb31,
2069
- constant int64_t & ne0,
2070
  constant int64_t & ne1,
2071
  constant int64_t & ne2,
2072
- constant int64_t & ne3,
2073
  constant float & scale,
2074
  constant float & max_bias,
2075
  constant float & m0,
@@ -2090,27 +2087,24 @@ kernel void kernel_flash_attn_ext_f16(
2090
  device const char * v,
2091
  device const char * mask,
2092
  device float * dst,
2093
- constant int64_t & ne00,
2094
  constant int64_t & ne01,
2095
  constant int64_t & ne02,
2096
  constant int64_t & ne03,
2097
- constant uint64_t & nb00,
2098
  constant uint64_t & nb01,
2099
  constant uint64_t & nb02,
2100
  constant uint64_t & nb03,
2101
- constant int64_t & ne10,
2102
  constant int64_t & ne11,
2103
  constant int64_t & ne12,
2104
  constant int64_t & ne13,
2105
- constant uint64_t & nb10,
2106
  constant uint64_t & nb11,
2107
  constant uint64_t & nb12,
2108
  constant uint64_t & nb13,
 
 
 
2109
  constant uint64_t & nb31,
2110
- constant int64_t & ne0,
2111
  constant int64_t & ne1,
2112
  constant int64_t & ne2,
2113
- constant int64_t & ne3,
2114
  constant float & scale,
2115
  constant float & max_bias,
2116
  constant float & m0,
@@ -2180,10 +2174,6 @@ kernel void kernel_flash_attn_ext_f16(
2180
  const short ne22 = ne12;
2181
  const short ne23 = ne13;
2182
 
2183
- const uint nb21 = nb11;
2184
- const uint nb22 = nb12;
2185
- const uint nb23 = nb13;
2186
-
2187
  // broadcast
2188
  const short rk2 = ne02/ne12;
2189
  const short rk3 = ne03/ne13;
@@ -2247,11 +2237,16 @@ kernel void kernel_flash_attn_ext_f16(
2247
  simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
2248
  }
2249
 
2250
- // mqk = mqk*scale + mask*slope
2251
- simdgroup_half8x8 mm;
2252
- simdgroup_load(mm, mp + ic + 8*cc, nb31/sizeof(half), 0, false);
2253
- simdgroup_multiply(mm, mslope, mm);
2254
- simdgroup_multiply_accumulate(mqk, mqk, mscale, mm);
 
 
 
 
 
2255
 
2256
  simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
2257
  }
@@ -2425,27 +2420,24 @@ kernel void kernel_flash_attn_ext_vec_f16(
2425
  device const char * v,
2426
  device const char * mask,
2427
  device float * dst,
2428
- constant int64_t & ne00,
2429
  constant int64_t & ne01,
2430
  constant int64_t & ne02,
2431
  constant int64_t & ne03,
2432
- constant uint64_t & nb00,
2433
  constant uint64_t & nb01,
2434
  constant uint64_t & nb02,
2435
  constant uint64_t & nb03,
2436
- constant int64_t & ne10,
2437
  constant int64_t & ne11,
2438
  constant int64_t & ne12,
2439
  constant int64_t & ne13,
2440
- constant uint64_t & nb10,
2441
  constant uint64_t & nb11,
2442
  constant uint64_t & nb12,
2443
  constant uint64_t & nb13,
 
 
 
2444
  constant uint64_t & nb31,
2445
- constant int64_t & ne0,
2446
  constant int64_t & ne1,
2447
  constant int64_t & ne2,
2448
- constant int64_t & ne3,
2449
  constant float & scale,
2450
  constant float & max_bias,
2451
  constant float & m0,
@@ -2521,10 +2513,6 @@ kernel void kernel_flash_attn_ext_vec_f16(
2521
  const short ne22 = ne12;
2522
  const short ne23 = ne13;
2523
 
2524
- const uint nb21 = nb11;
2525
- const uint nb22 = nb12;
2526
- const uint nb23 = nb13;
2527
-
2528
  // broadcast
2529
  const short rk2 = ne02/ne12;
2530
  const short rk3 = ne03/ne13;
@@ -2589,8 +2577,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
2589
 
2590
  // mqk = mqk*scale + mask*slope
2591
  if (tiisg == 0) {
2592
- float4 mm = (float4) mp4[ic/4 + cc];
2593
- mqk = mqk*scale + mm*slope;
2594
 
2595
  ss4[cc] = mqk;
2596
  }
 
2049
  device const char * v,
2050
  device const char * mask,
2051
  device float * dst,
 
2052
  constant int64_t & ne01,
2053
  constant int64_t & ne02,
2054
  constant int64_t & ne03,
 
2055
  constant uint64_t & nb01,
2056
  constant uint64_t & nb02,
2057
  constant uint64_t & nb03,
 
2058
  constant int64_t & ne11,
2059
  constant int64_t & ne12,
2060
  constant int64_t & ne13,
 
2061
  constant uint64_t & nb11,
2062
  constant uint64_t & nb12,
2063
  constant uint64_t & nb13,
2064
+ constant uint64_t & nb21,
2065
+ constant uint64_t & nb22,
2066
+ constant uint64_t & nb23,
2067
  constant uint64_t & nb31,
 
2068
  constant int64_t & ne1,
2069
  constant int64_t & ne2,
 
2070
  constant float & scale,
2071
  constant float & max_bias,
2072
  constant float & m0,
 
2087
  device const char * v,
2088
  device const char * mask,
2089
  device float * dst,
 
2090
  constant int64_t & ne01,
2091
  constant int64_t & ne02,
2092
  constant int64_t & ne03,
 
2093
  constant uint64_t & nb01,
2094
  constant uint64_t & nb02,
2095
  constant uint64_t & nb03,
 
2096
  constant int64_t & ne11,
2097
  constant int64_t & ne12,
2098
  constant int64_t & ne13,
 
2099
  constant uint64_t & nb11,
2100
  constant uint64_t & nb12,
2101
  constant uint64_t & nb13,
2102
+ constant uint64_t & nb21,
2103
+ constant uint64_t & nb22,
2104
+ constant uint64_t & nb23,
2105
  constant uint64_t & nb31,
 
2106
  constant int64_t & ne1,
2107
  constant int64_t & ne2,
 
2108
  constant float & scale,
2109
  constant float & max_bias,
2110
  constant float & m0,
 
2174
  const short ne22 = ne12;
2175
  const short ne23 = ne13;
2176
 
 
 
 
 
2177
  // broadcast
2178
  const short rk2 = ne02/ne12;
2179
  const short rk3 = ne03/ne13;
 
2237
  simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
2238
  }
2239
 
2240
+ if (mask != q) {
2241
+ // mqk = mqk*scale + mask*slope
2242
+ simdgroup_half8x8 mm;
2243
+ simdgroup_load(mm, mp + ic + 8*cc, nb31/sizeof(half), 0, false);
2244
+ simdgroup_multiply(mm, mslope, mm);
2245
+ simdgroup_multiply_accumulate(mqk, mqk, mscale, mm);
2246
+ } else {
2247
+ // mqk = mqk*scale
2248
+ simdgroup_multiply(mqk, mscale, mqk);
2249
+ }
2250
 
2251
  simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
2252
  }
 
2420
  device const char * v,
2421
  device const char * mask,
2422
  device float * dst,
 
2423
  constant int64_t & ne01,
2424
  constant int64_t & ne02,
2425
  constant int64_t & ne03,
 
2426
  constant uint64_t & nb01,
2427
  constant uint64_t & nb02,
2428
  constant uint64_t & nb03,
 
2429
  constant int64_t & ne11,
2430
  constant int64_t & ne12,
2431
  constant int64_t & ne13,
 
2432
  constant uint64_t & nb11,
2433
  constant uint64_t & nb12,
2434
  constant uint64_t & nb13,
2435
+ constant uint64_t & nb21,
2436
+ constant uint64_t & nb22,
2437
+ constant uint64_t & nb23,
2438
  constant uint64_t & nb31,
 
2439
  constant int64_t & ne1,
2440
  constant int64_t & ne2,
 
2441
  constant float & scale,
2442
  constant float & max_bias,
2443
  constant float & m0,
 
2513
  const short ne22 = ne12;
2514
  const short ne23 = ne13;
2515
 
 
 
 
 
2516
  // broadcast
2517
  const short rk2 = ne02/ne12;
2518
  const short rk3 = ne03/ne13;
 
2577
 
2578
  // mqk = mqk*scale + mask*slope
2579
  if (tiisg == 0) {
2580
+ mqk = mqk*scale + ((mask != q) ? ((float4) mp4[ic/4 + cc])*slope : (float4) 0.0f);
 
2581
 
2582
  ss4[cc] = mqk;
2583
  }
ggml.c CHANGED
@@ -2824,6 +2824,16 @@ bool ggml_are_same_shape(const struct ggml_tensor * t0, const struct ggml_tensor
2824
  (t0->ne[3] == t1->ne[3] );
2825
  }
2826
 
 
 
 
 
 
 
 
 
 
 
2827
  // check if t1 can be represented as a repeatition of t0
2828
  static inline bool ggml_can_repeat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
2829
  static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
 
2824
  (t0->ne[3] == t1->ne[3] );
2825
  }
2826
 
2827
+ bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
2828
+ static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
2829
+
2830
+ return
2831
+ (t0->nb[0] == t1->nb[0] ) &&
2832
+ (t0->nb[1] == t1->nb[1] ) &&
2833
+ (t0->nb[2] == t1->nb[2] ) &&
2834
+ (t0->nb[3] == t1->nb[3] );
2835
+ }
2836
+
2837
  // check if t1 can be represented as a repeatition of t0
2838
  static inline bool ggml_can_repeat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
2839
  static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
ggml.h CHANGED
@@ -766,7 +766,8 @@ extern "C" {
766
  GGML_API bool ggml_is_3d (const struct ggml_tensor * tensor);
767
  GGML_API int ggml_n_dims (const struct ggml_tensor * tensor); // returns 1 for scalars
768
 
769
- GGML_API bool ggml_are_same_shape(const struct ggml_tensor * t0, const struct ggml_tensor * t1);
 
770
 
771
  // use this to compute the memory overhead of a tensor
772
  GGML_API size_t ggml_tensor_overhead(void);
 
766
  GGML_API bool ggml_is_3d (const struct ggml_tensor * tensor);
767
  GGML_API int ggml_n_dims (const struct ggml_tensor * tensor); // returns 1 for scalars
768
 
769
+ GGML_API bool ggml_are_same_shape (const struct ggml_tensor * t0, const struct ggml_tensor * t1);
770
+ GGML_API bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tensor * t1);
771
 
772
  // use this to compute the memory overhead of a tensor
773
  GGML_API size_t ggml_tensor_overhead(void);