jeffbolznv commited on
Commit
0be0329
·
1 Parent(s): 5885084

vulkan/cuda: Fix im2col when KW!=KH (llama/14789)

Browse files

The tid is decomposed into "ow + ky*OW + kx*OW*KH". Change "ksize" to match.

ggml/src/ggml-cuda/im2col.cu CHANGED
@@ -10,7 +10,7 @@ static __global__ void im2col_kernel(
10
  return;
11
  }
12
 
13
- const int64_t ksize = OW * (KH > 1 ? KW : 1);
14
  const int64_t kx = i / ksize;
15
  const int64_t kd = kx * ksize;
16
  const int64_t ky = (i - kd) / OW;
 
10
  return;
11
  }
12
 
13
+ const int64_t ksize = OW * KH;
14
  const int64_t kx = i / ksize;
15
  const int64_t kd = kx * ksize;
16
  const int64_t ky = (i - kd) / OW;
ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp CHANGED
@@ -40,12 +40,10 @@ void main() {
40
  const uint src_base = ic * p.offset_delta + batch * p.batch_offset;
41
  const uint dst_base = ((batch * p.OH + oh) * p.OW) * p.CHW + ic * (p.KW * p.KH);
42
  const int oh_s1 = int(oh) * p.s1;
43
- const uint ksize = p.OW * (p.KH > 1 ? p.KW : 1);
44
 
45
  const uint base_linear_idx = gidx * NUM_ITER;
46
 
47
- const uint max_ky = ksize / p.OW;
48
-
49
  uint current_kx = base_linear_idx / ksize;
50
  const uint rem = base_linear_idx - (current_kx * ksize);
51
  uint current_ky = rem / p.OW;
@@ -76,7 +74,7 @@ void main() {
76
 
77
  if (++current_ix == p.OW) {
78
  current_ix = 0;
79
- if (++current_ky == max_ky) {
80
  current_ky = 0;
81
  current_kx++;
82
  }
 
40
  const uint src_base = ic * p.offset_delta + batch * p.batch_offset;
41
  const uint dst_base = ((batch * p.OH + oh) * p.OW) * p.CHW + ic * (p.KW * p.KH);
42
  const int oh_s1 = int(oh) * p.s1;
43
+ const uint ksize = p.OW * p.KH;
44
 
45
  const uint base_linear_idx = gidx * NUM_ITER;
46
 
 
 
47
  uint current_kx = base_linear_idx / ksize;
48
  const uint rem = base_linear_idx - (current_kx * ksize);
49
  uint current_ky = rem / p.OW;
 
74
 
75
  if (++current_ix == p.OW) {
76
  current_ix = 0;
77
+ if (++current_ky == p.KH) {
78
  current_ky = 0;
79
  current_kx++;
80
  }