junchao-loongson Jinyang He commited on
Commit
9794ea7
·
1 Parent(s): cf52931

ggml : add loongarch lsx and lasx support (llama/6454)

Browse files

* add loongarch lsx and lasx optimize code

* Add loongarch compilation support to makefile

* revert stb_image.h

* opt bytes_from_nibbles_32 and sum_i16_pairs_float

* fix undeclared

* format code

* update

* update 2

---------

Co-authored-by: Jinyang He <[email protected]>

Files changed (3) hide show
  1. ggml-impl.h +28 -0
  2. ggml-quants.c +0 -0
  3. ggml.c +189 -0
ggml-impl.h CHANGED
@@ -455,6 +455,34 @@ static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
455
  #include <riscv_vector.h>
456
  #endif
457
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
458
  #ifdef __F16C__
459
 
460
  #ifdef _MSC_VER
 
455
  #include <riscv_vector.h>
456
  #endif
457
 
458
+ #if defined(__loongarch64)
459
+ #if defined(__loongarch_asx)
460
+ #include <lasxintrin.h>
461
+ #endif
462
+ #if defined(__loongarch_sx)
463
+ #include <lsxintrin.h>
464
+ #endif
465
+ #endif
466
+
467
+ #if defined(__loongarch_asx)
468
+
469
+ typedef union {
470
+ int32_t i;
471
+ float f;
472
+ } ft_union;
473
+
474
+ /* float type data load instructions */
475
+ static __m128 __lsx_vreplfr2vr_s(float val) {
476
+ ft_union fi_tmpval = {.f = val};
477
+ return (__m128)__lsx_vreplgr2vr_w(fi_tmpval.i);
478
+ }
479
+
480
+ static __m256 __lasx_xvreplfr2vr_s(float val) {
481
+ ft_union fi_tmpval = {.f = val};
482
+ return (__m256)__lasx_xvreplgr2vr_w(fi_tmpval.i);
483
+ }
484
+ #endif
485
+
486
  #ifdef __F16C__
487
 
488
  #ifdef _MSC_VER
ggml-quants.c CHANGED
The diff for this file is too large to render. See raw diff
 
ggml.c CHANGED
@@ -1523,6 +1523,195 @@ static inline void __sse_f16x4_store(ggml_fp16_t *x, __m128 y) {
1523
  #define GGML_F16_VEC_MUL GGML_F32Cx4_MUL
1524
  #define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE
1525
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1526
  #endif
1527
 
1528
  // GGML_F32_ARR / GGML_F16_ARR
 
1523
  #define GGML_F16_VEC_MUL GGML_F32Cx4_MUL
1524
  #define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE
1525
 
1526
+ #elif defined(__loongarch_asx)
1527
+
1528
+ #define GGML_SIMD
1529
+
1530
+ // F32 LASX
1531
+ #define GGML_F32_STEP 32
1532
+ #define GGML_F32_EPR 8
1533
+
1534
+ #define GGML_F32x8 __m256
1535
+ #define GGML_F32x8_ZERO (__m256)__lasx_xvldi(0)
1536
+ #define GGML_F32x8_SET1(x) (__m256)__lasx_xvreplfr2vr_s((x))
1537
+ #define GGML_F32x8_LOAD(x) (__m256)__lasx_xvld((x), 0)
1538
+ #define GGML_F32x8_STORE(x,y) __lasx_xvst((y), (x), 0)
1539
+ #define GGML_F32x8_FMA(a, b, c) __lasx_xvfmadd_s(b, c, a)
1540
+ #define GGML_F32x8_ADD __lasx_xvfadd_s
1541
+ #define GGML_F32x8_MUL __lasx_xvfmul_s
1542
+ #define GGML_F32x8_REDUCE(res, x) \
1543
+ do { \
1544
+ int offset = GGML_F32_ARR >> 1; \
1545
+ for (int i = 0; i < offset; ++i) { \
1546
+ x[i] = __lasx_xvfadd_s(x[i], x[offset+i]); \
1547
+ } \
1548
+ offset >>= 1; \
1549
+ for (int i = 0; i < offset; ++i) { \
1550
+ x[i] = __lasx_xvfadd_s(x[i], x[offset+i]); \
1551
+ } \
1552
+ offset >>= 1; \
1553
+ for (int i = 0; i < offset; ++i) { \
1554
+ x[i] = __lasx_xvfadd_s(x[i], x[offset+i]); \
1555
+ } \
1556
+ float *tmp_p = (float *)&x[0]; \
1557
+ res = tmp_p[0] + tmp_p[1] + tmp_p[2] + tmp_p[3] + tmp_p[4] + tmp_p[5] + tmp_p[6] + tmp_p[7]; \
1558
+ } while (0)
1559
+ // TODO: is this optimal ?
1560
+
1561
+ #define GGML_F32_VEC GGML_F32x8
1562
+ #define GGML_F32_VEC_ZERO GGML_F32x8_ZERO
1563
+ #define GGML_F32_VEC_SET1 GGML_F32x8_SET1
1564
+ #define GGML_F32_VEC_LOAD GGML_F32x8_LOAD
1565
+ #define GGML_F32_VEC_STORE GGML_F32x8_STORE
1566
+ #define GGML_F32_VEC_FMA GGML_F32x8_FMA
1567
+ #define GGML_F32_VEC_ADD GGML_F32x8_ADD
1568
+ #define GGML_F32_VEC_MUL GGML_F32x8_MUL
1569
+ #define GGML_F32_VEC_REDUCE GGML_F32x8_REDUCE
1570
+
1571
+ // F16 LASX
1572
+
1573
+ #define GGML_F16_STEP 32
1574
+ #define GGML_F16_EPR 8
1575
+
1576
+ // F16 arithmetic is not supported by AVX, so we use F32 instead
1577
+
1578
+ #define GGML_F32Cx8 __m256
1579
+ #define GGML_F32Cx8_ZERO (__m256)__lasx_xvldi(0)
1580
+ #define GGML_F32Cx8_SET1(x) (__m256)__lasx_xvreplgr2vr_w((x))
1581
+
1582
+ static inline __m256 __lasx_f32cx8_load(ggml_fp16_t *x) {
1583
+ float tmp[8];
1584
+
1585
+ for (int i = 0; i < 8; i++) {
1586
+ tmp[i] = GGML_FP16_TO_FP32(x[i]);
1587
+ }
1588
+
1589
+ return (__m256)__lasx_xvld(tmp, 0);
1590
+ }
1591
+ static inline void __lasx_f32cx8_store(ggml_fp16_t *x, __m256 y) {
1592
+ float arr[8];
1593
+
1594
+ __lasx_xvst(y, arr, 0);
1595
+
1596
+ for (int i = 0; i < 8; i++)
1597
+ x[i] = GGML_FP32_TO_FP16(arr[i]);
1598
+ }
1599
+ #define GGML_F32Cx8_LOAD(x) __lasx_f32cx8_load(x)
1600
+ #define GGML_F32Cx8_STORE(x, y) __lasx_f32cx8_store(x, y)
1601
+
1602
+ #define GGML_F32Cx8_FMA GGML_F32x8_FMA
1603
+ #define GGML_F32Cx8_ADD __lasx_xvfadd_s
1604
+ #define GGML_F32Cx8_MUL __lasx_xvfmul_s
1605
+ #define GGML_F32Cx8_REDUCE GGML_F32x8_REDUCE
1606
+
1607
+ #define GGML_F16_VEC GGML_F32Cx8
1608
+ #define GGML_F16_VEC_ZERO GGML_F32Cx8_ZERO
1609
+ #define GGML_F16_VEC_SET1 GGML_F32Cx8_SET1
1610
+ #define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx8_LOAD(p)
1611
+ #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx8_STORE(p, r[i])
1612
+ #define GGML_F16_VEC_FMA GGML_F32Cx8_FMA
1613
+ #define GGML_F16_VEC_ADD GGML_F32Cx8_ADD
1614
+ #define GGML_F16_VEC_MUL GGML_F32Cx8_MUL
1615
+ #define GGML_F16_VEC_REDUCE GGML_F32Cx8_REDUCE
1616
+
1617
+ #elif defined(__loongarch_sx)
1618
+
1619
+ #define GGML_SIMD
1620
+
1621
+ // F32 LSX
1622
+
1623
+ #define GGML_F32_STEP 32
1624
+ #define GGML_F32_EPR 4
1625
+
1626
+ #define GGML_F32x4 __m128
1627
+ #define GGML_F32x4_ZERO __lsx_vldi(0)
1628
+ #define GGML_F32x4_SET1(x) __lsx_vinsgr2vr_w(__lsx_vldi(0),(x), 0)
1629
+ #define GGML_F32x4_LOAD(x) __lsx_vld((x), 0)
1630
+ #define GGML_F32x4_STORE((x),(y)) __lsx_vst((y), (x), 0)
1631
+ #define GGML_F32x4_FMA(a, b, c) __lsx_vfmadd_s(b, c, a)
1632
+ #define GGML_F32x4_ADD __lsx_vfadd_s
1633
+ #define GGML_F32x4_MUL __lsx_vfmul_s
1634
+ #define GGML_F32x4_REDUCE(res, x) \
1635
+ { \
1636
+ int offset = GGML_F32_ARR >> 1; \
1637
+ for (int i = 0; i < offset; ++i) { \
1638
+ x[i] = __lsx_vfadd_s(x[i], x[offset+i]); \
1639
+ } \
1640
+ offset >>= 1; \
1641
+ for (int i = 0; i < offset; ++i) { \
1642
+ x[i] = __lsx_vfadd_s(x[i], x[offset+i]); \
1643
+ } \
1644
+ offset >>= 1; \
1645
+ for (int i = 0; i < offset; ++i) { \
1646
+ x[i] = __lsx_vfadd_s(x[i], x[offset+i]); \
1647
+ } \
1648
+ __m128i tmp = __lsx_vsrli_d((__m128i)x[0], 32); \
1649
+ tmp = (__m128i)__lsx_vfadd_s((__m128)tmp, x[0]); \
1650
+ tmp = __lsx_vpickev_w(__lsx_vldi(0), tmp); \
1651
+ const __m128 t0 = __lsx_vshuf4i_w(tmp, 0x88); \
1652
+ tmp = __lsx_vsrli_d((__m128i)t0, 32); \
1653
+ tmp = (__m128i)__lsx_vfadd_s((__m128)tmp, t0); \
1654
+ tmp = __lsx_vpickev_w(__lsx_vldi(0), tmp); \
1655
+ res = (ggml_float) __lsx_vpickve2gr_w(__lsx_vshuf4i_w(tmp, 0x88), 0); \
1656
+ }
1657
+
1658
+ #define GGML_F32_VEC GGML_F32x4
1659
+ #define GGML_F32_VEC_ZERO GGML_F32x4_ZERO
1660
+ #define GGML_F32_VEC_SET1 GGML_F32x4_SET1
1661
+ #define GGML_F32_VEC_LOAD GGML_F32x4_LOAD
1662
+ #define GGML_F32_VEC_STORE GGML_F32x4_STORE
1663
+ #define GGML_F32_VEC_FMA GGML_F32x4_FMA
1664
+ #define GGML_F32_VEC_ADD GGML_F32x4_ADD
1665
+ #define GGML_F32_VEC_MUL GGML_F32x4_MUL
1666
+ #define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
1667
+
1668
+ // F16 LSX
1669
+
1670
+ #define GGML_F16_STEP 32
1671
+ #define GGML_F16_EPR 4
1672
+
1673
+ static inline __m128 __lsx_f16x4_load(ggml_fp16_t *x) {
1674
+ float tmp[4];
1675
+
1676
+ tmp[0] = GGML_FP16_TO_FP32(x[0]);
1677
+ tmp[1] = GGML_FP16_TO_FP32(x[1]);
1678
+ tmp[2] = GGML_FP16_TO_FP32(x[2]);
1679
+ tmp[3] = GGML_FP16_TO_FP32(x[3]);
1680
+
1681
+ return __lsx_vld(tmp, 0);
1682
+ }
1683
+
1684
+ static inline void __lsx_f16x4_store(ggml_fp16_t *x, __m128 y) {
1685
+ float arr[4];
1686
+
1687
+ __lsx_vst(y, arr, 0);
1688
+
1689
+ x[0] = GGML_FP32_TO_FP16(arr[0]);
1690
+ x[1] = GGML_FP32_TO_FP16(arr[1]);
1691
+ x[2] = GGML_FP32_TO_FP16(arr[2]);
1692
+ x[3] = GGML_FP32_TO_FP16(arr[3]);
1693
+ }
1694
+
1695
+ #define GGML_F32Cx4 __m128
1696
+ #define GGML_F32Cx4_ZERO __lsx_vldi(0)
1697
+ #define GGML_F32Cx4_SET1(x) __lsx_vinsgr2vr_w(__lsx_vldi(0),(x), 0)
1698
+ #define GGML_F32Cx4_LOAD(x) __lsx_f16x4_load(x)
1699
+ #define GGML_F32Cx4_STORE(x, y) __lsx_f16x4_store(x, y)
1700
+ #define GGML_F32Cx4_FMA GGML_F32x4_FMA
1701
+ #define GGML_F32Cx4_ADD __lsx_vfadd_s
1702
+ #define GGML_F32Cx4_MUL __lsx_vfmul_s
1703
+ #define GGML_F32Cx4_REDUCE GGML_F32x4_REDUCE
1704
+
1705
+ #define GGML_F16_VEC GGML_F32Cx4
1706
+ #define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO
1707
+ #define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1
1708
+ #define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx4_LOAD(p)
1709
+ #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE(p, r[i])
1710
+ #define GGML_F16_VEC_FMA GGML_F32Cx4_FMA
1711
+ #define GGML_F16_VEC_ADD GGML_F32Cx4_ADD
1712
+ #define GGML_F16_VEC_MUL GGML_F32Cx4_MUL
1713
+ #define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE
1714
+
1715
  #endif
1716
 
1717
  // GGML_F32_ARR / GGML_F16_ARR