amritahs-ibm commited on
Commit
6f18eed
·
1 Parent(s): 623b74d

llamafile : ppc64le MMA INT8 implementation (llama/10912)

Browse files

This change upstreams llamafile's cpu matrix
multiplication kernels for ppc64le using MMA
builtins for quantised int8 datatype.

This change results in 10% - 70% improvement
in total speed(ie all tokens/total time), across
various batch sizes.

The patch is tested with Meta-Lllama-3-8B,
Mistral-7B, Llama-2-7B-chat-hf models on a
IBM POWER10 machine.

Signed-off-by: Amrita H S <[email protected]>

Files changed (1) hide show
  1. ggml/src/ggml-cpu/llamafile/sgemm.cpp +773 -69
ggml/src/ggml-cpu/llamafile/sgemm.cpp CHANGED
@@ -54,6 +54,7 @@
54
  #include "ggml-quants.h"
55
 
56
  #include <atomic>
 
57
 
58
  #ifdef _MSC_VER
59
  #define NOINLINE __declspec(noinline)
@@ -1051,6 +1052,704 @@ class tinyBLAS_Q0_AVX {
1051
  } \
1052
  } \
1053
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1054
  template <typename TA, typename TB, typename TC>
1055
  class tinyBLAS_PPC {
1056
  public:
@@ -1070,13 +1769,17 @@ class tinyBLAS_PPC {
1070
 
1071
  void (tinyBLAS_PPC::*kernel)(int64_t, int64_t);
1072
 
1073
- void READ_BLOCK(const float* a, int64_t lda, int rows, int cols, float* vec) {
 
1074
  int64_t i, j;
1075
- float *aoffset = NULL, *boffset = NULL;
1076
- float *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
1077
- float *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
1078
-
1079
- aoffset = const_cast<float*>(a);
 
 
 
1080
  boffset = vec;
1081
  j = (rows >> 3);
1082
  if (j > 0) {
@@ -1092,9 +1795,6 @@ class tinyBLAS_PPC {
1092
  aoffset += 8 * lda;
1093
  i = (cols >> 3);
1094
  if (i > 0) {
1095
- __vector_pair C1, C2, C3, C4, C5, C6, C7, C8;
1096
- vector float c1[2], c2[2], c3[2], c4[2], c5[2], c6[2], c7[2], c8[2];
1097
- vector float t1, t2, t3, t4, t5, t6, t7, t8;
1098
  do {
1099
  C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1);
1100
  C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2);
@@ -1174,21 +1874,19 @@ class tinyBLAS_PPC {
1174
  } while(i > 0);
1175
  }
1176
  if (cols & 4) {
1177
- vector float c1, c2, c3, c4, c5, c6, c7, c8;
1178
- vector float t1, t2, t3, t4, t5, t6, t7, t8;
1179
- c1 = vec_xl(0, aoffset1);
1180
- c2 = vec_xl(0, aoffset2);
1181
- c3 = vec_xl(0, aoffset3);
1182
- c4 = vec_xl(0, aoffset4);
1183
- c5 = vec_xl(0, aoffset5);
1184
- c6 = vec_xl(0, aoffset6);
1185
- c7 = vec_xl(0, aoffset7);
1186
- c8 = vec_xl(0, aoffset8);
1187
-
1188
- t1 = vec_mergeh(c1, c2);
1189
- t2 = vec_mergeh(c3, c4);
1190
- t3 = vec_mergeh(c5, c6);
1191
- t4 = vec_mergeh(c7, c8);
1192
  t5 = vec_xxpermdi(t1, t2, 0);
1193
  t6 = vec_xxpermdi(t3, t4, 0);
1194
  t7 = vec_xxpermdi(t1, t2, 3);
@@ -1198,10 +1896,10 @@ class tinyBLAS_PPC {
1198
  vec_xst(t7, 0, boffset+8);
1199
  vec_xst(t8, 0, boffset+12);
1200
 
1201
- t1 = vec_mergel(c1, c2);
1202
- t2 = vec_mergel(c3, c4);
1203
- t3 = vec_mergel(c5, c6);
1204
- t4 = vec_mergel(c7, c8);
1205
  t5 = vec_xxpermdi(t1, t2, 0);
1206
  t6 = vec_xxpermdi(t3, t4, 0);
1207
  t7 = vec_xxpermdi(t1, t2, 3);
@@ -1223,9 +1921,6 @@ class tinyBLAS_PPC {
1223
  aoffset += 4 * lda;
1224
  i = (cols >> 3);
1225
  if (i > 0) {
1226
- __vector_pair C1, C2, C3, C4;
1227
- vector float c1[2], c2[2], c3[2], c4[2];
1228
- vector float t1, t2, t3, t4, t5, t6, t7, t8;
1229
  do {
1230
  C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1);
1231
  C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2);
@@ -1272,22 +1967,20 @@ class tinyBLAS_PPC {
1272
  }
1273
 
1274
  if (cols & 4) {
1275
- vector float c1, c2, c3, c4;
1276
- vector float t1, t2, t3, t4;
1277
- c1 = vec_xl(0, aoffset1);
1278
- c2 = vec_xl(0, aoffset2);
1279
- c3 = vec_xl(0, aoffset3);
1280
- c4 = vec_xl(0, aoffset4);
1281
-
1282
- t1 = vec_mergeh(c1, c2);
1283
- t2 = vec_mergeh(c3, c4);
1284
  t3 = vec_xxpermdi(t1, t2, 0);
1285
  t4 = vec_xxpermdi(t1, t2, 3);
1286
  vec_xst(t3, 0, boffset);
1287
  vec_xst(t4, 0, boffset+4);
1288
 
1289
- t1 = vec_mergel(c1, c2);
1290
- t2 = vec_mergel(c3, c4);
1291
  t3 = vec_xxpermdi(t1, t2, 0);
1292
  t4 = vec_xxpermdi(t1, t2, 3);
1293
  vec_xst(t3, 0, boffset+8);
@@ -1299,21 +1992,19 @@ class tinyBLAS_PPC {
1299
  aoffset2 = aoffset1 + lda;
1300
  aoffset3 = aoffset2 + lda;
1301
  if (cols & 4) {
1302
- vector float c1, c2, c3, c4 = {0};
1303
- vector float t1, t2, t3, t4;
1304
- c1 = vec_xl(0, aoffset1);
1305
- c2 = vec_xl(0, aoffset2);
1306
- c3 = vec_xl(0, aoffset3);
1307
-
1308
- t1 = vec_mergeh(c1, c2);
1309
- t2 = vec_mergeh(c3, c4);
1310
  t3 = vec_xxpermdi(t1, t2, 0);
1311
  t4 = vec_xxpermdi(t1, t2, 3);
1312
  vec_xst(t3, 0, boffset);
1313
  vec_xst(t4, 0, boffset+4);
1314
 
1315
- t1 = vec_mergel(c1, c2);
1316
- t2 = vec_mergel(c3, c4);
1317
  t3 = vec_xxpermdi(t1, t2, 0);
1318
  t4 = vec_xxpermdi(t1, t2, 3);
1319
  vec_xst(t3, 0, boffset+8);
@@ -1321,14 +2012,13 @@ class tinyBLAS_PPC {
1321
  }
1322
  }
1323
  }
1324
-
1325
  void KERNEL_4x4(int64_t ii, int64_t jj) {
1326
  vec_t vec_A[4], vec_B[4], vec_C[4];
1327
  acc_t acc_0;
1328
  __builtin_mma_xxsetaccz(&acc_0);
1329
  for (int l = 0; l < k; l+=4) {
1330
- READ_BLOCK(A+(ii*lda)+l, lda, 4, 4, (float*)vec_A);
1331
- READ_BLOCK(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B);
1332
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
1333
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
1334
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[2], vec_B[2]);
@@ -1343,8 +2033,8 @@ class tinyBLAS_PPC {
1343
  __builtin_mma_xxsetaccz(&acc_0);
1344
  __builtin_mma_xxsetaccz(&acc_1);
1345
  for (int64_t l = 0; l < k; l+=4) {
1346
- READ_BLOCK(A+(ii*lda)+l, lda, 4, 4, (float*)vec_A);
1347
- READ_BLOCK(B+(jj*ldb)+l, ldb, 8, 4, (float*)vec_B);
1348
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], (vec_t)vec_B[0]);
1349
  __builtin_mma_xvf32gerpp(&acc_1, vec_A[0], (vec_t)vec_B[1]);
1350
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], (vec_t)vec_B[2]);
@@ -1364,8 +2054,8 @@ class tinyBLAS_PPC {
1364
  __builtin_mma_xxsetaccz(&acc_0);
1365
  __builtin_mma_xxsetaccz(&acc_1);
1366
  for (int64_t l = 0; l < k; l+=4) {
1367
- READ_BLOCK(A+(ii*lda)+l, lda, 8, 4, (float*)vec_A);
1368
- READ_BLOCK(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B);
1369
  __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[0], vec_B[0]);
1370
  __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[1], vec_B[0]);
1371
  __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[2], vec_B[1]);
@@ -1387,8 +2077,8 @@ class tinyBLAS_PPC {
1387
  __builtin_mma_xxsetaccz(&acc_2);
1388
  __builtin_mma_xxsetaccz(&acc_3);
1389
  for (int l = 0; l < k; l+=8) {
1390
- READ_BLOCK(A+(ii*lda)+l, lda, 8, 8, (float*)vec_A);
1391
- READ_BLOCK(B+(jj*ldb)+l, ldb, 8, 8, (float*)vec_B);
1392
  for(int x = 0; x < 16; x+=2) {
1393
  __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[x], vec_B[x]);
1394
  __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[x], vec_B[x+1]);
@@ -1571,15 +2261,15 @@ class tinyBLAS_PPC {
1571
  vec_t vec_A[4], vec_B[4];
1572
  for (int l=0; l<k; l+=4) {
1573
  if (RN >= 4 && RM == 1) {
1574
- float* a = const_cast<float*>(A+(ii)*lda+l);
1575
- READ_BLOCK(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B);
1576
  vec_A[0] = (vec_t)vec_xl(0,a);
1577
- vec_A[1] = (vec_t)vec_splats(*((float*)&vec_A+1));
1578
- vec_A[2] = (vec_t)vec_splats(*((float*)&vec_A+2));
1579
- vec_A[3] = (vec_t)vec_splats(*((float*)&vec_A+3));
1580
  } else {
1581
- READ_BLOCK(A+(ii*lda)+l, lda, RM, 4, (float*)vec_A);
1582
- READ_BLOCK(B+(jj*ldb)+l, ldb, RN, 4, (float*)vec_B);
1583
  }
1584
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
1585
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
@@ -1589,7 +2279,7 @@ class tinyBLAS_PPC {
1589
  __builtin_mma_disassemble_acc(vec_C, &acc_0);
1590
  for (int I = 0; I < RM; I++) {
1591
  for (int J = 0; J < RN; J++) {
1592
- *((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&vec_C[I]+J);
1593
  }
1594
  }
1595
  }
@@ -1812,6 +2502,20 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
1812
  params->ith, params->nth};
1813
  tb.matmul(m, n);
1814
  return true;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1815
  #else
1816
  return false;
1817
  #endif
 
54
  #include "ggml-quants.h"
55
 
56
  #include <atomic>
57
+ #include <array>
58
 
59
  #ifdef _MSC_VER
60
  #define NOINLINE __declspec(noinline)
 
1052
  } \
1053
  } \
1054
 
1055
+ template <typename TA, typename TB, typename TC>
1056
+ class tinyBLAS_Q0_PPC {
1057
+ public:
1058
+ tinyBLAS_Q0_PPC(int64_t k,
1059
+ const TA *A, int64_t lda,
1060
+ const TB *B, int64_t ldb,
1061
+ TC *C, int64_t ldc,
1062
+ int ith, int nth)
1063
+ : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
1064
+ }
1065
+
1066
+ void matmul(int64_t m, int64_t n) {
1067
+ mnpack(0, m, 0, n);
1068
+ }
1069
+
1070
+ private:
1071
+
1072
+ template<int RM, int RN>
1073
+ inline void save_res(int ii, int jj, int idx, vector float* fin_res) {
1074
+ for (int I = 0; I < RM; I++) {
1075
+ for (int J = 0; J < RN; J++) {
1076
+ *((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&fin_res[idx+I]+J);
1077
+ }
1078
+ }
1079
+ }
1080
+
1081
+ template<int size>
1082
+ inline void compute(acc_t* ACC, int c_idx, int s_idx, std::array<int, size>& comparray, vector float* vs, vector float* fin_res) {
1083
+ vector signed int vec_C[4];
1084
+ vector float CA[4] = {0};
1085
+ vector float res[4] = {0};
1086
+ __builtin_mma_disassemble_acc(vec_C, ACC);
1087
+ for (int i = 0; i < 4; i++) {
1088
+ CA[i] = vec_splats((float)(((double)comparray[c_idx+i]) * -128.0));
1089
+ res[i] = vec_add(vec_ctf(vec_C[i], 0), CA[i]);
1090
+ fin_res[s_idx+i] = vec_madd(res[i], vs[s_idx+i], fin_res[s_idx+i]);
1091
+ }
1092
+ }
1093
+
1094
+ template<typename VA, typename VB>
1095
+ void packNormal(const TA* a, int64_t lda, int rows, int cols, VA* vec, bool flip) {
1096
+ int64_t i, j;
1097
+ TA *aoffset = NULL;
1098
+ VA *vecOffset = NULL;
1099
+ TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
1100
+ TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
1101
+ __vector_pair C1, C2, C3, C4, C5, C6, C7, C8;
1102
+ VB c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2]={0};
1103
+ VB c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2]={0};
1104
+ VB t1, t2, t3, t4, t5, t6, t7, t8;
1105
+ vector unsigned char xor_vector;
1106
+ uint8_t flip_vec = 0x80;
1107
+ xor_vector = vec_splats(flip_vec);
1108
+ vector unsigned char swiz1 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
1109
+ vector unsigned char swiz2 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
1110
+ vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27};
1111
+ vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31};
1112
+
1113
+ aoffset = const_cast<TA*>(a);
1114
+ vecOffset = vec;
1115
+ j = (rows >> 3);
1116
+ if (j > 0) {
1117
+ do {
1118
+ aoffset1 = aoffset;
1119
+ aoffset2 = aoffset1 + lda;
1120
+ aoffset3 = aoffset2 + lda;
1121
+ aoffset4 = aoffset3 + lda;
1122
+ aoffset5 = aoffset4 + lda;
1123
+ aoffset6 = aoffset5 + lda;
1124
+ aoffset7 = aoffset6 + lda;
1125
+ aoffset8 = aoffset7 + lda;
1126
+ aoffset += 8 * lda;
1127
+
1128
+ i = (cols >> 3);
1129
+ if (i > 0) {
1130
+ do {
1131
+ C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1->qs);
1132
+ C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2->qs);
1133
+ C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3->qs);
1134
+ C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4->qs);
1135
+ C5 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset5->qs);
1136
+ C6 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset6->qs);
1137
+ C7 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset7->qs);
1138
+ C8 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset8->qs);
1139
+
1140
+ __builtin_vsx_disassemble_pair(c1, &C1);
1141
+ __builtin_vsx_disassemble_pair(c2, &C2);
1142
+ __builtin_vsx_disassemble_pair(c3, &C3);
1143
+ __builtin_vsx_disassemble_pair(c4, &C4);
1144
+ __builtin_vsx_disassemble_pair(c5, &C5);
1145
+ __builtin_vsx_disassemble_pair(c6, &C6);
1146
+ __builtin_vsx_disassemble_pair(c7, &C7);
1147
+ __builtin_vsx_disassemble_pair(c8, &C8);
1148
+
1149
+ t1 = vec_perm(c1[0], c2[0], swiz1);
1150
+ t2 = vec_perm(c1[0], c2[0], swiz2);
1151
+ t3 = vec_perm(c3[0], c4[0], swiz1);
1152
+ t4 = vec_perm(c3[0], c4[0], swiz2);
1153
+ t5 = vec_perm(t1, t3, swiz3);
1154
+ t6 = vec_perm(t1, t3, swiz4);
1155
+ t7 = vec_perm(t2, t4, swiz3);
1156
+ t8 = vec_perm(t2, t4, swiz4);
1157
+ if (flip == true) {
1158
+ t5 = vec_xor(t5, xor_vector);
1159
+ t6 = vec_xor(t6, xor_vector);
1160
+ t7 = vec_xor(t7, xor_vector);
1161
+ t8 = vec_xor(t8, xor_vector);
1162
+ }
1163
+ vec_xst(t5, 0, vecOffset);
1164
+ vec_xst(t6, 0, vecOffset+16);
1165
+ vec_xst(t7, 0, vecOffset+32);
1166
+ vec_xst(t8, 0, vecOffset+48);
1167
+
1168
+ t1 = vec_perm(c1[1], c2[1], swiz1);
1169
+ t2 = vec_perm(c1[1], c2[1], swiz2);
1170
+ t3 = vec_perm(c3[1], c4[1], swiz1);
1171
+ t4 = vec_perm(c3[1], c4[1], swiz2);
1172
+ t5 = vec_perm(t1, t3, swiz3);
1173
+ t6 = vec_perm(t1, t3, swiz4);
1174
+ t7 = vec_perm(t2, t4, swiz3);
1175
+ t8 = vec_perm(t2, t4, swiz4);
1176
+ if (flip == true) {
1177
+ t5 = vec_xor(t5, xor_vector);
1178
+ t6 = vec_xor(t6, xor_vector);
1179
+ t7 = vec_xor(t7, xor_vector);
1180
+ t8 = vec_xor(t8, xor_vector);
1181
+ }
1182
+ vec_xst(t5, 0, vecOffset+64);
1183
+ vec_xst(t6, 0, vecOffset+80);
1184
+ vec_xst(t7, 0, vecOffset+96);
1185
+ vec_xst(t8, 0, vecOffset+112);
1186
+
1187
+ t1 = vec_perm(c5[0], c6[0], swiz1);
1188
+ t2 = vec_perm(c5[0], c6[0], swiz2);
1189
+ t3 = vec_perm(c7[0], c8[0], swiz1);
1190
+ t4 = vec_perm(c7[0], c8[0], swiz2);
1191
+ t5 = vec_perm(t1, t3, swiz3);
1192
+ t6 = vec_perm(t1, t3, swiz4);
1193
+ t7 = vec_perm(t2, t4, swiz3);
1194
+ t8 = vec_perm(t2, t4, swiz4);
1195
+ if (flip == true) {
1196
+ t5 = vec_xor(t5, xor_vector);
1197
+ t6 = vec_xor(t6, xor_vector);
1198
+ t7 = vec_xor(t7, xor_vector);
1199
+ t8 = vec_xor(t8, xor_vector);
1200
+ }
1201
+ vec_xst(t5, 0, vecOffset+128);
1202
+ vec_xst(t6, 0, vecOffset+144);
1203
+ vec_xst(t7, 0, vecOffset+160);
1204
+ vec_xst(t8, 0, vecOffset+176);
1205
+
1206
+ t1 = vec_perm(c5[1], c6[1], swiz1);
1207
+ t2 = vec_perm(c5[1], c6[1], swiz2);
1208
+ t3 = vec_perm(c7[1], c8[1], swiz1);
1209
+ t4 = vec_perm(c7[1], c8[1], swiz2);
1210
+ t5 = vec_perm(t1, t3, swiz3);
1211
+ t6 = vec_perm(t1, t3, swiz4);
1212
+ t7 = vec_perm(t2, t4, swiz3);
1213
+ t8 = vec_perm(t2, t4, swiz4);
1214
+ if (flip == true) {
1215
+ t5 = vec_xor(t5, xor_vector);
1216
+ t6 = vec_xor(t6, xor_vector);
1217
+ t7 = vec_xor(t7, xor_vector);
1218
+ t8 = vec_xor(t8, xor_vector);
1219
+ }
1220
+ vec_xst(t5, 0, vecOffset+192);
1221
+ vec_xst(t6, 0, vecOffset+208);
1222
+ vec_xst(t7, 0, vecOffset+224);
1223
+ vec_xst(t8, 0, vecOffset+240);
1224
+
1225
+ aoffset1 += lda;
1226
+ aoffset2 += lda;
1227
+ aoffset3 += lda;
1228
+ aoffset4 += lda;
1229
+ aoffset5 += lda;
1230
+ aoffset6 += lda;
1231
+ aoffset7 += lda;
1232
+ aoffset8 += lda;
1233
+ vecOffset += 256;
1234
+ i--;
1235
+ } while(i > 0);
1236
+ }
1237
+ j--;
1238
+ } while(j > 0);
1239
+ }
1240
+
1241
+ if (rows & 4) {
1242
+ aoffset1 = aoffset;
1243
+ aoffset2 = aoffset1 + lda;
1244
+ aoffset3 = aoffset2 + lda;
1245
+ aoffset4 = aoffset3 + lda;
1246
+ aoffset += 4 * lda;
1247
+
1248
+ i = (cols >> 3);
1249
+ if (i > 0) {
1250
+ do {
1251
+ C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1->qs);
1252
+ C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2->qs);
1253
+ C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3->qs);
1254
+ C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4->qs);
1255
+
1256
+ __builtin_vsx_disassemble_pair(c1, &C1);
1257
+ __builtin_vsx_disassemble_pair(c2, &C2);
1258
+ __builtin_vsx_disassemble_pair(c3, &C3);
1259
+ __builtin_vsx_disassemble_pair(c4, &C4);
1260
+
1261
+ t1 = vec_perm(c1[0], c2[0], swiz1);
1262
+ t2 = vec_perm(c1[0], c2[0], swiz2);
1263
+ t3 = vec_perm(c3[0], c4[0], swiz1);
1264
+ t4 = vec_perm(c3[0], c4[0], swiz2);
1265
+ t5 = vec_perm(t1, t3, swiz3);
1266
+ t6 = vec_perm(t1, t3, swiz4);
1267
+ t7 = vec_perm(t2, t4, swiz3);
1268
+ t8 = vec_perm(t2, t4, swiz4);
1269
+ if (flip == true) {
1270
+ t5 = vec_xor(t5, xor_vector);
1271
+ t6 = vec_xor(t6, xor_vector);
1272
+ t7 = vec_xor(t7, xor_vector);
1273
+ t8 = vec_xor(t8, xor_vector);
1274
+ }
1275
+ vec_xst(t5, 0, vecOffset);
1276
+ vec_xst(t6, 0, vecOffset+16);
1277
+ vec_xst(t7, 0, vecOffset+32);
1278
+ vec_xst(t8, 0, vecOffset+48);
1279
+
1280
+ t1 = vec_perm(c1[1], c2[1], swiz1);
1281
+ t2 = vec_perm(c1[1], c2[1], swiz2);
1282
+ t3 = vec_perm(c3[1], c4[1], swiz1);
1283
+ t4 = vec_perm(c3[1], c4[1], swiz2);
1284
+ t5 = vec_perm(t1, t3, swiz3);
1285
+ t6 = vec_perm(t1, t3, swiz4);
1286
+ t7 = vec_perm(t2, t4, swiz3);
1287
+ t8 = vec_perm(t2, t4, swiz4);
1288
+ if (flip == true) {
1289
+ t5 = vec_xor(t5, xor_vector);
1290
+ t6 = vec_xor(t6, xor_vector);
1291
+ t7 = vec_xor(t7, xor_vector);
1292
+ t8 = vec_xor(t8, xor_vector);
1293
+ }
1294
+ vec_xst(t5, 0, vecOffset+64);
1295
+ vec_xst(t6, 0, vecOffset+80);
1296
+ vec_xst(t7, 0, vecOffset+96);
1297
+ vec_xst(t8, 0, vecOffset+112);
1298
+
1299
+ aoffset1 += lda;
1300
+ aoffset2 += lda;
1301
+ aoffset3 += lda;
1302
+ aoffset4 += lda;
1303
+ vecOffset += 128;
1304
+ i--;
1305
+ } while(i > 0);
1306
+ }
1307
+ }
1308
+ if (rows & 3) {
1309
+ aoffset1 = aoffset;
1310
+ aoffset2 = aoffset1 + lda;
1311
+ aoffset3 = aoffset2 + lda;
1312
+ i = (cols >> 3);
1313
+ if (i > 0) {
1314
+ do {
1315
+ switch(rows) {
1316
+ case 3: C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3->qs);
1317
+ __builtin_vsx_disassemble_pair(c3, &C3);
1318
+ case 2: C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2->qs);
1319
+ __builtin_vsx_disassemble_pair(c2, &C2);
1320
+ case 1: C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1->qs);
1321
+ __builtin_vsx_disassemble_pair(c1, &C1);
1322
+ break;
1323
+ }
1324
+ t1 = vec_perm(c1[0], c2[0], swiz1);
1325
+ t2 = vec_perm(c1[0], c2[0], swiz2);
1326
+ t3 = vec_perm(c3[0], c4[0], swiz1);
1327
+ t4 = vec_perm(c3[0], c4[0], swiz2);
1328
+ t5 = vec_perm(t1, t3, swiz3);
1329
+ t6 = vec_perm(t1, t3, swiz4);
1330
+ t7 = vec_perm(t2, t4, swiz3);
1331
+ t8 = vec_perm(t2, t4, swiz4);
1332
+ if (flip == true) {
1333
+ t5 = vec_xor(t5, xor_vector);
1334
+ t6 = vec_xor(t6, xor_vector);
1335
+ t7 = vec_xor(t7, xor_vector);
1336
+ t8 = vec_xor(t8, xor_vector);
1337
+ }
1338
+ vec_xst(t5, 0, vecOffset);
1339
+ vec_xst(t6, 0, vecOffset+16);
1340
+ vec_xst(t7, 0, vecOffset+32);
1341
+ vec_xst(t8, 0, vecOffset+48);
1342
+
1343
+ t1 = vec_perm(c1[1], c2[1], swiz1);
1344
+ t2 = vec_perm(c1[1], c2[1], swiz2);
1345
+ t3 = vec_perm(c3[1], c4[1], swiz1);
1346
+ t4 = vec_perm(c3[1], c4[1], swiz2);
1347
+ t5 = vec_perm(t1, t3, swiz3);
1348
+ t6 = vec_perm(t1, t3, swiz4);
1349
+ t7 = vec_perm(t2, t4, swiz3);
1350
+ t8 = vec_perm(t2, t4, swiz4);
1351
+ if (flip == true) {
1352
+ t5 = vec_xor(t5, xor_vector);
1353
+ t6 = vec_xor(t6, xor_vector);
1354
+ t7 = vec_xor(t7, xor_vector);
1355
+ t8 = vec_xor(t8, xor_vector);
1356
+ }
1357
+ vec_xst(t5, 0, vecOffset+64);
1358
+ vec_xst(t6, 0, vecOffset+80);
1359
+ vec_xst(t7, 0, vecOffset+96);
1360
+ vec_xst(t8, 0, vecOffset+112);
1361
+
1362
+ aoffset1 += lda;
1363
+ aoffset2 += lda;
1364
+ aoffset3 += lda;
1365
+ vecOffset += 128;
1366
+ i--;
1367
+ } while(i > 0);
1368
+ }
1369
+ }
1370
+ }
1371
+
1372
+ void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
1373
+ int64_t mc, nc, mp, np;
1374
+ int m_rem = MIN(m - m0, 8);
1375
+ int n_rem = MIN(n - n0, 8);
1376
+ // TO-DO: KERNEL_16x8 and KERNEL_8x16 are having some performance
1377
+ // issues. After resolving them, below code will be enabled.
1378
+ /*if (m_rem >= 16 && n_rem >= 8) {
1379
+ mc = 16;
1380
+ nc = 8;
1381
+ gemm<16,8>(m0, m, n0, n);
1382
+ } else if(m_rem >= 8 && n_rem >= 16) {
1383
+ mc = 8;
1384
+ nc = 16;
1385
+ gemm<8,16>(m0, m, n0, n);
1386
+ }*/
1387
+ if (m_rem >= 8 && n_rem >= 8) {
1388
+ mc = 8;
1389
+ nc = 8;
1390
+ gemm<8,8>(m0, m, n0, n);
1391
+ } else if (m_rem >= 4 && n_rem >= 8) {
1392
+ mc = 4;
1393
+ nc = 8;
1394
+ gemm<4,8>(m0, m, n0, n);
1395
+ } else if (m_rem >= 8 && n_rem >= 4) {
1396
+ mc = 8;
1397
+ nc = 4;
1398
+ gemm<8,4>(m0, m, n0, n);
1399
+ } else if (m_rem >= 4 && n_rem >= 4) {
1400
+ mc = 4;
1401
+ nc = 4;
1402
+ gemm_small<4, 4>(m0, m, n0, n);
1403
+ } else if ((m_rem < 4) && (n_rem > 4)) {
1404
+ nc = 4;
1405
+ switch(m_rem) {
1406
+ case 1:
1407
+ mc = 1;
1408
+ gemm_small<1, 4>(m0, m, n0, n);
1409
+ break;
1410
+ case 2:
1411
+ mc = 2;
1412
+ gemm_small<2, 4>(m0, m, n0, n);
1413
+ break;
1414
+ case 3:
1415
+ mc = 3;
1416
+ gemm_small<3, 4>(m0, m, n0, n);
1417
+ break;
1418
+ default:
1419
+ return;
1420
+ }
1421
+ } else if ((m_rem > 4) && (n_rem < 4)) {
1422
+ mc = 4;
1423
+ switch(n_rem) {
1424
+ case 1:
1425
+ nc = 1;
1426
+ gemm_small<4, 1>(m0, m, n0, n);
1427
+ break;
1428
+ case 2:
1429
+ nc = 2;
1430
+ gemm_small<4, 2>(m0, m, n0, n);
1431
+ break;
1432
+ case 3:
1433
+ nc = 3;
1434
+ gemm_small<4, 3>(m0, m, n0, n);
1435
+ break;
1436
+ default:
1437
+ return;
1438
+ }
1439
+ } else {
1440
+ switch((m_rem << 4) | n_rem) {
1441
+ case 0x43:
1442
+ mc = 4;
1443
+ nc = 3;
1444
+ gemm_small<4, 3>(m0, m, n0, n);
1445
+ break;
1446
+ case 0x42:
1447
+ mc = 4;
1448
+ nc = 2;
1449
+ gemm_small<4, 2>(m0, m, n0, n);
1450
+ break;
1451
+ case 0x41:
1452
+ mc = 4;
1453
+ nc = 1;
1454
+ gemm_small<4, 1>(m0, m, n0, n);
1455
+ break;
1456
+ case 0x34:
1457
+ mc = 3;
1458
+ nc = 4;
1459
+ gemm_small<3, 4>(m0, m, n0, n);
1460
+ break;
1461
+ case 0x33:
1462
+ mc = 3;
1463
+ nc = 3;
1464
+ gemm_small<3, 3>(m0, m, n0, n);
1465
+ break;
1466
+ case 0x32:
1467
+ mc = 3;
1468
+ nc = 2;
1469
+ gemm_small<3, 2>(m0, m, n0, n);
1470
+ break;
1471
+ case 0x31:
1472
+ mc = 3;
1473
+ nc = 1;
1474
+ gemm_small<3, 1>(m0, m, n0, n);
1475
+ break;
1476
+ case 0x24:
1477
+ mc = 2;
1478
+ nc = 4;
1479
+ gemm_small<2, 4>(m0, m, n0, n);
1480
+ break;
1481
+ case 0x23:
1482
+ mc = 2;
1483
+ nc = 3;
1484
+ gemm_small<2, 3>(m0, m, n0, n);
1485
+ break;
1486
+ case 0x22:
1487
+ mc = 2;
1488
+ nc = 2;
1489
+ gemm_small<2, 2>(m0, m, n0, n);
1490
+ break;
1491
+ case 0x21:
1492
+ mc = 2;
1493
+ nc = 1;
1494
+ gemm_small<2, 1>(m0, m, n0, n);
1495
+ break;
1496
+ case 0x14:
1497
+ mc = 1;
1498
+ nc = 4;
1499
+ gemm_small<1, 4>(m0, m, n0, n);
1500
+ break;
1501
+ case 0x13:
1502
+ mc = 1;
1503
+ nc = 3;
1504
+ gemm_small<1, 3>(m0, m, n0, n);
1505
+ break;
1506
+ case 0x12:
1507
+ mc = 1;
1508
+ nc = 2;
1509
+ gemm_small<1, 2>(m0, m, n0, n);
1510
+ break;
1511
+ case 0x11:
1512
+ mc = 1;
1513
+ nc = 1;
1514
+ gemm_small<1, 1>(m0, m, n0, n);
1515
+ break;
1516
+ default:
1517
+ return;
1518
+ }
1519
+ }
1520
+ mp = m0 + (m - m0) / mc * mc;
1521
+ np = n0 + (n - n0) / nc * nc;
1522
+ mnpack(mp, m, n0, np);
1523
+ mnpack(m0, m, np, n);
1524
+ }
1525
+
1526
+ void KERNEL_4x8(int64_t ii, int64_t jj) {
1527
+ vec_t vec_A[8], vec_B[16] = {0};
1528
+ acc_t acc_0, acc_1;
1529
+ std::array<int, 4> comparray;
1530
+ vector float fin_res[8] = {0};
1531
+ vector float vs[8] = {0};
1532
+ for (int l = 0; l < k; l++) {
1533
+ __builtin_mma_xxsetaccz(&acc_0);
1534
+ __builtin_mma_xxsetaccz(&acc_1);
1535
+ packNormal<int8_t, vector signed char>((A+(ii*lda)+l), lda, 4, 8, (int8_t*)vec_A, false);
1536
+ packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true);
1537
+ for(int x = 0; x < 8; x++) {
1538
+ __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
1539
+ __builtin_mma_xvi8ger4pp(&acc_1, vec_A[x], vec_B[x+8]);
1540
+ }
1541
+ for (int I = 0; I<4; I++) {
1542
+ for (int J = 0; J<4; J++) {
1543
+ *((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
1544
+ *((float*)&vs[I+4]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d));
1545
+ }
1546
+ }
1547
+ auto aoffset = A+(ii*lda)+l;
1548
+ for (int i = 0; i < 4; i++) {
1549
+ comparray[i] = 0;
1550
+ int ca = 0;
1551
+ const int8_t *at = aoffset->qs;
1552
+ for (int j = 0; j < 32; j++)
1553
+ ca += (int)*at++;
1554
+ comparray[i] = ca;
1555
+ aoffset += lda;
1556
+ }
1557
+ compute<4>(&acc_0, 0, 0, comparray, vs, fin_res);
1558
+ compute<4>(&acc_1, 0, 4, comparray, vs, fin_res);
1559
+ }
1560
+ save_res<4, 4>(ii, jj, 0, fin_res);
1561
+ save_res<4, 4>(ii, jj+4, 4, fin_res);
1562
+ }
1563
+
1564
+ void KERNEL_8x4(int64_t ii, int64_t jj) {
1565
+ vec_t vec_A[16], vec_B[8] = {0};
1566
+ acc_t acc_0, acc_1;
1567
+ std::array<int, 8> comparray;
1568
+ vector float fin_res[8] = {0};
1569
+ vector float vs[8] = {0};
1570
+ for (int l = 0; l < k; l++) {
1571
+ __builtin_mma_xxsetaccz(&acc_0);
1572
+ __builtin_mma_xxsetaccz(&acc_1);
1573
+ packNormal<int8_t, vector signed char>((A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
1574
+ packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 4, 8, (uint8_t*)vec_B, true);
1575
+ for(int x = 0; x < 8; x++) {
1576
+ __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
1577
+ __builtin_mma_xvi8ger4pp(&acc_1, vec_A[x+8], vec_B[x]);
1578
+ }
1579
+ for (int I = 0; I<8; I++) {
1580
+ for (int J = 0; J<4; J++) {
1581
+ *((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
1582
+ }
1583
+ }
1584
+ auto aoffset = A+(ii*lda)+l;
1585
+ for (int i = 0; i < 8; i++) {
1586
+ comparray[i] = 0;
1587
+ int ca = 0;
1588
+ const int8_t *at = aoffset->qs;
1589
+ for (int j = 0; j < 32; j++)
1590
+ ca += (int)*at++;
1591
+ comparray[i] = ca;
1592
+ aoffset += lda;
1593
+ }
1594
+ compute<8>(&acc_0, 0, 0, comparray, vs, fin_res);
1595
+ compute<8>(&acc_1, 4, 4, comparray, vs, fin_res);
1596
+ }
1597
+ save_res<4, 4>(ii, jj, 0, fin_res);
1598
+ save_res<4, 4>(ii+4, jj, 4, fin_res);
1599
+ }
1600
+
1601
+ void KERNEL_8x8(int64_t ii, int64_t jj) {
1602
+ vec_t vec_A[16], vec_B[16] = {0};
1603
+ acc_t acc_0, acc_1, acc_2, acc_3;
1604
+ std::array<int, 8> comparray;
1605
+ vector float fin_res[16] = {0};
1606
+ vector float vs[16] = {0};
1607
+ for (int l = 0; l < k; l++) {
1608
+ __builtin_mma_xxsetaccz(&acc_0);
1609
+ __builtin_mma_xxsetaccz(&acc_1);
1610
+ __builtin_mma_xxsetaccz(&acc_2);
1611
+ __builtin_mma_xxsetaccz(&acc_3);
1612
+ packNormal<int8_t, vector signed char>((A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
1613
+ packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true);
1614
+ for(int x = 0; x < 8; x++) {
1615
+ __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
1616
+ __builtin_mma_xvi8ger4pp(&acc_1, vec_A[x+8], vec_B[x]);
1617
+ __builtin_mma_xvi8ger4pp(&acc_2, vec_A[x], vec_B[x+8]);
1618
+ __builtin_mma_xvi8ger4pp(&acc_3, vec_A[x+8], vec_B[x+8]);
1619
+ }
1620
+ for (int I = 0; I<8; I++) {
1621
+ for (int J = 0; J<4; J++) {
1622
+ *((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
1623
+ *((float*)&vs[I+8]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d));
1624
+ }
1625
+ }
1626
+ auto aoffset = A+(ii*lda)+l;
1627
+ for (int i = 0; i < 8; i++) {
1628
+ comparray[i] = 0;
1629
+ int ca = 0;
1630
+ const int8_t *at = aoffset->qs;
1631
+ for (int j = 0; j < 32; j++)
1632
+ ca += (int)*at++;
1633
+ comparray[i] = ca;
1634
+ aoffset += lda;
1635
+ }
1636
+ compute<8>(&acc_0, 0, 0, comparray, vs, fin_res);
1637
+ compute<8>(&acc_1, 4, 4, comparray, vs, fin_res);
1638
+ compute<8>(&acc_2, 0, 8, comparray, vs, fin_res);
1639
+ compute<8>(&acc_3, 4, 12, comparray, vs, fin_res);
1640
+ }
1641
+ save_res<4, 4>(ii, jj, 0, fin_res);
1642
+ save_res<4, 4>(ii+4, jj, 4, fin_res);
1643
+ save_res<4, 4>(ii, jj+4, 8, fin_res);
1644
+ save_res<4, 4>(ii+4, jj+4, 12, fin_res);
1645
+ }
1646
+
1647
+ template<int RM, int RN>
1648
+ void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n) {
1649
+ int64_t ytiles = (m - m0) / RM;
1650
+ int64_t xtiles = (n - n0) / RN;
1651
+ int64_t tiles = xtiles * ytiles;
1652
+ int64_t duty = (tiles + nth - 1) / nth;
1653
+ int64_t start = duty * ith;
1654
+ int64_t end = start + duty;
1655
+ vec_t vec_A[8], vec_B[8] = {0};
1656
+ vector signed int vec_C[4];
1657
+ acc_t acc_0;
1658
+
1659
+ if (end > tiles)
1660
+ end = tiles;
1661
+ for (int64_t job = start; job < end; ++job) {
1662
+ int64_t ii = m0 + job / xtiles * RM;
1663
+ int64_t jj = n0 + job % xtiles * RN;
1664
+ std::array<int, RM> comparray;
1665
+ vector float res[4] = {0};
1666
+ vector float fin_res[4] = {0};
1667
+ vector float vs[4] = {0};
1668
+ vector float CA[4] = {0};
1669
+ __builtin_prefetch((A+(ii*lda)+0)->qs, 0, 1); // prefetch first value
1670
+ __builtin_prefetch((B+(jj*ldb)+0)->qs, 0, 1); // prefetch first value
1671
+ for (int l = 0; l < k; l++) {
1672
+ __builtin_prefetch((A+(ii*lda)+(l+1))->qs, 0, 1); // prefetch one loop ahead
1673
+ __builtin_prefetch((B+(jj*ldb)+(l+1))->qs, 0, 1); // prefetch one loop ahead
1674
+ __builtin_mma_xxsetaccz(&acc_0);
1675
+ packNormal<int8_t, vector signed char>((A+(ii*lda)+l), lda, RM, 8, (int8_t*)vec_A, false);
1676
+ packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, RN, 8, (uint8_t*)vec_B, true);
1677
+ for(int x = 0; x < 8; x+=4) {
1678
+ __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
1679
+ __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+1], vec_B[x+1]);
1680
+ __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+2], vec_B[x+2]);
1681
+ __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+3], vec_B[x+3]);
1682
+ }
1683
+ for (int I = 0; I<RM; I++) {
1684
+ for (int J = 0; J<RN; J++) {
1685
+ *((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
1686
+ }
1687
+ }
1688
+ __builtin_mma_disassemble_acc(vec_C, &acc_0);
1689
+ auto aoffset = A+(ii*lda)+l;
1690
+ for (int i = 0; i < RM; i++) {
1691
+ comparray[i] = 0;
1692
+ int ca = 0;
1693
+ const int8_t *at = aoffset->qs;
1694
+ for (int j = 0; j < 32; j++)
1695
+ ca += (int)*at++;
1696
+ comparray[i] = ca;
1697
+ aoffset += lda;
1698
+ }
1699
+
1700
+ for (int i = 0; i < RM; i++) {
1701
+ CA[i] = vec_splats((float)(((double)comparray[i]) * -128.0));
1702
+ res[i] = vec_add(vec_ctf(vec_C[i], 0), CA[i]);
1703
+ fin_res[i] = vec_madd(res[i], vs[i], fin_res[i]);
1704
+ }
1705
+ }
1706
+ save_res<RM, RN>(ii, jj, 0, fin_res);
1707
+ }
1708
+ }
1709
+
1710
+ template<int RM, int RN>
1711
+ inline void kernel(int64_t ii, int64_t jj) {
1712
+ if constexpr(RM == 4 && RN == 8) {
1713
+ KERNEL_4x8(ii,jj);
1714
+ } else if constexpr(RM == 8 && RN == 4) {
1715
+ KERNEL_8x4(ii,jj);
1716
+ } else if constexpr(RM == 8 && RN == 8) {
1717
+ KERNEL_8x8(ii,jj);
1718
+ } else {
1719
+ static_assert(false, "RN/RM values not supported");
1720
+ }
1721
+ }
1722
+
1723
+ template <int RM, int RN>
1724
+ NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
1725
+ int64_t ytiles = (m - m0) / RM;
1726
+ int64_t xtiles = (n - n0) / RN;
1727
+ int64_t tiles = xtiles * ytiles;
1728
+ int64_t duty = (tiles + nth - 1) / nth;
1729
+ int64_t start = duty * ith;
1730
+ int64_t end = start + duty;
1731
+ if (end > tiles)
1732
+ end = tiles;
1733
+ for (int64_t job = start; job < end; ++job) {
1734
+ int64_t ii = m0 + job / xtiles * RM;
1735
+ int64_t jj = n0 + job % xtiles * RN;
1736
+ kernel<RM, RN>(ii, jj);
1737
+ }
1738
+ }
1739
+
1740
+ const TA *const A;
1741
+ const TB *const B;
1742
+ TC *C;
1743
+ TA *At;
1744
+ TB *Bt;
1745
+ const int64_t k;
1746
+ const int64_t lda;
1747
+ const int64_t ldb;
1748
+ const int64_t ldc;
1749
+ const int ith;
1750
+ const int nth;
1751
+ };
1752
+
1753
  template <typename TA, typename TB, typename TC>
1754
  class tinyBLAS_PPC {
1755
  public:
 
1769
 
1770
  void (tinyBLAS_PPC::*kernel)(int64_t, int64_t);
1771
 
1772
+ template<typename VA>
1773
+ void packTranspose(const TA* a, int64_t lda, int rows, int cols, TA* vec) {
1774
  int64_t i, j;
1775
+ TA *aoffset = NULL, *boffset = NULL;
1776
+ TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
1777
+ TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
1778
+ __vector_pair C1, C2, C3, C4, C5, C6, C7, C8;
1779
+ VA c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0};
1780
+ VA c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0};
1781
+ VA t1, t2, t3, t4, t5, t6, t7, t8;
1782
+ aoffset = const_cast<TA*>(a);
1783
  boffset = vec;
1784
  j = (rows >> 3);
1785
  if (j > 0) {
 
1795
  aoffset += 8 * lda;
1796
  i = (cols >> 3);
1797
  if (i > 0) {
 
 
 
1798
  do {
1799
  C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1);
1800
  C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2);
 
1874
  } while(i > 0);
1875
  }
1876
  if (cols & 4) {
1877
+ c1[0] = vec_xl(0, aoffset1);
1878
+ c2[0] = vec_xl(0, aoffset2);
1879
+ c3[0] = vec_xl(0, aoffset3);
1880
+ c4[0] = vec_xl(0, aoffset4);
1881
+ c5[0] = vec_xl(0, aoffset5);
1882
+ c6[0] = vec_xl(0, aoffset6);
1883
+ c7[0] = vec_xl(0, aoffset7);
1884
+ c8[0] = vec_xl(0, aoffset8);
1885
+
1886
+ t1 = vec_mergeh(c1[0], c2[0]);
1887
+ t2 = vec_mergeh(c3[0], c4[0]);
1888
+ t3 = vec_mergeh(c5[0], c6[0]);
1889
+ t4 = vec_mergeh(c7[0], c8[0]);
 
 
1890
  t5 = vec_xxpermdi(t1, t2, 0);
1891
  t6 = vec_xxpermdi(t3, t4, 0);
1892
  t7 = vec_xxpermdi(t1, t2, 3);
 
1896
  vec_xst(t7, 0, boffset+8);
1897
  vec_xst(t8, 0, boffset+12);
1898
 
1899
+ t1 = vec_mergel(c1[0], c2[0]);
1900
+ t2 = vec_mergel(c3[0], c4[0]);
1901
+ t3 = vec_mergel(c5[0], c6[0]);
1902
+ t4 = vec_mergel(c7[0], c8[0]);
1903
  t5 = vec_xxpermdi(t1, t2, 0);
1904
  t6 = vec_xxpermdi(t3, t4, 0);
1905
  t7 = vec_xxpermdi(t1, t2, 3);
 
1921
  aoffset += 4 * lda;
1922
  i = (cols >> 3);
1923
  if (i > 0) {
 
 
 
1924
  do {
1925
  C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1);
1926
  C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2);
 
1967
  }
1968
 
1969
  if (cols & 4) {
1970
+ c1[0] = vec_xl(0, aoffset1);
1971
+ c2[0] = vec_xl(0, aoffset2);
1972
+ c3[0] = vec_xl(0, aoffset3);
1973
+ c4[0] = vec_xl(0, aoffset4);
1974
+
1975
+ t1 = vec_mergeh(c1[0], c2[0]);
1976
+ t2 = vec_mergeh(c3[0], c4[0]);
 
 
1977
  t3 = vec_xxpermdi(t1, t2, 0);
1978
  t4 = vec_xxpermdi(t1, t2, 3);
1979
  vec_xst(t3, 0, boffset);
1980
  vec_xst(t4, 0, boffset+4);
1981
 
1982
+ t1 = vec_mergel(c1[0], c2[0]);
1983
+ t2 = vec_mergel(c3[0], c4[0]);
1984
  t3 = vec_xxpermdi(t1, t2, 0);
1985
  t4 = vec_xxpermdi(t1, t2, 3);
1986
  vec_xst(t3, 0, boffset+8);
 
1992
  aoffset2 = aoffset1 + lda;
1993
  aoffset3 = aoffset2 + lda;
1994
  if (cols & 4) {
1995
+ c1[0] = vec_xl(0, aoffset1);
1996
+ c2[0] = vec_xl(0, aoffset2);
1997
+ c3[0] = vec_xl(0, aoffset3);
1998
+
1999
+ t1 = vec_mergeh(c1[0], c2[0]);
2000
+ t2 = vec_mergeh(c3[0], c4[0]);
 
 
2001
  t3 = vec_xxpermdi(t1, t2, 0);
2002
  t4 = vec_xxpermdi(t1, t2, 3);
2003
  vec_xst(t3, 0, boffset);
2004
  vec_xst(t4, 0, boffset+4);
2005
 
2006
+ t1 = vec_mergel(c1[0], c2[0]);
2007
+ t2 = vec_mergel(c3[0], c4[0]);
2008
  t3 = vec_xxpermdi(t1, t2, 0);
2009
  t4 = vec_xxpermdi(t1, t2, 3);
2010
  vec_xst(t3, 0, boffset+8);
 
2012
  }
2013
  }
2014
  }
 
2015
  void KERNEL_4x4(int64_t ii, int64_t jj) {
2016
  vec_t vec_A[4], vec_B[4], vec_C[4];
2017
  acc_t acc_0;
2018
  __builtin_mma_xxsetaccz(&acc_0);
2019
  for (int l = 0; l < k; l+=4) {
2020
+ packTranspose<vector float>(A+(ii*lda)+l, lda, 4, 4, (TA*)vec_A);
2021
+ packTranspose<vector float>(B+(jj*ldb)+l, ldb, 4, 4, (TA*)vec_B);
2022
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
2023
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
2024
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[2], vec_B[2]);
 
2033
  __builtin_mma_xxsetaccz(&acc_0);
2034
  __builtin_mma_xxsetaccz(&acc_1);
2035
  for (int64_t l = 0; l < k; l+=4) {
2036
+ packTranspose<vector float>(A+(ii*lda)+l, lda, 4, 4, (TA*)vec_A);
2037
+ packTranspose<vector float>(B+(jj*ldb)+l, ldb, 8, 4, (TA*)vec_B);
2038
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], (vec_t)vec_B[0]);
2039
  __builtin_mma_xvf32gerpp(&acc_1, vec_A[0], (vec_t)vec_B[1]);
2040
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], (vec_t)vec_B[2]);
 
2054
  __builtin_mma_xxsetaccz(&acc_0);
2055
  __builtin_mma_xxsetaccz(&acc_1);
2056
  for (int64_t l = 0; l < k; l+=4) {
2057
+ packTranspose<vector float>(A+(ii*lda)+l, lda, 8, 4, (TA*)vec_A);
2058
+ packTranspose<vector float>(B+(jj*ldb)+l, ldb, 4, 4, (TA*)vec_B);
2059
  __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[0], vec_B[0]);
2060
  __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[1], vec_B[0]);
2061
  __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[2], vec_B[1]);
 
2077
  __builtin_mma_xxsetaccz(&acc_2);
2078
  __builtin_mma_xxsetaccz(&acc_3);
2079
  for (int l = 0; l < k; l+=8) {
2080
+ packTranspose<vector float>(A+(ii*lda)+l, lda, 8, 8, (TA*)vec_A);
2081
+ packTranspose<vector float>(B+(jj*ldb)+l, ldb, 8, 8, (TA*)vec_B);
2082
  for(int x = 0; x < 16; x+=2) {
2083
  __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[x], vec_B[x]);
2084
  __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[x], vec_B[x+1]);
 
2261
  vec_t vec_A[4], vec_B[4];
2262
  for (int l=0; l<k; l+=4) {
2263
  if (RN >= 4 && RM == 1) {
2264
+ TA* a = const_cast<TA*>(A+(ii)*lda+l);
2265
+ packTranspose<vector float>(B+(jj*ldb)+l, ldb, 4, 4, (TA*)vec_B);
2266
  vec_A[0] = (vec_t)vec_xl(0,a);
2267
+ vec_A[1] = (vec_t)vec_splats(*((TA*)&vec_A+1));
2268
+ vec_A[2] = (vec_t)vec_splats(*((TA*)&vec_A+2));
2269
+ vec_A[3] = (vec_t)vec_splats(*((TA*)&vec_A+3));
2270
  } else {
2271
+ packTranspose<vector float>(A+(ii*lda)+l, lda, RM, 4, (TA*)vec_A);
2272
+ packTranspose<vector float>(B+(jj*ldb)+l, ldb, RN, 4, (TA*)vec_B);
2273
  }
2274
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
2275
  __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
 
2279
  __builtin_mma_disassemble_acc(vec_C, &acc_0);
2280
  for (int I = 0; I < RM; I++) {
2281
  for (int J = 0; J < RN; J++) {
2282
+ *((TC*)(C+ii+((jj+J)*ldc)+I)) = *((TC*)&vec_C[I]+J);
2283
  }
2284
  }
2285
  }
 
2502
  params->ith, params->nth};
2503
  tb.matmul(m, n);
2504
  return true;
2505
+
2506
+ #elif defined(__MMA__)
2507
+ if (n < 8 && n != 4)
2508
+ return false;
2509
+ if (m < 8 && m != 4)
2510
+ return false;
2511
+ tinyBLAS_Q0_PPC<block_q8_0, block_q8_0, float> tb{
2512
+ k, (const block_q8_0 *)A, lda,
2513
+ (const block_q8_0 *)B, ldb,
2514
+ (float *)C, ldc,
2515
+ params->ith, params->nth};
2516
+ tb.matmul(m, n);
2517
+ return true;
2518
+
2519
  #else
2520
  return false;
2521
  #endif