Spaces:
Running
Running
Commit
·
0be0329
1
Parent(s):
5885084
vulkan/cuda: Fix im2col when KW!=KH (llama/14789)
Browse filesThe tid is decomposed into "ow + ky*OW + kx*OW*KH". Change "ksize" to match.
ggml/src/ggml-cuda/im2col.cu
CHANGED
|
@@ -10,7 +10,7 @@ static __global__ void im2col_kernel(
|
|
| 10 |
return;
|
| 11 |
}
|
| 12 |
|
| 13 |
-
const int64_t ksize = OW *
|
| 14 |
const int64_t kx = i / ksize;
|
| 15 |
const int64_t kd = kx * ksize;
|
| 16 |
const int64_t ky = (i - kd) / OW;
|
|
|
|
| 10 |
return;
|
| 11 |
}
|
| 12 |
|
| 13 |
+
const int64_t ksize = OW * KH;
|
| 14 |
const int64_t kx = i / ksize;
|
| 15 |
const int64_t kd = kx * ksize;
|
| 16 |
const int64_t ky = (i - kd) / OW;
|
ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp
CHANGED
|
@@ -40,12 +40,10 @@ void main() {
|
|
| 40 |
const uint src_base = ic * p.offset_delta + batch * p.batch_offset;
|
| 41 |
const uint dst_base = ((batch * p.OH + oh) * p.OW) * p.CHW + ic * (p.KW * p.KH);
|
| 42 |
const int oh_s1 = int(oh) * p.s1;
|
| 43 |
-
const uint ksize = p.OW *
|
| 44 |
|
| 45 |
const uint base_linear_idx = gidx * NUM_ITER;
|
| 46 |
|
| 47 |
-
const uint max_ky = ksize / p.OW;
|
| 48 |
-
|
| 49 |
uint current_kx = base_linear_idx / ksize;
|
| 50 |
const uint rem = base_linear_idx - (current_kx * ksize);
|
| 51 |
uint current_ky = rem / p.OW;
|
|
@@ -76,7 +74,7 @@ void main() {
|
|
| 76 |
|
| 77 |
if (++current_ix == p.OW) {
|
| 78 |
current_ix = 0;
|
| 79 |
-
if (++current_ky ==
|
| 80 |
current_ky = 0;
|
| 81 |
current_kx++;
|
| 82 |
}
|
|
|
|
| 40 |
const uint src_base = ic * p.offset_delta + batch * p.batch_offset;
|
| 41 |
const uint dst_base = ((batch * p.OH + oh) * p.OW) * p.CHW + ic * (p.KW * p.KH);
|
| 42 |
const int oh_s1 = int(oh) * p.s1;
|
| 43 |
+
const uint ksize = p.OW * p.KH;
|
| 44 |
|
| 45 |
const uint base_linear_idx = gidx * NUM_ITER;
|
| 46 |
|
|
|
|
|
|
|
| 47 |
uint current_kx = base_linear_idx / ksize;
|
| 48 |
const uint rem = base_linear_idx - (current_kx * ksize);
|
| 49 |
uint current_ky = rem / p.OW;
|
|
|
|
| 74 |
|
| 75 |
if (++current_ix == p.OW) {
|
| 76 |
current_ix = 0;
|
| 77 |
+
if (++current_ky == p.KH) {
|
| 78 |
current_ky = 0;
|
| 79 |
current_kx++;
|
| 80 |
}
|