Spaces:
Running
llama : simplify Mamba with advanced batch splits (llama/8526)
Browse files* llama : advanced batch splits
This includes equal-sequence-length batch splits which are useful
to simplify recurrent model operators.
* llama : always make recurrent state slots contiguous
* ggml : simplify mamba operators
* llama : fix integer signedness mixing
* llama : logits_all has priority over batch->logits
Otherwise, the server embeddings tests failed.
This was likely an existing problem but was only detected here
because of an additional assertion.
* llama : apply suggestions
Co-authored-by: Georgi Gerganov <[email protected]>
* llama : fix t5 segfault
* llama : fix Mamba session save and restore
* llama : minor cosmetic changes
* llama : rename llama_reorder_outputs to llama_output_reorder
Also move it closer to llama_output_reserve.
* llama : fix pooled embeddings when using batches with equal_seqs
* minor : add struct members for clarity
ggml-ci
* llama : fix T5 segfault again
* llama : fix Mamba pooled embeddings with multiple sequences
Until the pooled embeddings are refactored to allow splitting
across ubatches for causal embeddings,
recurrent models can only process a single sequence per ubatch
when calculating pooled embeddings.
* llama : add llama_model_is_recurrent to simplify figuring that out
This will make it easier to more cleanly support RWKV-v6 and Mamba-2.
* llama : fix simple splits when the batch contains embeddings
---------
Co-authored-by: Georgi Gerganov <[email protected]>
- ggml/include/ggml.h +3 -6
- ggml/src/ggml.c +92 -185
|
@@ -1824,10 +1824,8 @@ extern "C" {
|
|
| 1824 |
|
| 1825 |
GGML_API struct ggml_tensor * ggml_ssm_conv(
|
| 1826 |
struct ggml_context * ctx,
|
| 1827 |
-
struct ggml_tensor *
|
| 1828 |
-
struct ggml_tensor *
|
| 1829 |
-
struct ggml_tensor * c,
|
| 1830 |
-
struct ggml_tensor * sq);
|
| 1831 |
|
| 1832 |
GGML_API struct ggml_tensor * ggml_ssm_scan(
|
| 1833 |
struct ggml_context * ctx,
|
|
@@ -1836,8 +1834,7 @@ extern "C" {
|
|
| 1836 |
struct ggml_tensor * dt,
|
| 1837 |
struct ggml_tensor * A,
|
| 1838 |
struct ggml_tensor * B,
|
| 1839 |
-
struct ggml_tensor * C
|
| 1840 |
-
struct ggml_tensor * sq);
|
| 1841 |
|
| 1842 |
// partition into non-overlapping windows with padding if needed
|
| 1843 |
// example:
|
|
|
|
| 1824 |
|
| 1825 |
GGML_API struct ggml_tensor * ggml_ssm_conv(
|
| 1826 |
struct ggml_context * ctx,
|
| 1827 |
+
struct ggml_tensor * sx,
|
| 1828 |
+
struct ggml_tensor * c);
|
|
|
|
|
|
|
| 1829 |
|
| 1830 |
GGML_API struct ggml_tensor * ggml_ssm_scan(
|
| 1831 |
struct ggml_context * ctx,
|
|
|
|
| 1834 |
struct ggml_tensor * dt,
|
| 1835 |
struct ggml_tensor * A,
|
| 1836 |
struct ggml_tensor * B,
|
| 1837 |
+
struct ggml_tensor * C);
|
|
|
|
| 1838 |
|
| 1839 |
// partition into non-overlapping windows with padding if needed
|
| 1840 |
// example:
|
|
@@ -7384,43 +7384,34 @@ struct ggml_tensor * ggml_flash_attn_back(
|
|
| 7384 |
|
| 7385 |
struct ggml_tensor * ggml_ssm_conv(
|
| 7386 |
struct ggml_context * ctx,
|
| 7387 |
-
struct ggml_tensor *
|
| 7388 |
-
struct ggml_tensor *
|
| 7389 |
-
|
| 7390 |
-
struct ggml_tensor * sq) {
|
| 7391 |
-
GGML_ASSERT(ggml_is_3d(s));
|
| 7392 |
-
GGML_ASSERT(ggml_is_matrix(x));
|
| 7393 |
GGML_ASSERT(ggml_is_matrix(c));
|
| 7394 |
-
GGML_ASSERT(ggml_is_matrix(sq));
|
| 7395 |
-
GGML_ASSERT(sq->type == GGML_TYPE_I32);
|
| 7396 |
|
| 7397 |
-
const int64_t d_conv
|
| 7398 |
-
const int64_t d_inner
|
| 7399 |
-
const int64_t
|
| 7400 |
-
const int64_t
|
| 7401 |
|
| 7402 |
-
|
| 7403 |
-
GGML_ASSERT(
|
| 7404 |
-
GGML_ASSERT(
|
| 7405 |
-
GGML_ASSERT(
|
| 7406 |
-
GGML_ASSERT(sq->ne[1] == n_tokens);
|
| 7407 |
|
| 7408 |
bool is_node = false;
|
| 7409 |
|
| 7410 |
-
if (
|
| 7411 |
GGML_ABORT("fatal error"); // TODO: implement
|
| 7412 |
is_node = true;
|
| 7413 |
}
|
| 7414 |
|
| 7415 |
-
|
| 7416 |
-
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, (d_inner*n_tokens) + (d_conv*d_inner*n_kv));
|
| 7417 |
|
| 7418 |
result->op = GGML_OP_SSM_CONV;
|
| 7419 |
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
| 7420 |
-
result->src[0] =
|
| 7421 |
-
result->src[1] =
|
| 7422 |
-
result->src[2] = c;
|
| 7423 |
-
result->src[3] = sq;
|
| 7424 |
|
| 7425 |
return result;
|
| 7426 |
}
|
|
@@ -7434,39 +7425,42 @@ struct ggml_tensor * ggml_ssm_scan(
|
|
| 7434 |
struct ggml_tensor * dt,
|
| 7435 |
struct ggml_tensor * A,
|
| 7436 |
struct ggml_tensor * B,
|
| 7437 |
-
struct ggml_tensor * C
|
| 7438 |
-
struct ggml_tensor * sq) {
|
| 7439 |
GGML_ASSERT(ggml_is_contiguous(s));
|
| 7440 |
GGML_ASSERT(ggml_is_contiguous(x));
|
| 7441 |
GGML_ASSERT(ggml_is_contiguous(dt));
|
| 7442 |
GGML_ASSERT(ggml_is_contiguous(A));
|
| 7443 |
-
GGML_ASSERT(
|
|
|
|
|
|
|
| 7444 |
GGML_ASSERT(B->nb[0] == ggml_type_size(B->type));
|
| 7445 |
GGML_ASSERT(C->nb[0] == ggml_type_size(C->type));
|
| 7446 |
GGML_ASSERT(ggml_are_same_shape(x, dt));
|
|
|
|
| 7447 |
|
| 7448 |
{
|
| 7449 |
-
const int64_t d_state
|
| 7450 |
-
const int64_t d_inner
|
| 7451 |
-
const int64_t
|
|
|
|
| 7452 |
|
|
|
|
| 7453 |
GGML_ASSERT(x->ne[0] == d_inner);
|
| 7454 |
GGML_ASSERT(A->ne[0] == d_state);
|
| 7455 |
GGML_ASSERT(A->ne[1] == d_inner);
|
| 7456 |
GGML_ASSERT(B->ne[0] == d_state);
|
| 7457 |
-
GGML_ASSERT(B->ne[1] ==
|
| 7458 |
-
GGML_ASSERT(
|
| 7459 |
-
GGML_ASSERT(C->ne[1] == n_tokens);
|
| 7460 |
}
|
| 7461 |
|
| 7462 |
bool is_node = false;
|
| 7463 |
|
| 7464 |
-
if (s->grad || x->grad || dt->grad || A->grad || B->grad || C->grad
|
| 7465 |
GGML_ABORT("fatal error"); // TODO: implement
|
| 7466 |
is_node = true;
|
| 7467 |
}
|
| 7468 |
|
| 7469 |
-
//
|
| 7470 |
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + ggml_nelements(s));
|
| 7471 |
|
| 7472 |
result->op = GGML_OP_SSM_SCAN;
|
|
@@ -7477,7 +7471,6 @@ struct ggml_tensor * ggml_ssm_scan(
|
|
| 7477 |
result->src[3] = A;
|
| 7478 |
result->src[4] = B;
|
| 7479 |
result->src[5] = C;
|
| 7480 |
-
result->src[6] = sq;
|
| 7481 |
|
| 7482 |
return result;
|
| 7483 |
}
|
|
@@ -11254,11 +11247,6 @@ static void ggml_compute_forward_concat_f32(
|
|
| 11254 |
|
| 11255 |
GGML_TENSOR_BINARY_OP_LOCALS
|
| 11256 |
|
| 11257 |
-
// TODO: support for transposed / permuted tensors
|
| 11258 |
-
GGML_ASSERT(nb0 == sizeof(float));
|
| 11259 |
-
GGML_ASSERT(nb00 == sizeof(float));
|
| 11260 |
-
GGML_ASSERT(nb10 == sizeof(float));
|
| 11261 |
-
|
| 11262 |
const int32_t dim = ggml_get_op_params_i32(dst, 0);
|
| 11263 |
|
| 11264 |
GGML_ASSERT(dim >= 0 && dim < 4);
|
|
@@ -16256,27 +16244,22 @@ static void ggml_compute_forward_flash_attn_back(
|
|
| 16256 |
static void ggml_compute_forward_ssm_conv_f32(
|
| 16257 |
const struct ggml_compute_params * params,
|
| 16258 |
struct ggml_tensor * dst) {
|
| 16259 |
-
const struct ggml_tensor * src0 = dst->src[0]; //
|
| 16260 |
-
const struct ggml_tensor * src1 = dst->src[1]; //
|
| 16261 |
-
const struct ggml_tensor * src2 = dst->src[2]; // conv1d.weight
|
| 16262 |
-
const struct ggml_tensor * src3 = dst->src[3]; // state_seq
|
| 16263 |
|
| 16264 |
const int ith = params->ith;
|
| 16265 |
const int nth = params->nth;
|
| 16266 |
|
| 16267 |
-
const int nc
|
| 16268 |
-
const int
|
| 16269 |
-
const int
|
| 16270 |
-
const int
|
|
|
|
| 16271 |
|
| 16272 |
-
GGML_ASSERT(
|
| 16273 |
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
| 16274 |
GGML_ASSERT(src1->nb[0] == sizeof(float));
|
| 16275 |
-
GGML_ASSERT(src2->nb[0] == sizeof(float));
|
| 16276 |
-
GGML_ASSERT(src3->nb[0] == sizeof(int32_t));
|
| 16277 |
GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
|
| 16278 |
-
// for use with the destination state offset between sequences
|
| 16279 |
-
GGML_ASSERT(src2->nb[2] == src2->ne[1]*src2->ne[0]*sizeof(float));
|
| 16280 |
|
| 16281 |
// rows per thread
|
| 16282 |
const int dr = (nr + nth - 1)/nth;
|
|
@@ -16286,74 +16269,27 @@ static void ggml_compute_forward_ssm_conv_f32(
|
|
| 16286 |
const int ir1 = MIN(ir0 + dr, nr);
|
| 16287 |
const int ir = ir1 - ir0;
|
| 16288 |
|
| 16289 |
-
|
| 16290 |
-
|
| 16291 |
-
|
| 16292 |
-
|
| 16293 |
-
float *
|
| 16294 |
-
float *
|
| 16295 |
-
|
| 16296 |
-
for (int i1 = 0; i1 < ir; ++i1) {
|
| 16297 |
-
for (int i0 = 0; i0 < nc - 1; ++i0) {
|
| 16298 |
-
// copy s0 to last (d_conv - 1) columns of s
|
| 16299 |
-
s[1 + i0 + i1*nc] = s0[i0 + i1*(nc - 1)];
|
| 16300 |
-
}
|
| 16301 |
-
}
|
| 16302 |
-
}
|
| 16303 |
-
}
|
| 16304 |
-
|
| 16305 |
-
for (int i2 = 0; i2 < n_t; ++i2) {
|
| 16306 |
-
int32_t * sq = (int32_t *) ((char *) src3->data + i2*(src3->nb[1])); // {n_kv, n_tokens}
|
| 16307 |
-
float * x = (float *) ((char *) dst->data + ir0*sizeof(float) + i2*(nr*sizeof(float))); // {d_inner, n_tokens}
|
| 16308 |
-
float * s = (float *) ((char *) dst->data + ir0*(src2->nb[1]) + sq[0]*(src2->nb[2]) + nr*n_t*sizeof(float)); // {d_conv, d_inner, n_kv}
|
| 16309 |
-
float * s0; // {d_conv - 1, d_inner, n_kv}
|
| 16310 |
-
float * x0 = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens}
|
| 16311 |
-
float * c = (float *) ((char *) src2->data + ir0*(src2->nb[1])); // {d_conv, d_inner}
|
| 16312 |
-
int ne0s0;
|
| 16313 |
-
|
| 16314 |
-
GGML_ASSERT(0 <= sq[0] && sq[0] < n_kv);
|
| 16315 |
|
| 16316 |
-
|
| 16317 |
-
|
| 16318 |
-
|
| 16319 |
-
|
| 16320 |
-
|
| 16321 |
-
|
| 16322 |
-
s0 = s + 1;
|
| 16323 |
-
ne0s0 = nc;
|
| 16324 |
-
}
|
| 16325 |
-
|
| 16326 |
-
// d_inner
|
| 16327 |
-
for (int i1 = 0; i1 < ir; ++i1) {
|
| 16328 |
-
// shift state left
|
| 16329 |
-
for (int i0 = 0; i0 < nc - 1; ++i0) {
|
| 16330 |
-
s[i0 + i1*nc] = s0[i0 + i1*ne0s0];
|
| 16331 |
-
}
|
| 16332 |
-
// insert x on the last column
|
| 16333 |
-
s[(nc - 1) + i1*nc] = x0[i1];
|
| 16334 |
-
}
|
| 16335 |
-
|
| 16336 |
-
// handle copies when there are multiple output states
|
| 16337 |
-
for (int i3 = 1; i3 < n_kv; ++i3) {
|
| 16338 |
-
int32_t seq = sq[i3];
|
| 16339 |
-
if (0 <= seq && seq < n_kv) {
|
| 16340 |
-
float * s1 = s + (seq - sq[0])*nc*nr;
|
| 16341 |
-
memcpy(s1, s, nc*ir*sizeof(float));
|
| 16342 |
-
} else {
|
| 16343 |
-
// stop at negative or too big seq_ids
|
| 16344 |
-
break;
|
| 16345 |
-
}
|
| 16346 |
-
}
|
| 16347 |
|
| 16348 |
-
|
| 16349 |
-
|
| 16350 |
-
|
| 16351 |
-
|
| 16352 |
-
|
| 16353 |
-
int i = i0 + i1*nc;
|
| 16354 |
-
sumf += s[i] * c[i];
|
| 16355 |
}
|
| 16356 |
-
x[i1] = sumf;
|
| 16357 |
}
|
| 16358 |
}
|
| 16359 |
}
|
|
@@ -16384,15 +16320,14 @@ static void ggml_compute_forward_ssm_scan_f32(
|
|
| 16384 |
const struct ggml_tensor * src3 = dst->src[3]; // A
|
| 16385 |
const struct ggml_tensor * src4 = dst->src[4]; // B
|
| 16386 |
const struct ggml_tensor * src5 = dst->src[5]; // C
|
| 16387 |
-
const struct ggml_tensor * src6 = dst->src[6]; // sq
|
| 16388 |
|
| 16389 |
const int ith = params->ith;
|
| 16390 |
const int nth = params->nth;
|
| 16391 |
|
| 16392 |
-
const int64_t nc
|
| 16393 |
-
const int64_t nr
|
| 16394 |
-
const int64_t n_t
|
| 16395 |
-
const int64_t
|
| 16396 |
|
| 16397 |
GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst));
|
| 16398 |
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
|
@@ -16401,12 +16336,12 @@ static void ggml_compute_forward_ssm_scan_f32(
|
|
| 16401 |
GGML_ASSERT(src3->nb[0] == sizeof(float));
|
| 16402 |
GGML_ASSERT(src4->nb[0] == sizeof(float));
|
| 16403 |
GGML_ASSERT(src5->nb[0] == sizeof(float));
|
| 16404 |
-
// required for the dot product between s and C
|
| 16405 |
GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
|
| 16406 |
// required for per-sequence offsets for states
|
| 16407 |
GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float));
|
| 16408 |
-
// required to get correct offset for state destination (i.e. src1->nb[
|
| 16409 |
-
GGML_ASSERT(src1->nb[
|
| 16410 |
|
| 16411 |
// rows per thread
|
| 16412 |
const int dr = (nr + nth - 1)/nth;
|
|
@@ -16416,64 +16351,36 @@ static void ggml_compute_forward_ssm_scan_f32(
|
|
| 16416 |
const int ir1 = MIN(ir0 + dr, nr);
|
| 16417 |
const int ir = ir1 - ir0;
|
| 16418 |
|
| 16419 |
-
|
| 16420 |
-
|
| 16421 |
-
|
| 16422 |
-
|
| 16423 |
-
float *
|
| 16424 |
-
float *
|
| 16425 |
-
|
| 16426 |
-
|
| 16427 |
-
|
| 16428 |
-
|
| 16429 |
-
|
| 16430 |
-
|
| 16431 |
-
|
| 16432 |
-
|
| 16433 |
-
|
| 16434 |
-
|
| 16435 |
-
|
| 16436 |
-
|
| 16437 |
-
|
| 16438 |
-
|
| 16439 |
-
|
| 16440 |
-
|
| 16441 |
-
|
| 16442 |
-
|
| 16443 |
-
|
| 16444 |
-
|
| 16445 |
-
|
| 16446 |
-
|
| 16447 |
-
|
| 16448 |
-
|
| 16449 |
-
|
| 16450 |
-
// d_inner
|
| 16451 |
-
for (int i1 = 0; i1 < ir; ++i1) {
|
| 16452 |
-
// ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78
|
| 16453 |
-
float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
|
| 16454 |
-
float x_dt = x[i1] * dt_soft_plus;
|
| 16455 |
-
float sumf = 0.0f;
|
| 16456 |
-
// d_state
|
| 16457 |
-
for (int i0 = 0; i0 < nc; ++i0) {
|
| 16458 |
-
int i = i0 + i1*nc;
|
| 16459 |
-
// state = prev_state * dA + dB * x
|
| 16460 |
-
float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
|
| 16461 |
-
// y = rowwise_dotprod(state, C)
|
| 16462 |
-
sumf += state * C[i0];
|
| 16463 |
-
s[i] = state;
|
| 16464 |
-
}
|
| 16465 |
-
y[i1] = sumf;
|
| 16466 |
-
}
|
| 16467 |
-
|
| 16468 |
-
// handle copies when there are multiple output states
|
| 16469 |
-
for (int i3 = 1; i3 < n_kv; ++i3) {
|
| 16470 |
-
int32_t seq = sq[i3];
|
| 16471 |
-
if (0 <= seq && seq < n_kv) {
|
| 16472 |
-
float * s1 = s + (seq - sq[0])*nc*nr;
|
| 16473 |
-
memcpy(s1, s, nc*ir*sizeof(float));
|
| 16474 |
-
} else {
|
| 16475 |
-
// stop at negative or too big seq_ids
|
| 16476 |
-
break;
|
| 16477 |
}
|
| 16478 |
}
|
| 16479 |
}
|
|
|
|
| 7384 |
|
| 7385 |
struct ggml_tensor * ggml_ssm_conv(
|
| 7386 |
struct ggml_context * ctx,
|
| 7387 |
+
struct ggml_tensor * sx,
|
| 7388 |
+
struct ggml_tensor * c) {
|
| 7389 |
+
GGML_ASSERT(ggml_is_3d(sx));
|
|
|
|
|
|
|
|
|
|
| 7390 |
GGML_ASSERT(ggml_is_matrix(c));
|
|
|
|
|
|
|
| 7391 |
|
| 7392 |
+
const int64_t d_conv = c->ne[0];
|
| 7393 |
+
const int64_t d_inner = c->ne[1];
|
| 7394 |
+
const int64_t n_t = sx->ne[0] - d_conv + 1; // tokens per sequence
|
| 7395 |
+
const int64_t n_s = sx->ne[2];
|
| 7396 |
|
| 7397 |
+
// TODO: maybe support other strides than 1?
|
| 7398 |
+
GGML_ASSERT(sx->ne[0] == d_conv - 1 + n_t);
|
| 7399 |
+
GGML_ASSERT(sx->ne[1] == d_inner);
|
| 7400 |
+
GGML_ASSERT(n_t >= 0);
|
|
|
|
| 7401 |
|
| 7402 |
bool is_node = false;
|
| 7403 |
|
| 7404 |
+
if (sx->grad || c->grad) {
|
| 7405 |
GGML_ABORT("fatal error"); // TODO: implement
|
| 7406 |
is_node = true;
|
| 7407 |
}
|
| 7408 |
|
| 7409 |
+
struct ggml_tensor * result = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_inner, n_t, n_s);
|
|
|
|
| 7410 |
|
| 7411 |
result->op = GGML_OP_SSM_CONV;
|
| 7412 |
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
| 7413 |
+
result->src[0] = sx;
|
| 7414 |
+
result->src[1] = c;
|
|
|
|
|
|
|
| 7415 |
|
| 7416 |
return result;
|
| 7417 |
}
|
|
|
|
| 7425 |
struct ggml_tensor * dt,
|
| 7426 |
struct ggml_tensor * A,
|
| 7427 |
struct ggml_tensor * B,
|
| 7428 |
+
struct ggml_tensor * C) {
|
|
|
|
| 7429 |
GGML_ASSERT(ggml_is_contiguous(s));
|
| 7430 |
GGML_ASSERT(ggml_is_contiguous(x));
|
| 7431 |
GGML_ASSERT(ggml_is_contiguous(dt));
|
| 7432 |
GGML_ASSERT(ggml_is_contiguous(A));
|
| 7433 |
+
GGML_ASSERT(ggml_is_matrix(A));
|
| 7434 |
+
GGML_ASSERT(ggml_is_3d(B));
|
| 7435 |
+
GGML_ASSERT(ggml_is_3d(s));
|
| 7436 |
GGML_ASSERT(B->nb[0] == ggml_type_size(B->type));
|
| 7437 |
GGML_ASSERT(C->nb[0] == ggml_type_size(C->type));
|
| 7438 |
GGML_ASSERT(ggml_are_same_shape(x, dt));
|
| 7439 |
+
GGML_ASSERT(ggml_are_same_shape(B, C));
|
| 7440 |
|
| 7441 |
{
|
| 7442 |
+
const int64_t d_state = s->ne[0];
|
| 7443 |
+
const int64_t d_inner = s->ne[1];
|
| 7444 |
+
const int64_t n_seq_tokens = x->ne[1];
|
| 7445 |
+
const int64_t n_seqs = x->ne[2];
|
| 7446 |
|
| 7447 |
+
GGML_ASSERT(s->ne[2] == n_seqs);
|
| 7448 |
GGML_ASSERT(x->ne[0] == d_inner);
|
| 7449 |
GGML_ASSERT(A->ne[0] == d_state);
|
| 7450 |
GGML_ASSERT(A->ne[1] == d_inner);
|
| 7451 |
GGML_ASSERT(B->ne[0] == d_state);
|
| 7452 |
+
GGML_ASSERT(B->ne[1] == n_seq_tokens);
|
| 7453 |
+
GGML_ASSERT(B->ne[2] == n_seqs);
|
|
|
|
| 7454 |
}
|
| 7455 |
|
| 7456 |
bool is_node = false;
|
| 7457 |
|
| 7458 |
+
if (s->grad || x->grad || dt->grad || A->grad || B->grad || C->grad) {
|
| 7459 |
GGML_ABORT("fatal error"); // TODO: implement
|
| 7460 |
is_node = true;
|
| 7461 |
}
|
| 7462 |
|
| 7463 |
+
// concatenated y + ssm_states
|
| 7464 |
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + ggml_nelements(s));
|
| 7465 |
|
| 7466 |
result->op = GGML_OP_SSM_SCAN;
|
|
|
|
| 7471 |
result->src[3] = A;
|
| 7472 |
result->src[4] = B;
|
| 7473 |
result->src[5] = C;
|
|
|
|
| 7474 |
|
| 7475 |
return result;
|
| 7476 |
}
|
|
|
|
| 11247 |
|
| 11248 |
GGML_TENSOR_BINARY_OP_LOCALS
|
| 11249 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11250 |
const int32_t dim = ggml_get_op_params_i32(dst, 0);
|
| 11251 |
|
| 11252 |
GGML_ASSERT(dim >= 0 && dim < 4);
|
|
|
|
| 16244 |
static void ggml_compute_forward_ssm_conv_f32(
|
| 16245 |
const struct ggml_compute_params * params,
|
| 16246 |
struct ggml_tensor * dst) {
|
| 16247 |
+
const struct ggml_tensor * src0 = dst->src[0]; // conv_x
|
| 16248 |
+
const struct ggml_tensor * src1 = dst->src[1]; // conv1d.weight
|
|
|
|
|
|
|
| 16249 |
|
| 16250 |
const int ith = params->ith;
|
| 16251 |
const int nth = params->nth;
|
| 16252 |
|
| 16253 |
+
const int nc = src1->ne[0]; // d_conv
|
| 16254 |
+
const int ncs = src0->ne[0]; // d_conv - 1 + n_t
|
| 16255 |
+
const int nr = src0->ne[1]; // d_inner
|
| 16256 |
+
const int n_t = dst->ne[1]; // tokens per sequence
|
| 16257 |
+
const int n_s = dst->ne[2]; // number of sequences in the batch
|
| 16258 |
|
| 16259 |
+
GGML_ASSERT( dst->ne[0] == nr);
|
| 16260 |
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
| 16261 |
GGML_ASSERT(src1->nb[0] == sizeof(float));
|
|
|
|
|
|
|
| 16262 |
GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
|
|
|
|
|
|
|
| 16263 |
|
| 16264 |
// rows per thread
|
| 16265 |
const int dr = (nr + nth - 1)/nth;
|
|
|
|
| 16269 |
const int ir1 = MIN(ir0 + dr, nr);
|
| 16270 |
const int ir = ir1 - ir0;
|
| 16271 |
|
| 16272 |
+
for (int i3 = 0; i3 < n_s; ++i3) {
|
| 16273 |
+
for (int i2 = 0; i2 < n_t; ++i2) {
|
| 16274 |
+
// {d_conv - 1 + n_t, d_inner, n_seqs}
|
| 16275 |
+
// sliding window
|
| 16276 |
+
const float * s = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i2*(src0->nb[0]) + i3*(src0->nb[2])); // {d_conv, d_inner, n_s}
|
| 16277 |
+
const float * c = (const float *) ((const char *) src1->data + ir0*(src1->nb[1])); // {d_conv, d_inner}
|
| 16278 |
+
float * x = (float *) ((char *) dst->data + ir0*(dst->nb[0]) + i2*(dst->nb[1]) + i3*(dst->nb[2])); // {d_inner, n_t, n_s}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16279 |
|
| 16280 |
+
// TODO: transpose the output for smaller strides for big batches?
|
| 16281 |
+
// d_inner
|
| 16282 |
+
for (int i1 = 0; i1 < ir; ++i1) {
|
| 16283 |
+
// rowwise dot product
|
| 16284 |
+
// NOTE: not using ggml_vec_dot_f32, because its sum is in double precision
|
| 16285 |
+
float sumf = 0.0f;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16286 |
|
| 16287 |
+
// d_conv
|
| 16288 |
+
for (int i0 = 0; i0 < nc; ++i0) {
|
| 16289 |
+
sumf += s[i0 + i1*ncs] * c[i0 + i1*nc];
|
| 16290 |
+
}
|
| 16291 |
+
x[i1] = sumf;
|
|
|
|
|
|
|
| 16292 |
}
|
|
|
|
| 16293 |
}
|
| 16294 |
}
|
| 16295 |
}
|
|
|
|
| 16320 |
const struct ggml_tensor * src3 = dst->src[3]; // A
|
| 16321 |
const struct ggml_tensor * src4 = dst->src[4]; // B
|
| 16322 |
const struct ggml_tensor * src5 = dst->src[5]; // C
|
|
|
|
| 16323 |
|
| 16324 |
const int ith = params->ith;
|
| 16325 |
const int nth = params->nth;
|
| 16326 |
|
| 16327 |
+
const int64_t nc = src0->ne[0]; // d_state
|
| 16328 |
+
const int64_t nr = src0->ne[1]; // d_inner
|
| 16329 |
+
const int64_t n_t = src1->ne[1]; // number of tokens per sequence
|
| 16330 |
+
const int64_t n_s = src0->ne[2]; // number of sequences in the batch
|
| 16331 |
|
| 16332 |
GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst));
|
| 16333 |
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
|
|
|
| 16336 |
GGML_ASSERT(src3->nb[0] == sizeof(float));
|
| 16337 |
GGML_ASSERT(src4->nb[0] == sizeof(float));
|
| 16338 |
GGML_ASSERT(src5->nb[0] == sizeof(float));
|
| 16339 |
+
// required for the dot product between s and C
|
| 16340 |
GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
|
| 16341 |
// required for per-sequence offsets for states
|
| 16342 |
GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float));
|
| 16343 |
+
// required to get correct offset for state destination (i.e. src1->nb[3])
|
| 16344 |
+
GGML_ASSERT(src1->nb[3] == src1->ne[0]*src1->ne[1]*src1->ne[2]*sizeof(float));
|
| 16345 |
|
| 16346 |
// rows per thread
|
| 16347 |
const int dr = (nr + nth - 1)/nth;
|
|
|
|
| 16351 |
const int ir1 = MIN(ir0 + dr, nr);
|
| 16352 |
const int ir = ir1 - ir0;
|
| 16353 |
|
| 16354 |
+
for (int i3 = 0; i3 < n_s; ++i3) {
|
| 16355 |
+
for (int i2 = 0; i2 < n_t; ++i2) {
|
| 16356 |
+
const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
|
| 16357 |
+
const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
|
| 16358 |
+
const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
|
| 16359 |
+
const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
|
| 16360 |
+
const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
|
| 16361 |
+
const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
|
| 16362 |
+
float * y = (float *) ((char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
|
| 16363 |
+
float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
|
| 16364 |
+
|
| 16365 |
+
// use the output as the source for the next token-wise iterations
|
| 16366 |
+
if (i2 > 0) { s0 = s; }
|
| 16367 |
+
|
| 16368 |
+
// d_inner
|
| 16369 |
+
for (int i1 = 0; i1 < ir; ++i1) {
|
| 16370 |
+
// ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78
|
| 16371 |
+
float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
|
| 16372 |
+
float x_dt = x[i1] * dt_soft_plus;
|
| 16373 |
+
float sumf = 0.0f;
|
| 16374 |
+
// d_state
|
| 16375 |
+
for (int i0 = 0; i0 < nc; ++i0) {
|
| 16376 |
+
int i = i0 + i1*nc;
|
| 16377 |
+
// state = prev_state * dA + dB * x
|
| 16378 |
+
float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
|
| 16379 |
+
// y = rowwise_dotprod(state, C)
|
| 16380 |
+
sumf += state * C[i0];
|
| 16381 |
+
s[i] = state;
|
| 16382 |
+
}
|
| 16383 |
+
y[i1] = sumf;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16384 |
}
|
| 16385 |
}
|
| 16386 |
}
|