jeffbolznv commited on
Commit
09dd86a
·
1 Parent(s): 608b377

vulkan: fix mul_mat_vec failure in backend tests (llama/12529)

Browse files

The OOB calculation could be wrong if the last iteration was during one of
the unrolled loops. Adjust the unrolling counts to avoid this. Add a couple
new backend tests that hit this failure on NVIDIA GPUs.

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp CHANGED
@@ -105,6 +105,16 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
105
  int unroll_count = 4;
106
  uint unrolled_iters = num_iters & ~(unroll_count - 1);
107
 
 
 
 
 
 
 
 
 
 
 
108
  uint i = 0;
109
  while (i < unrolled_iters) {
110
  // Manually partially unroll the loop
@@ -113,8 +123,18 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
113
  i++;
114
  }
115
  }
 
116
  unroll_count = 2;
117
  unrolled_iters = num_iters & ~(unroll_count - 1);
 
 
 
 
 
 
 
 
 
118
  while (i < unrolled_iters) {
119
  // Manually partially unroll the loop
120
  [[unroll]] for (uint k = 0; k < unroll_count; ++k) {
 
105
  int unroll_count = 4;
106
  uint unrolled_iters = num_iters & ~(unroll_count - 1);
107
 
108
+ #if K_PER_ITER == 2
109
+ // If the K dimension is odd, we need lastiter==true on the last iteration
110
+ // so OOB is computed correctly. Skip some unrolling to make that happen.
111
+ if ((p.ncols & 1) != 0 &&
112
+ unrolled_iters == num_iters &&
113
+ unrolled_iters > 0) {
114
+ unrolled_iters -= unroll_count;
115
+ }
116
+ #endif
117
+
118
  uint i = 0;
119
  while (i < unrolled_iters) {
120
  // Manually partially unroll the loop
 
123
  i++;
124
  }
125
  }
126
+
127
  unroll_count = 2;
128
  unrolled_iters = num_iters & ~(unroll_count - 1);
129
+
130
+ #if K_PER_ITER == 2
131
+ if ((p.ncols & 1) != 0 &&
132
+ unrolled_iters == num_iters &&
133
+ unrolled_iters > 0) {
134
+ unrolled_iters -= unroll_count;
135
+ }
136
+ #endif
137
+
138
  while (i < unrolled_iters) {
139
  // Manually partially unroll the loop
140
  [[unroll]] for (uint k = 0; k < unroll_count; ++k) {