vmobilis commited on
Commit
c9a49f9
·
1 Parent(s): 4c17fa1

ggml : ggml_compute_forward_concat() for arbitrary tensor type (ggml/1118)

Browse files

* ggml_compute_forward_concat() for arbitrary tensor type

* Check that tensors' type match

* ggml-cpu.c: check type of source tensors

* ggml-cpu.c: move tensor type check to ggml_compute_forward_concat()

* ggml.c: check concatenated tensor type

* Remove tensor type check from ggml_compute_forward_concat() in ggml-cpu.c

..., as it was moved to ggml.c.

Files changed (2) hide show
  1. ggml/src/ggml-cpu/ggml-cpu.c +141 -2
  2. ggml/src/ggml.c +1 -0
ggml/src/ggml-cpu/ggml-cpu.c CHANGED
@@ -6648,6 +6648,135 @@ static void ggml_compute_forward_repeat_back(
6648
 
6649
  // ggml_compute_forward_concat
6650
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6651
  static void ggml_compute_forward_concat_f32(
6652
  const struct ggml_compute_params * params,
6653
  struct ggml_tensor * dst) {
@@ -6655,7 +6784,7 @@ static void ggml_compute_forward_concat_f32(
6655
  const struct ggml_tensor * src0 = dst->src[0];
6656
  const struct ggml_tensor * src1 = dst->src[1];
6657
 
6658
- GGML_ASSERT(src0->nb[0] == sizeof(float));
6659
 
6660
  const int ith = params->ith;
6661
  const int nth = params->nth;
@@ -6698,6 +6827,16 @@ static void ggml_compute_forward_concat(
6698
  const struct ggml_tensor * src0 = dst->src[0];
6699
 
6700
  switch (src0->type) {
 
 
 
 
 
 
 
 
 
 
6701
  case GGML_TYPE_F32:
6702
  case GGML_TYPE_I32:
6703
  {
@@ -6705,7 +6844,7 @@ static void ggml_compute_forward_concat(
6705
  } break;
6706
  default:
6707
  {
6708
- GGML_ABORT("fatal error");
6709
  }
6710
  }
6711
  }
 
6648
 
6649
  // ggml_compute_forward_concat
6650
 
6651
+ static void ggml_compute_forward_concat_any(
6652
+ const struct ggml_compute_params * params,
6653
+ struct ggml_tensor * dst) {
6654
+
6655
+ const struct ggml_tensor * src0 = dst->src[0];
6656
+ const struct ggml_tensor * src1 = dst->src[1];
6657
+
6658
+ const size_t len = ggml_type_size(src0->type);
6659
+
6660
+ const int ith = params->ith;
6661
+ const int nth = params->nth;
6662
+
6663
+ GGML_TENSOR_BINARY_OP_LOCALS
6664
+
6665
+ const int32_t dim = ggml_get_op_params_i32(dst, 0);
6666
+
6667
+ GGML_ASSERT(dim >= 0 && dim < 4);
6668
+
6669
+ int64_t o[4] = {0, 0, 0, 0};
6670
+ o[dim] = src0->ne[dim];
6671
+
6672
+ const char * x;
6673
+
6674
+ // TODO: smarter multi-theading
6675
+ for (int i3 = 0; i3 < ne3; i3++) {
6676
+ for (int i2 = ith; i2 < ne2; i2 += nth) {
6677
+ for (int i1 = 0; i1 < ne1; i1++) {
6678
+ for (int i0 = 0; i0 < ne0; i0++) {
6679
+ if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
6680
+ x = (const char *)src0->data + (i0 )*nb00 + (i1 )*nb01 + (i2 )*nb02 + (i3 )*nb03;
6681
+ } else {
6682
+ x = (const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13;
6683
+ }
6684
+
6685
+ char * y = (char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3;
6686
+
6687
+ memcpy(y, x, len);
6688
+ }
6689
+ }
6690
+ }
6691
+ }
6692
+ }
6693
+
6694
+ static void ggml_compute_forward_concat_i8(
6695
+ const struct ggml_compute_params * params,
6696
+ struct ggml_tensor * dst) {
6697
+
6698
+ const struct ggml_tensor * src0 = dst->src[0];
6699
+ const struct ggml_tensor * src1 = dst->src[1];
6700
+
6701
+ GGML_ASSERT(ggml_type_size(src0->type) == sizeof(int8_t));
6702
+
6703
+ const int ith = params->ith;
6704
+ const int nth = params->nth;
6705
+
6706
+ GGML_TENSOR_BINARY_OP_LOCALS
6707
+
6708
+ const int32_t dim = ggml_get_op_params_i32(dst, 0);
6709
+
6710
+ GGML_ASSERT(dim >= 0 && dim < 4);
6711
+
6712
+ int64_t o[4] = {0, 0, 0, 0};
6713
+ o[dim] = src0->ne[dim];
6714
+
6715
+ const int8_t * x;
6716
+
6717
+ // TODO: smarter multi-theading
6718
+ for (int i3 = 0; i3 < ne3; i3++) {
6719
+ for (int i2 = ith; i2 < ne2; i2 += nth) {
6720
+ for (int i1 = 0; i1 < ne1; i1++) {
6721
+ for (int i0 = 0; i0 < ne0; i0++) {
6722
+ if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
6723
+ x = (const int8_t *) ((const char *)src0->data + (i0 )*nb00 + (i1 )*nb01 + (i2 )*nb02 + (i3 )*nb03);
6724
+ } else {
6725
+ x = (const int8_t *) ((const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13);
6726
+ }
6727
+
6728
+ int8_t * y = (int8_t *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
6729
+
6730
+ *y = *x;
6731
+ }
6732
+ }
6733
+ }
6734
+ }
6735
+ }
6736
+
6737
+ static void ggml_compute_forward_concat_f16(
6738
+ const struct ggml_compute_params * params,
6739
+ struct ggml_tensor * dst) {
6740
+
6741
+ const struct ggml_tensor * src0 = dst->src[0];
6742
+ const struct ggml_tensor * src1 = dst->src[1];
6743
+
6744
+ GGML_ASSERT(ggml_type_size(src0->type) == sizeof(ggml_fp16_t));
6745
+
6746
+ const int ith = params->ith;
6747
+ const int nth = params->nth;
6748
+
6749
+ GGML_TENSOR_BINARY_OP_LOCALS
6750
+
6751
+ const int32_t dim = ggml_get_op_params_i32(dst, 0);
6752
+
6753
+ GGML_ASSERT(dim >= 0 && dim < 4);
6754
+
6755
+ int64_t o[4] = {0, 0, 0, 0};
6756
+ o[dim] = src0->ne[dim];
6757
+
6758
+ const ggml_fp16_t * x;
6759
+
6760
+ // TODO: smarter multi-theading
6761
+ for (int i3 = 0; i3 < ne3; i3++) {
6762
+ for (int i2 = ith; i2 < ne2; i2 += nth) {
6763
+ for (int i1 = 0; i1 < ne1; i1++) {
6764
+ for (int i0 = 0; i0 < ne0; i0++) {
6765
+ if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
6766
+ x = (const ggml_fp16_t *) ((const char *)src0->data + (i0 )*nb00 + (i1 )*nb01 + (i2 )*nb02 + (i3 )*nb03);
6767
+ } else {
6768
+ x = (const ggml_fp16_t *) ((const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13);
6769
+ }
6770
+
6771
+ ggml_fp16_t * y = (ggml_fp16_t *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
6772
+
6773
+ *y = *x;
6774
+ }
6775
+ }
6776
+ }
6777
+ }
6778
+ }
6779
+
6780
  static void ggml_compute_forward_concat_f32(
6781
  const struct ggml_compute_params * params,
6782
  struct ggml_tensor * dst) {
 
6784
  const struct ggml_tensor * src0 = dst->src[0];
6785
  const struct ggml_tensor * src1 = dst->src[1];
6786
 
6787
+ GGML_ASSERT(ggml_type_size(src0->type) == sizeof(float));
6788
 
6789
  const int ith = params->ith;
6790
  const int nth = params->nth;
 
6827
  const struct ggml_tensor * src0 = dst->src[0];
6828
 
6829
  switch (src0->type) {
6830
+ case GGML_TYPE_F16:
6831
+ case GGML_TYPE_BF16:
6832
+ case GGML_TYPE_I16:
6833
+ {
6834
+ ggml_compute_forward_concat_f16(params, dst);
6835
+ } break;
6836
+ case GGML_TYPE_I8:
6837
+ {
6838
+ ggml_compute_forward_concat_i8(params, dst);
6839
+ } break;
6840
  case GGML_TYPE_F32:
6841
  case GGML_TYPE_I32:
6842
  {
 
6844
  } break;
6845
  default:
6846
  {
6847
+ ggml_compute_forward_concat_any(params, dst);
6848
  }
6849
  }
6850
  }
ggml/src/ggml.c CHANGED
@@ -2332,6 +2332,7 @@ struct ggml_tensor * ggml_concat(
2332
  struct ggml_tensor * b,
2333
  int dim) {
2334
  GGML_ASSERT(dim >= 0 && dim < GGML_MAX_DIMS);
 
2335
 
2336
  int64_t ne[GGML_MAX_DIMS];
2337
  for (int d = 0; d < GGML_MAX_DIMS; ++d) {
 
2332
  struct ggml_tensor * b,
2333
  int dim) {
2334
  GGML_ASSERT(dim >= 0 && dim < GGML_MAX_DIMS);
2335
+ GGML_ASSERT(a->type == b->type);
2336
 
2337
  int64_t ne[GGML_MAX_DIMS];
2338
  for (int d = 0; d < GGML_MAX_DIMS; ++d) {