compilade ggerganov commited on
Commit
f1abcb4
·
1 Parent(s): 4b1fda0

llama : simplify Mamba with advanced batch splits (llama/8526)

Browse files

* llama : advanced batch splits

This includes equal-sequence-length batch splits which are useful
to simplify recurrent model operators.

* llama : always make recurrent state slots contiguous

* ggml : simplify mamba operators

* llama : fix integer signedness mixing

* llama : logits_all has priority over batch->logits

Otherwise, the server embeddings tests failed.
This was likely an existing problem but was only detected here
because of an additional assertion.

* llama : apply suggestions

Co-authored-by: Georgi Gerganov <[email protected]>

* llama : fix t5 segfault

* llama : fix Mamba session save and restore

* llama : minor cosmetic changes

* llama : rename llama_reorder_outputs to llama_output_reorder

Also move it closer to llama_output_reserve.

* llama : fix pooled embeddings when using batches with equal_seqs

* minor : add struct members for clarity

ggml-ci

* llama : fix T5 segfault again

* llama : fix Mamba pooled embeddings with multiple sequences

Until the pooled embeddings are refactored to allow splitting
across ubatches for causal embeddings,
recurrent models can only process a single sequence per ubatch
when calculating pooled embeddings.

* llama : add llama_model_is_recurrent to simplify figuring that out

This will make it easier to more cleanly support RWKV-v6 and Mamba-2.

* llama : fix simple splits when the batch contains embeddings

---------

Co-authored-by: Georgi Gerganov <[email protected]>

Files changed (2) hide show
  1. ggml/include/ggml.h +3 -6
  2. ggml/src/ggml.c +92 -185
ggml/include/ggml.h CHANGED
@@ -1824,10 +1824,8 @@ extern "C" {
1824
 
1825
  GGML_API struct ggml_tensor * ggml_ssm_conv(
1826
  struct ggml_context * ctx,
1827
- struct ggml_tensor * s,
1828
- struct ggml_tensor * x,
1829
- struct ggml_tensor * c,
1830
- struct ggml_tensor * sq);
1831
 
1832
  GGML_API struct ggml_tensor * ggml_ssm_scan(
1833
  struct ggml_context * ctx,
@@ -1836,8 +1834,7 @@ extern "C" {
1836
  struct ggml_tensor * dt,
1837
  struct ggml_tensor * A,
1838
  struct ggml_tensor * B,
1839
- struct ggml_tensor * C,
1840
- struct ggml_tensor * sq);
1841
 
1842
  // partition into non-overlapping windows with padding if needed
1843
  // example:
 
1824
 
1825
  GGML_API struct ggml_tensor * ggml_ssm_conv(
1826
  struct ggml_context * ctx,
1827
+ struct ggml_tensor * sx,
1828
+ struct ggml_tensor * c);
 
 
1829
 
1830
  GGML_API struct ggml_tensor * ggml_ssm_scan(
1831
  struct ggml_context * ctx,
 
1834
  struct ggml_tensor * dt,
1835
  struct ggml_tensor * A,
1836
  struct ggml_tensor * B,
1837
+ struct ggml_tensor * C);
 
1838
 
1839
  // partition into non-overlapping windows with padding if needed
1840
  // example:
ggml/src/ggml.c CHANGED
@@ -7384,43 +7384,34 @@ struct ggml_tensor * ggml_flash_attn_back(
7384
 
7385
  struct ggml_tensor * ggml_ssm_conv(
7386
  struct ggml_context * ctx,
7387
- struct ggml_tensor * s,
7388
- struct ggml_tensor * x,
7389
- struct ggml_tensor * c,
7390
- struct ggml_tensor * sq) {
7391
- GGML_ASSERT(ggml_is_3d(s));
7392
- GGML_ASSERT(ggml_is_matrix(x));
7393
  GGML_ASSERT(ggml_is_matrix(c));
7394
- GGML_ASSERT(ggml_is_matrix(sq));
7395
- GGML_ASSERT(sq->type == GGML_TYPE_I32);
7396
 
7397
- const int64_t d_conv = c->ne[0];
7398
- const int64_t d_inner = c->ne[1];
7399
- const int64_t n_tokens = x->ne[1];
7400
- const int64_t n_kv = s->ne[2];
7401
 
7402
- GGML_ASSERT( s->ne[0] == d_conv - 1);
7403
- GGML_ASSERT( s->ne[1] == d_inner);
7404
- GGML_ASSERT( x->ne[0] == d_inner);
7405
- GGML_ASSERT(sq->ne[0] == n_kv);
7406
- GGML_ASSERT(sq->ne[1] == n_tokens);
7407
 
7408
  bool is_node = false;
7409
 
7410
- if (s->grad || x->grad || c->grad || sq->grad) {
7411
  GGML_ABORT("fatal error"); // TODO: implement
7412
  is_node = true;
7413
  }
7414
 
7415
- // 2-in-1 concatenated x and conv_states, {d_inner, n_tokens} with {d_conv, d_inner, n_kv}
7416
- struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, (d_inner*n_tokens) + (d_conv*d_inner*n_kv));
7417
 
7418
  result->op = GGML_OP_SSM_CONV;
7419
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7420
- result->src[0] = s;
7421
- result->src[1] = x;
7422
- result->src[2] = c;
7423
- result->src[3] = sq;
7424
 
7425
  return result;
7426
  }
@@ -7434,39 +7425,42 @@ struct ggml_tensor * ggml_ssm_scan(
7434
  struct ggml_tensor * dt,
7435
  struct ggml_tensor * A,
7436
  struct ggml_tensor * B,
7437
- struct ggml_tensor * C,
7438
- struct ggml_tensor * sq) {
7439
  GGML_ASSERT(ggml_is_contiguous(s));
7440
  GGML_ASSERT(ggml_is_contiguous(x));
7441
  GGML_ASSERT(ggml_is_contiguous(dt));
7442
  GGML_ASSERT(ggml_is_contiguous(A));
7443
- GGML_ASSERT(sq->type == GGML_TYPE_I32);
 
 
7444
  GGML_ASSERT(B->nb[0] == ggml_type_size(B->type));
7445
  GGML_ASSERT(C->nb[0] == ggml_type_size(C->type));
7446
  GGML_ASSERT(ggml_are_same_shape(x, dt));
 
7447
 
7448
  {
7449
- const int64_t d_state = s->ne[0];
7450
- const int64_t d_inner = s->ne[1];
7451
- const int64_t n_tokens = x->ne[1];
 
7452
 
 
7453
  GGML_ASSERT(x->ne[0] == d_inner);
7454
  GGML_ASSERT(A->ne[0] == d_state);
7455
  GGML_ASSERT(A->ne[1] == d_inner);
7456
  GGML_ASSERT(B->ne[0] == d_state);
7457
- GGML_ASSERT(B->ne[1] == n_tokens);
7458
- GGML_ASSERT(C->ne[0] == d_state);
7459
- GGML_ASSERT(C->ne[1] == n_tokens);
7460
  }
7461
 
7462
  bool is_node = false;
7463
 
7464
- if (s->grad || x->grad || dt->grad || A->grad || B->grad || C->grad || sq->grad) {
7465
  GGML_ABORT("fatal error"); // TODO: implement
7466
  is_node = true;
7467
  }
7468
 
7469
- // 2-in-1 concatenated y and ssm_states, {d_inner, n_tokens} with {d_state, d_inner, n_kv}
7470
  struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + ggml_nelements(s));
7471
 
7472
  result->op = GGML_OP_SSM_SCAN;
@@ -7477,7 +7471,6 @@ struct ggml_tensor * ggml_ssm_scan(
7477
  result->src[3] = A;
7478
  result->src[4] = B;
7479
  result->src[5] = C;
7480
- result->src[6] = sq;
7481
 
7482
  return result;
7483
  }
@@ -11254,11 +11247,6 @@ static void ggml_compute_forward_concat_f32(
11254
 
11255
  GGML_TENSOR_BINARY_OP_LOCALS
11256
 
11257
- // TODO: support for transposed / permuted tensors
11258
- GGML_ASSERT(nb0 == sizeof(float));
11259
- GGML_ASSERT(nb00 == sizeof(float));
11260
- GGML_ASSERT(nb10 == sizeof(float));
11261
-
11262
  const int32_t dim = ggml_get_op_params_i32(dst, 0);
11263
 
11264
  GGML_ASSERT(dim >= 0 && dim < 4);
@@ -16256,27 +16244,22 @@ static void ggml_compute_forward_flash_attn_back(
16256
  static void ggml_compute_forward_ssm_conv_f32(
16257
  const struct ggml_compute_params * params,
16258
  struct ggml_tensor * dst) {
16259
- const struct ggml_tensor * src0 = dst->src[0]; // conv_state
16260
- const struct ggml_tensor * src1 = dst->src[1]; // x
16261
- const struct ggml_tensor * src2 = dst->src[2]; // conv1d.weight
16262
- const struct ggml_tensor * src3 = dst->src[3]; // state_seq
16263
 
16264
  const int ith = params->ith;
16265
  const int nth = params->nth;
16266
 
16267
- const int nc = src2->ne[0]; // d_conv
16268
- const int nr = src0->ne[1]; // d_inner
16269
- const int n_t = src1->ne[1]; // n_tokens
16270
- const int n_kv = src0->ne[2]; // max number of sequences in the batch
 
16271
 
16272
- GGML_ASSERT((nr*n_t) + (nc*nr*n_kv) == ggml_nelements(dst));
16273
  GGML_ASSERT(src0->nb[0] == sizeof(float));
16274
  GGML_ASSERT(src1->nb[0] == sizeof(float));
16275
- GGML_ASSERT(src2->nb[0] == sizeof(float));
16276
- GGML_ASSERT(src3->nb[0] == sizeof(int32_t));
16277
  GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
16278
- // for use with the destination state offset between sequences
16279
- GGML_ASSERT(src2->nb[2] == src2->ne[1]*src2->ne[0]*sizeof(float));
16280
 
16281
  // rows per thread
16282
  const int dr = (nr + nth - 1)/nth;
@@ -16286,74 +16269,27 @@ static void ggml_compute_forward_ssm_conv_f32(
16286
  const int ir1 = MIN(ir0 + dr, nr);
16287
  const int ir = ir1 - ir0;
16288
 
16289
- if (n_kv > 1) {
16290
- // multiple sequences means it's hard to know when it's the first time a state is read,
16291
- // so copy them all over to the destination, just to be sure.
16292
- for (int i3 = 0; i3 < n_kv; ++i3) {
16293
- float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]));
16294
- float * s = (float *) ((char *) dst->data + ir0*(src2->nb[1]) + i3*(src2->nb[2]) + nr*n_t*sizeof(float));
16295
- // can't use memcpy because of d_conv vs d_conv - 1
16296
- for (int i1 = 0; i1 < ir; ++i1) {
16297
- for (int i0 = 0; i0 < nc - 1; ++i0) {
16298
- // copy s0 to last (d_conv - 1) columns of s
16299
- s[1 + i0 + i1*nc] = s0[i0 + i1*(nc - 1)];
16300
- }
16301
- }
16302
- }
16303
- }
16304
-
16305
- for (int i2 = 0; i2 < n_t; ++i2) {
16306
- int32_t * sq = (int32_t *) ((char *) src3->data + i2*(src3->nb[1])); // {n_kv, n_tokens}
16307
- float * x = (float *) ((char *) dst->data + ir0*sizeof(float) + i2*(nr*sizeof(float))); // {d_inner, n_tokens}
16308
- float * s = (float *) ((char *) dst->data + ir0*(src2->nb[1]) + sq[0]*(src2->nb[2]) + nr*n_t*sizeof(float)); // {d_conv, d_inner, n_kv}
16309
- float * s0; // {d_conv - 1, d_inner, n_kv}
16310
- float * x0 = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens}
16311
- float * c = (float *) ((char *) src2->data + ir0*(src2->nb[1])); // {d_conv, d_inner}
16312
- int ne0s0;
16313
-
16314
- GGML_ASSERT(0 <= sq[0] && sq[0] < n_kv);
16315
 
16316
- // avoid needing to copy the state for the first token
16317
- if (i2 == 0) {
16318
- s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2])); // {d_conv - 1, d_inner, n_kv}
16319
- ne0s0 = src0->ne[0];
16320
- } else {
16321
- // the source is the last (d_conv - 1) columns of the destination
16322
- s0 = s + 1;
16323
- ne0s0 = nc;
16324
- }
16325
-
16326
- // d_inner
16327
- for (int i1 = 0; i1 < ir; ++i1) {
16328
- // shift state left
16329
- for (int i0 = 0; i0 < nc - 1; ++i0) {
16330
- s[i0 + i1*nc] = s0[i0 + i1*ne0s0];
16331
- }
16332
- // insert x on the last column
16333
- s[(nc - 1) + i1*nc] = x0[i1];
16334
- }
16335
-
16336
- // handle copies when there are multiple output states
16337
- for (int i3 = 1; i3 < n_kv; ++i3) {
16338
- int32_t seq = sq[i3];
16339
- if (0 <= seq && seq < n_kv) {
16340
- float * s1 = s + (seq - sq[0])*nc*nr;
16341
- memcpy(s1, s, nc*ir*sizeof(float));
16342
- } else {
16343
- // stop at negative or too big seq_ids
16344
- break;
16345
- }
16346
- }
16347
 
16348
- // it seems a little faster when this is separate from the state shift
16349
- for (int i1 = 0; i1 < ir; ++i1) {
16350
- // rowwise dot product
16351
- float sumf = 0.0f;
16352
- for (int i0 = 0; i0 < nc; ++i0) {
16353
- int i = i0 + i1*nc;
16354
- sumf += s[i] * c[i];
16355
  }
16356
- x[i1] = sumf;
16357
  }
16358
  }
16359
  }
@@ -16384,15 +16320,14 @@ static void ggml_compute_forward_ssm_scan_f32(
16384
  const struct ggml_tensor * src3 = dst->src[3]; // A
16385
  const struct ggml_tensor * src4 = dst->src[4]; // B
16386
  const struct ggml_tensor * src5 = dst->src[5]; // C
16387
- const struct ggml_tensor * src6 = dst->src[6]; // sq
16388
 
16389
  const int ith = params->ith;
16390
  const int nth = params->nth;
16391
 
16392
- const int64_t nc = src0->ne[0]; // d_state
16393
- const int64_t nr = src0->ne[1]; // d_inner
16394
- const int64_t n_t = src1->ne[1]; // number of tokens in the batch
16395
- const int64_t n_kv = src0->ne[2]; // max number of sequences in the batch
16396
 
16397
  GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst));
16398
  GGML_ASSERT(src0->nb[0] == sizeof(float));
@@ -16401,12 +16336,12 @@ static void ggml_compute_forward_ssm_scan_f32(
16401
  GGML_ASSERT(src3->nb[0] == sizeof(float));
16402
  GGML_ASSERT(src4->nb[0] == sizeof(float));
16403
  GGML_ASSERT(src5->nb[0] == sizeof(float));
16404
- // required for the dot product between s and C, and when copying the states
16405
  GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
16406
  // required for per-sequence offsets for states
16407
  GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float));
16408
- // required to get correct offset for state destination (i.e. src1->nb[2])
16409
- GGML_ASSERT(src1->nb[2] == src1->ne[0]*src1->ne[1]*sizeof(float));
16410
 
16411
  // rows per thread
16412
  const int dr = (nr + nth - 1)/nth;
@@ -16416,64 +16351,36 @@ static void ggml_compute_forward_ssm_scan_f32(
16416
  const int ir1 = MIN(ir0 + dr, nr);
16417
  const int ir = ir1 - ir0;
16418
 
16419
- if (n_kv > 1) {
16420
- // it's hard to know if the source states have already been copied
16421
- // when there are multiple, so copy them already.
16422
- for (int i3 = 0; i3 < n_kv; ++i3) {
16423
- float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]));
16424
- float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[2]);
16425
- memcpy(s, s0, nc*ir*sizeof(float));
16426
- }
16427
- }
16428
-
16429
- for (int i2 = 0; i2 < n_t; ++i2) {
16430
- int32_t * sq = (int32_t *) ((char *) src6->data + i2*(src6->nb[1])); // {n_kv, n_tokens}
16431
- float * y = (float *) ((char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens}
16432
- float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2]) + src1->nb[2]); // {d_state, d_inner, n_kv}
16433
- float * s0;
16434
- float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens}
16435
- float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1])); // {d_inner, n_tokens}
16436
- float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
16437
- float * B = (float *) ((char *) src4->data + i2*(src4->nb[1])); // {d_state, n_tokens}
16438
- float * C = (float *) ((char *) src5->data + i2*(src5->nb[1])); // {d_state, n_tokens}
16439
-
16440
- GGML_ASSERT(0 <= sq[0] && sq[0] < n_kv);
16441
-
16442
- // avoid needing to copy the state for the first token
16443
- if (i2 == 0) {
16444
- s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2])); // {d_state, d_inner, n_kv}
16445
- } else {
16446
- // otherwise the source is the same as the destination
16447
- s0 = s;
16448
- }
16449
-
16450
- // d_inner
16451
- for (int i1 = 0; i1 < ir; ++i1) {
16452
- // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78
16453
- float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
16454
- float x_dt = x[i1] * dt_soft_plus;
16455
- float sumf = 0.0f;
16456
- // d_state
16457
- for (int i0 = 0; i0 < nc; ++i0) {
16458
- int i = i0 + i1*nc;
16459
- // state = prev_state * dA + dB * x
16460
- float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
16461
- // y = rowwise_dotprod(state, C)
16462
- sumf += state * C[i0];
16463
- s[i] = state;
16464
- }
16465
- y[i1] = sumf;
16466
- }
16467
-
16468
- // handle copies when there are multiple output states
16469
- for (int i3 = 1; i3 < n_kv; ++i3) {
16470
- int32_t seq = sq[i3];
16471
- if (0 <= seq && seq < n_kv) {
16472
- float * s1 = s + (seq - sq[0])*nc*nr;
16473
- memcpy(s1, s, nc*ir*sizeof(float));
16474
- } else {
16475
- // stop at negative or too big seq_ids
16476
- break;
16477
  }
16478
  }
16479
  }
 
7384
 
7385
  struct ggml_tensor * ggml_ssm_conv(
7386
  struct ggml_context * ctx,
7387
+ struct ggml_tensor * sx,
7388
+ struct ggml_tensor * c) {
7389
+ GGML_ASSERT(ggml_is_3d(sx));
 
 
 
7390
  GGML_ASSERT(ggml_is_matrix(c));
 
 
7391
 
7392
+ const int64_t d_conv = c->ne[0];
7393
+ const int64_t d_inner = c->ne[1];
7394
+ const int64_t n_t = sx->ne[0] - d_conv + 1; // tokens per sequence
7395
+ const int64_t n_s = sx->ne[2];
7396
 
7397
+ // TODO: maybe support other strides than 1?
7398
+ GGML_ASSERT(sx->ne[0] == d_conv - 1 + n_t);
7399
+ GGML_ASSERT(sx->ne[1] == d_inner);
7400
+ GGML_ASSERT(n_t >= 0);
 
7401
 
7402
  bool is_node = false;
7403
 
7404
+ if (sx->grad || c->grad) {
7405
  GGML_ABORT("fatal error"); // TODO: implement
7406
  is_node = true;
7407
  }
7408
 
7409
+ struct ggml_tensor * result = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_inner, n_t, n_s);
 
7410
 
7411
  result->op = GGML_OP_SSM_CONV;
7412
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7413
+ result->src[0] = sx;
7414
+ result->src[1] = c;
 
 
7415
 
7416
  return result;
7417
  }
 
7425
  struct ggml_tensor * dt,
7426
  struct ggml_tensor * A,
7427
  struct ggml_tensor * B,
7428
+ struct ggml_tensor * C) {
 
7429
  GGML_ASSERT(ggml_is_contiguous(s));
7430
  GGML_ASSERT(ggml_is_contiguous(x));
7431
  GGML_ASSERT(ggml_is_contiguous(dt));
7432
  GGML_ASSERT(ggml_is_contiguous(A));
7433
+ GGML_ASSERT(ggml_is_matrix(A));
7434
+ GGML_ASSERT(ggml_is_3d(B));
7435
+ GGML_ASSERT(ggml_is_3d(s));
7436
  GGML_ASSERT(B->nb[0] == ggml_type_size(B->type));
7437
  GGML_ASSERT(C->nb[0] == ggml_type_size(C->type));
7438
  GGML_ASSERT(ggml_are_same_shape(x, dt));
7439
+ GGML_ASSERT(ggml_are_same_shape(B, C));
7440
 
7441
  {
7442
+ const int64_t d_state = s->ne[0];
7443
+ const int64_t d_inner = s->ne[1];
7444
+ const int64_t n_seq_tokens = x->ne[1];
7445
+ const int64_t n_seqs = x->ne[2];
7446
 
7447
+ GGML_ASSERT(s->ne[2] == n_seqs);
7448
  GGML_ASSERT(x->ne[0] == d_inner);
7449
  GGML_ASSERT(A->ne[0] == d_state);
7450
  GGML_ASSERT(A->ne[1] == d_inner);
7451
  GGML_ASSERT(B->ne[0] == d_state);
7452
+ GGML_ASSERT(B->ne[1] == n_seq_tokens);
7453
+ GGML_ASSERT(B->ne[2] == n_seqs);
 
7454
  }
7455
 
7456
  bool is_node = false;
7457
 
7458
+ if (s->grad || x->grad || dt->grad || A->grad || B->grad || C->grad) {
7459
  GGML_ABORT("fatal error"); // TODO: implement
7460
  is_node = true;
7461
  }
7462
 
7463
+ // concatenated y + ssm_states
7464
  struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + ggml_nelements(s));
7465
 
7466
  result->op = GGML_OP_SSM_SCAN;
 
7471
  result->src[3] = A;
7472
  result->src[4] = B;
7473
  result->src[5] = C;
 
7474
 
7475
  return result;
7476
  }
 
11247
 
11248
  GGML_TENSOR_BINARY_OP_LOCALS
11249
 
 
 
 
 
 
11250
  const int32_t dim = ggml_get_op_params_i32(dst, 0);
11251
 
11252
  GGML_ASSERT(dim >= 0 && dim < 4);
 
16244
  static void ggml_compute_forward_ssm_conv_f32(
16245
  const struct ggml_compute_params * params,
16246
  struct ggml_tensor * dst) {
16247
+ const struct ggml_tensor * src0 = dst->src[0]; // conv_x
16248
+ const struct ggml_tensor * src1 = dst->src[1]; // conv1d.weight
 
 
16249
 
16250
  const int ith = params->ith;
16251
  const int nth = params->nth;
16252
 
16253
+ const int nc = src1->ne[0]; // d_conv
16254
+ const int ncs = src0->ne[0]; // d_conv - 1 + n_t
16255
+ const int nr = src0->ne[1]; // d_inner
16256
+ const int n_t = dst->ne[1]; // tokens per sequence
16257
+ const int n_s = dst->ne[2]; // number of sequences in the batch
16258
 
16259
+ GGML_ASSERT( dst->ne[0] == nr);
16260
  GGML_ASSERT(src0->nb[0] == sizeof(float));
16261
  GGML_ASSERT(src1->nb[0] == sizeof(float));
 
 
16262
  GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
 
 
16263
 
16264
  // rows per thread
16265
  const int dr = (nr + nth - 1)/nth;
 
16269
  const int ir1 = MIN(ir0 + dr, nr);
16270
  const int ir = ir1 - ir0;
16271
 
16272
+ for (int i3 = 0; i3 < n_s; ++i3) {
16273
+ for (int i2 = 0; i2 < n_t; ++i2) {
16274
+ // {d_conv - 1 + n_t, d_inner, n_seqs}
16275
+ // sliding window
16276
+ const float * s = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i2*(src0->nb[0]) + i3*(src0->nb[2])); // {d_conv, d_inner, n_s}
16277
+ const float * c = (const float *) ((const char *) src1->data + ir0*(src1->nb[1])); // {d_conv, d_inner}
16278
+ float * x = (float *) ((char *) dst->data + ir0*(dst->nb[0]) + i2*(dst->nb[1]) + i3*(dst->nb[2])); // {d_inner, n_t, n_s}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16279
 
16280
+ // TODO: transpose the output for smaller strides for big batches?
16281
+ // d_inner
16282
+ for (int i1 = 0; i1 < ir; ++i1) {
16283
+ // rowwise dot product
16284
+ // NOTE: not using ggml_vec_dot_f32, because its sum is in double precision
16285
+ float sumf = 0.0f;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16286
 
16287
+ // d_conv
16288
+ for (int i0 = 0; i0 < nc; ++i0) {
16289
+ sumf += s[i0 + i1*ncs] * c[i0 + i1*nc];
16290
+ }
16291
+ x[i1] = sumf;
 
 
16292
  }
 
16293
  }
16294
  }
16295
  }
 
16320
  const struct ggml_tensor * src3 = dst->src[3]; // A
16321
  const struct ggml_tensor * src4 = dst->src[4]; // B
16322
  const struct ggml_tensor * src5 = dst->src[5]; // C
 
16323
 
16324
  const int ith = params->ith;
16325
  const int nth = params->nth;
16326
 
16327
+ const int64_t nc = src0->ne[0]; // d_state
16328
+ const int64_t nr = src0->ne[1]; // d_inner
16329
+ const int64_t n_t = src1->ne[1]; // number of tokens per sequence
16330
+ const int64_t n_s = src0->ne[2]; // number of sequences in the batch
16331
 
16332
  GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst));
16333
  GGML_ASSERT(src0->nb[0] == sizeof(float));
 
16336
  GGML_ASSERT(src3->nb[0] == sizeof(float));
16337
  GGML_ASSERT(src4->nb[0] == sizeof(float));
16338
  GGML_ASSERT(src5->nb[0] == sizeof(float));
16339
+ // required for the dot product between s and C
16340
  GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
16341
  // required for per-sequence offsets for states
16342
  GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float));
16343
+ // required to get correct offset for state destination (i.e. src1->nb[3])
16344
+ GGML_ASSERT(src1->nb[3] == src1->ne[0]*src1->ne[1]*src1->ne[2]*sizeof(float));
16345
 
16346
  // rows per thread
16347
  const int dr = (nr + nth - 1)/nth;
 
16351
  const int ir1 = MIN(ir0 + dr, nr);
16352
  const int ir = ir1 - ir0;
16353
 
16354
+ for (int i3 = 0; i3 < n_s; ++i3) {
16355
+ for (int i2 = 0; i2 < n_t; ++i2) {
16356
+ const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
16357
+ const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
16358
+ const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
16359
+ const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
16360
+ const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
16361
+ const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
16362
+ float * y = (float *) ((char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
16363
+ float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
16364
+
16365
+ // use the output as the source for the next token-wise iterations
16366
+ if (i2 > 0) { s0 = s; }
16367
+
16368
+ // d_inner
16369
+ for (int i1 = 0; i1 < ir; ++i1) {
16370
+ // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78
16371
+ float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
16372
+ float x_dt = x[i1] * dt_soft_plus;
16373
+ float sumf = 0.0f;
16374
+ // d_state
16375
+ for (int i0 = 0; i0 < nc; ++i0) {
16376
+ int i = i0 + i1*nc;
16377
+ // state = prev_state * dA + dB * x
16378
+ float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
16379
+ // y = rowwise_dotprod(state, C)
16380
+ sumf += state * C[i0];
16381
+ s[i] = state;
16382
+ }
16383
+ y[i1] = sumf;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16384
  }
16385
  }
16386
  }