Spaces:
Running
Running
uvos
commited on
Commit
·
72c6f1d
1
Parent(s):
82bb7f3
HIP: Supress transformation warning in softmax.cu
Browse filesloops with bounds not known at compile time can not be unrolled.
when ncols_template == 0, the bounds of the loop are not constexpr, thus llvm cant unroll the loops here.
ggml/src/ggml-cuda/softmax.cu
CHANGED
|
@@ -13,6 +13,12 @@ __device__ float __forceinline__ t2f32<half>(half val) {
|
|
| 13 |
return __half2float(val);
|
| 14 |
}
|
| 15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
template <bool use_shared, int ncols_template, int block_size_template, typename T>
|
| 17 |
static __global__ void soft_max_f32(
|
| 18 |
const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y,
|
|
@@ -118,6 +124,9 @@ static __global__ void soft_max_f32(
|
|
| 118 |
dst[col] = vals[col] * inv_sum;
|
| 119 |
}
|
| 120 |
}
|
|
|
|
|
|
|
|
|
|
| 121 |
|
| 122 |
static __global__ void soft_max_back_f32(
|
| 123 |
const float * grad, const float * dstf, float * dst, const int ncols, const float scale) {
|
|
|
|
| 13 |
return __half2float(val);
|
| 14 |
}
|
| 15 |
|
| 16 |
+
// When ncols_template == 0 the bounds for the loops in this function are not known and can't be unrolled.
|
| 17 |
+
// As we want to keep pragma unroll for all other cases we supress the clang transformation warning here.
|
| 18 |
+
#ifdef __clang__
|
| 19 |
+
#pragma clang diagnostic push
|
| 20 |
+
#pragma clang diagnostic ignored "-Wpass-failed"
|
| 21 |
+
#endif
|
| 22 |
template <bool use_shared, int ncols_template, int block_size_template, typename T>
|
| 23 |
static __global__ void soft_max_f32(
|
| 24 |
const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y,
|
|
|
|
| 124 |
dst[col] = vals[col] * inv_sum;
|
| 125 |
}
|
| 126 |
}
|
| 127 |
+
#ifdef __clang__
|
| 128 |
+
#pragma clang diagnostic pop
|
| 129 |
+
#endif
|
| 130 |
|
| 131 |
static __global__ void soft_max_back_f32(
|
| 132 |
const float * grad, const float * dstf, float * dst, const int ncols, const float scale) {
|