Chenguang Li commited on
Commit
6aecea5
·
1 Parent(s): 8a74c6b

CANN: Support more ops (llama/12841)

Browse files

* [CANN]Support Opt LOG && MEAN && PAD_REFLECT_1D

* [CANN]Support COUNT_EQUAL && STEP && SGN

* [CANN]codestyle adjustment

* [CANN]codestyle adjustment

---------

Signed-off-by: noemotiovon <[email protected]>

ggml/src/ggml-cann/acl_tensor.cpp CHANGED
@@ -41,6 +41,8 @@ aclDataType ggml_cann_type_mapping(ggml_type type) {
41
  return ACL_INT4;
42
  case GGML_TYPE_Q8_0:
43
  return ACL_INT8;
 
 
44
  default:
45
  return ACL_DT_UNDEFINED;
46
  }
 
41
  return ACL_INT4;
42
  case GGML_TYPE_Q8_0:
43
  return ACL_INT8;
44
+ case GGML_TYPE_I64:
45
+ return ACL_INT64;
46
  default:
47
  return ACL_DT_UNDEFINED;
48
  }
ggml/src/ggml-cann/aclnn_ops.cpp CHANGED
@@ -59,6 +59,11 @@
59
  #include <aclnnop/aclnn_div.h>
60
  #include <aclnnop/aclnn_convolution.h>
61
  #include <aclnnop/aclnn_elu.h>
 
 
 
 
 
62
  #include <float.h>
63
 
64
  #include <cmath>
@@ -2598,6 +2603,7 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
2598
  aclTensor* acl_dst = ggml_cann_create_tensor(dst, dst->ne, dst->nb, 3);
2599
 
2600
  GGML_CANN_CALL_ACLNN_OP(ArgMax, acl_src, 3, false, acl_dst);
 
2601
  ACL_CHECK(aclDestroyTensor(acl_src));
2602
  ACL_CHECK(aclDestroyTensor(acl_dst));
2603
  }
@@ -2629,6 +2635,9 @@ void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* ds
2629
 
2630
  ACL_CHECK(aclDestroyTensor(acl_weight));
2631
  ACL_CHECK(aclDestroyTensor(acl_dst));
 
 
 
2632
  }
2633
 
2634
  void ggml_cann_elu(ggml_backend_cann_context& ctx, ggml_tensor* dst){
@@ -2646,4 +2655,79 @@ void ggml_cann_elu(ggml_backend_cann_context& ctx, ggml_tensor* dst){
2646
 
2647
  ACL_CHECK(aclDestroyTensor(acl_input));
2648
  ACL_CHECK(aclDestroyTensor(acl_dst));
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2649
  }
 
59
  #include <aclnnop/aclnn_div.h>
60
  #include <aclnnop/aclnn_convolution.h>
61
  #include <aclnnop/aclnn_elu.h>
62
+ #include <aclnnop/aclnn_log.h>
63
+ #include <aclnnop/aclnn_mean.h>
64
+ #include <aclnnop/aclnn_reflection_pad1d.h>
65
+ #include <aclnnop/aclnn_eq_tensor.h>
66
+ #include <aclnnop/aclnn_gt_scalar.h>
67
  #include <float.h>
68
 
69
  #include <cmath>
 
2603
  aclTensor* acl_dst = ggml_cann_create_tensor(dst, dst->ne, dst->nb, 3);
2604
 
2605
  GGML_CANN_CALL_ACLNN_OP(ArgMax, acl_src, 3, false, acl_dst);
2606
+
2607
  ACL_CHECK(aclDestroyTensor(acl_src));
2608
  ACL_CHECK(aclDestroyTensor(acl_dst));
2609
  }
 
2635
 
2636
  ACL_CHECK(aclDestroyTensor(acl_weight));
2637
  ACL_CHECK(aclDestroyTensor(acl_dst));
2638
+ ACL_CHECK(aclDestroyIntArray(stride));
2639
+ ACL_CHECK(aclDestroyIntArray(padding));
2640
+ ACL_CHECK(aclDestroyIntArray(dilation));
2641
  }
2642
 
2643
  void ggml_cann_elu(ggml_backend_cann_context& ctx, ggml_tensor* dst){
 
2655
 
2656
  ACL_CHECK(aclDestroyTensor(acl_input));
2657
  ACL_CHECK(aclDestroyTensor(acl_dst));
2658
+ ACL_CHECK(aclDestroyScalar(alpha));
2659
+ }
2660
+
2661
+ void ggml_cann_mean(ggml_backend_cann_context& ctx, ggml_tensor* dst){
2662
+ ggml_tensor * src0 = dst->src[0];
2663
+
2664
+ aclTensor* acl_src = ggml_cann_create_tensor(src0);
2665
+ aclTensor* acl_dst = ggml_cann_create_tensor(dst);
2666
+
2667
+ int64_t reduceDimValue[] = {3};
2668
+ aclIntArray* reduceDim = aclCreateIntArray(reduceDimValue, 1);
2669
+ bool keepDim = true;
2670
+
2671
+ GGML_CANN_CALL_ACLNN_OP(Mean, acl_src, reduceDim, keepDim, ACL_FLOAT, acl_dst);
2672
+
2673
+ ACL_CHECK(aclDestroyTensor(acl_src));
2674
+ ACL_CHECK(aclDestroyTensor(acl_dst));
2675
+ ACL_CHECK(aclDestroyIntArray(reduceDim));
2676
+ }
2677
+
2678
+ void ggml_cann_pad_reflect_1d(ggml_backend_cann_context& ctx, ggml_tensor* dst){
2679
+ ggml_tensor * src0 = dst->src[0];
2680
+ int32_t *opts = (int32_t *) dst->op_params;
2681
+ int64_t paddingsArray[2] = {opts[0], opts[1]};
2682
+ aclIntArray* paddings = aclCreateIntArray(paddingsArray, 2);
2683
+
2684
+ for (int64_t i = 0; i < src0->ne[3]; i++) {
2685
+ aclTensor* acl_src = ggml_cann_create_tensor(
2686
+ (char*)src0->data + i * src0->ne[3],
2687
+ ggml_cann_type_mapping(src0->type), ggml_element_size(src0),
2688
+ src0->ne, src0->nb, 3);
2689
+
2690
+ aclTensor* acl_dst = ggml_cann_create_tensor(
2691
+ (char*)dst->data + i * src0->ne[3],
2692
+ ggml_cann_type_mapping(dst->type), ggml_element_size(dst),
2693
+ dst->ne, dst->nb, 3);
2694
+
2695
+ GGML_CANN_CALL_ACLNN_OP(ReflectionPad1d, acl_src, paddings, acl_dst);
2696
+
2697
+ ACL_CHECK(aclDestroyTensor(acl_src));
2698
+ ACL_CHECK(aclDestroyTensor(acl_dst));
2699
+ }
2700
+ ACL_CHECK(aclDestroyIntArray(paddings));
2701
+ }
2702
+
2703
+ void ggml_cann_count_equal(ggml_backend_cann_context& ctx, ggml_tensor* dst){
2704
+ ggml_tensor * src0 = dst->src[0];
2705
+ ggml_tensor * src1 = dst->src[1];
2706
+
2707
+ aclTensor* acl_self = ggml_cann_create_tensor(src0);
2708
+ aclTensor* acl_other = ggml_cann_create_tensor(src1);
2709
+
2710
+ GGML_CANN_CALL_ACLNN_OP(InplaceEqTensor, acl_self, acl_other);
2711
+
2712
+ ggml_cann_sum(ctx, dst);
2713
+
2714
+ ACL_CHECK(aclDestroyTensor(acl_self));
2715
+ ACL_CHECK(aclDestroyTensor(acl_other));
2716
+ }
2717
+
2718
+ void ggml_cann_step(ggml_backend_cann_context& ctx, ggml_tensor* dst){
2719
+ ggml_tensor * src0 = dst->src[0];
2720
+
2721
+ aclTensor* acl_src = ggml_cann_create_tensor(src0);
2722
+ aclTensor* acl_dst = ggml_cann_create_tensor(dst);
2723
+
2724
+ float alphaValue = 0.0f;
2725
+ aclScalar* alpha = nullptr;
2726
+ alpha = aclCreateScalar(&alphaValue, aclDataType::ACL_FLOAT);
2727
+
2728
+ GGML_CANN_CALL_ACLNN_OP(GtScalar, acl_src, alpha, acl_dst);
2729
+
2730
+ ACL_CHECK(aclDestroyTensor(acl_src));
2731
+ ACL_CHECK(aclDestroyTensor(acl_dst));
2732
+ ACL_CHECK(aclDestroyScalar(alpha));
2733
  }
ggml/src/ggml-cann/aclnn_ops.h CHANGED
@@ -42,6 +42,8 @@
42
  #include <aclnnop/aclnn_sqrt.h>
43
  #include <aclnnop/aclnn_sin.h>
44
  #include <aclnnop/aclnn_cos.h>
 
 
45
  #include "acl_tensor.h"
46
  #include "common.h"
47
 
@@ -650,6 +652,67 @@ void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* ds
650
  */
651
  void ggml_cann_elu(ggml_backend_cann_context& ctx, ggml_tensor* dst);
652
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
653
  /**
654
  * @brief Applies a element-wise operation to two input tensors using the CANN
655
  * backend.
 
42
  #include <aclnnop/aclnn_sqrt.h>
43
  #include <aclnnop/aclnn_sin.h>
44
  #include <aclnnop/aclnn_cos.h>
45
+ #include <aclnnop/aclnn_log.h>
46
+ #include <aclnnop/aclnn_sign.h>
47
  #include "acl_tensor.h"
48
  #include "common.h"
49
 
 
652
  */
653
  void ggml_cann_elu(ggml_backend_cann_context& ctx, ggml_tensor* dst);
654
 
655
+ /**
656
+ * @brief Computes the mean of a ggml tensor element-wise using the CANN backend.
657
+ *
658
+ * @details This function calculates the element-wise mean of the input tensor.
659
+ * The result is written to the destination tensor `dst`.
660
+ * The mean is computed by averaging the values across the entire tensor.
661
+ *
662
+ * This operation is optimized using the CANN backend for high-performance inference or training.
663
+ *
664
+ * @param ctx The CANN context used for operations.
665
+ * @param dst The destination tensor where the mean result will be stored.
666
+ * dst->op is expected to be `GGML_OP_MEAN`.
667
+ */
668
+ void ggml_cann_mean(ggml_backend_cann_context& ctx, ggml_tensor* dst);
669
+
670
+ /**
671
+ * @brief Applies 1D reflect padding to a ggml tensor using the CANN backend.
672
+ *
673
+ * @details This function performs 1D reflect padding on the input tensor.
674
+ * The amount of padding on each side is specified by parameters stored in `dst->op_params`.
675
+ * The operation reflects the values at the borders of the tensor to generate the padded output.
676
+ *
677
+ * This operation is optimized using the CANN backend for high-performance inference or training.
678
+ *
679
+ * @param ctx The CANN context used for operations.
680
+ * @param dst The destination tensor where the padded result will be stored.
681
+ * dst->op is expected to be `GGML_OP_PAD_REFLECT_1D`.
682
+ */
683
+ void ggml_cann_pad_reflect_1d(ggml_backend_cann_context& ctx, ggml_tensor* dst);
684
+
685
+ /**
686
+ * @brief Counts the number of equal elements in two ggml tensors using the CANN backend.
687
+ *
688
+ * @details This function performs an element-wise comparison between two input tensors,
689
+ * and counts the number of positions where the elements are equal. The result is
690
+ * stored in the destination tensor `dst` as a scalar.
691
+ *
692
+ * The operation is optimized using the CANN backend, making it suitable for
693
+ * high-performance inference or training scenarios.
694
+ *
695
+ * @param ctx The CANN context used for operations.
696
+ * @param dst The destination tensor where the result will be stored.
697
+ * dst->op is expected to be `GGML_OP_COUNT_EQUAL`.
698
+ */
699
+ void ggml_cann_count_equal(ggml_backend_cann_context& ctx, ggml_tensor* dst);
700
+
701
+ /**
702
+ * @brief Applies the Step activation function to a ggml tensor using the CANN backend.
703
+ *
704
+ * @details This function applies a step function element-wise to the input tensor, where
705
+ * each element is transformed to 1.0 if it is greater than 0, and 0.0 otherwise.
706
+ * The result is stored in the destination tensor `dst`.
707
+ *
708
+ * This operation is accelerated using the CANN backend to improve runtime performance.
709
+ *
710
+ * @param ctx The CANN context used for operations.
711
+ * @param dst The destination tensor where the result will be stored.
712
+ * dst->op is expected to be `GGML_OP_STEP`.
713
+ */
714
+ void ggml_cann_step(ggml_backend_cann_context& ctx, ggml_tensor* dst);
715
+
716
  /**
717
  * @brief Applies a element-wise operation to two input tensors using the CANN
718
  * backend.
ggml/src/ggml-cann/ggml-cann.cpp CHANGED
@@ -1358,6 +1358,12 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
1358
  case GGML_UNARY_OP_ELU:
1359
  ggml_cann_elu(ctx, dst);
1360
  break;
 
 
 
 
 
 
1361
  default:
1362
  return false;
1363
  }
@@ -1456,6 +1462,18 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
1456
  case GGML_OP_CONV_TRANSPOSE_1D:
1457
  ggml_cann_conv_transpose_1d(ctx, dst);
1458
  break;
 
 
 
 
 
 
 
 
 
 
 
 
1459
  default:
1460
  return false;
1461
  }
@@ -1718,6 +1736,8 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
1718
  case GGML_UNARY_OP_TANH:
1719
  case GGML_UNARY_OP_EXP:
1720
  case GGML_UNARY_OP_ELU:
 
 
1721
  return true;
1722
  default:
1723
  return false;
@@ -1854,6 +1874,10 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
1854
  case GGML_OP_COS:
1855
  case GGML_OP_SIN:
1856
  case GGML_OP_CONV_TRANSPOSE_1D:
 
 
 
 
1857
  return true;
1858
  default:
1859
  return false;
 
1358
  case GGML_UNARY_OP_ELU:
1359
  ggml_cann_elu(ctx, dst);
1360
  break;
1361
+ case GGML_UNARY_OP_SGN:
1362
+ GGML_CANN_CALL_UNARY_OP(Sign);
1363
+ break;
1364
+ case GGML_UNARY_OP_STEP:
1365
+ ggml_cann_step(ctx, dst);
1366
+ break;
1367
  default:
1368
  return false;
1369
  }
 
1462
  case GGML_OP_CONV_TRANSPOSE_1D:
1463
  ggml_cann_conv_transpose_1d(ctx, dst);
1464
  break;
1465
+ case GGML_OP_LOG:
1466
+ GGML_CANN_CALL_UNARY_OP(Log);
1467
+ break;
1468
+ case GGML_OP_MEAN:
1469
+ ggml_cann_mean(ctx, dst);
1470
+ break;
1471
+ case GGML_OP_PAD_REFLECT_1D:
1472
+ ggml_cann_pad_reflect_1d(ctx, dst);
1473
+ break;
1474
+ case GGML_OP_COUNT_EQUAL:
1475
+ ggml_cann_count_equal(ctx, dst);
1476
+ break;
1477
  default:
1478
  return false;
1479
  }
 
1736
  case GGML_UNARY_OP_TANH:
1737
  case GGML_UNARY_OP_EXP:
1738
  case GGML_UNARY_OP_ELU:
1739
+ case GGML_UNARY_OP_SGN:
1740
+ case GGML_UNARY_OP_STEP:
1741
  return true;
1742
  default:
1743
  return false;
 
1874
  case GGML_OP_COS:
1875
  case GGML_OP_SIN:
1876
  case GGML_OP_CONV_TRANSPOSE_1D:
1877
+ case GGML_OP_LOG:
1878
+ case GGML_OP_MEAN:
1879
+ case GGML_OP_PAD_REFLECT_1D:
1880
+ case GGML_OP_COUNT_EQUAL:
1881
  return true;
1882
  default:
1883
  return false;