Spaces:
Sleeping
vulkan : mul_mat: fix UB with small warps (ggml/952)
Browse filesWhen the device's warp size is less than 16,
it is possible for loadstride_a (mul_mm.comp:114)
and loadstride_b (mul_mm.comp:115) to be set to 0.
Because they are calculated as: the workgroup size,
multiplied by LOAD_VEC_* (which can be 1) and divided by 16.
And the workgroup size is set to be the same as the
warp/subgroup size.
The loadstride_* variables are used as increments in the
loops that populate the buffers used for the multiplication.
When they are 0 they cause an infinite loop.
But infinite loops without side-effects are UB and the
values of loadstride_* are known at compile time.
So, the compiler quietly optimizes all the loops away.
As a consequence, the buffers are not populated and
the multiplication result is just a matrix with all elements
set to 0.
We prevent the UB by making sure that the workgroup size
will never be less than 16, even if our device has a
smaller warp size (e.g. 8).
Signed-off-by: Salvatore Mesoraca <[email protected]>
- ggml/src/ggml-vulkan.cpp +2 -2
|
@@ -1164,11 +1164,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
| 1164 |
// mulmat
|
| 1165 |
std::initializer_list<uint32_t> warptile_l = { 128, 128, 128, 16, device->subgroup_size * 2, 64, 2, 4, 4, device->subgroup_size };
|
| 1166 |
std::initializer_list<uint32_t> warptile_m = { 128, 64, 64, 16, device->subgroup_size, 32, 2, 4, 2, device->subgroup_size };
|
| 1167 |
-
std::initializer_list<uint32_t> warptile_s = { device->subgroup_size, 32, 32, 16, 32, 32, 2, 2, 2, device->subgroup_size };
|
| 1168 |
|
| 1169 |
std::initializer_list<uint32_t> warptile_mmq_l = { 128, 128, 128, 32, device->subgroup_size * 2, 64, 2, 4, 4, device->subgroup_size };
|
| 1170 |
std::initializer_list<uint32_t> warptile_mmq_m = { 128, 64, 64, 32, device->subgroup_size, 32, 2, 4, 2, device->subgroup_size };
|
| 1171 |
-
std::initializer_list<uint32_t> warptile_mmq_s = { device->subgroup_size, 32, 32, 32, 32, 32, 2, 2, 2, device->subgroup_size };
|
| 1172 |
|
| 1173 |
std::array<uint32_t, 3> l_wg_denoms = {128, 128, 1 };
|
| 1174 |
std::array<uint32_t, 3> m_wg_denoms = { 64, 64, 1 };
|
|
|
|
| 1164 |
// mulmat
|
| 1165 |
std::initializer_list<uint32_t> warptile_l = { 128, 128, 128, 16, device->subgroup_size * 2, 64, 2, 4, 4, device->subgroup_size };
|
| 1166 |
std::initializer_list<uint32_t> warptile_m = { 128, 64, 64, 16, device->subgroup_size, 32, 2, 4, 2, device->subgroup_size };
|
| 1167 |
+
std::initializer_list<uint32_t> warptile_s = { std::max(device->subgroup_size, 16u), 32, 32, 16, 32, 32, 2, 2, 2, device->subgroup_size };
|
| 1168 |
|
| 1169 |
std::initializer_list<uint32_t> warptile_mmq_l = { 128, 128, 128, 32, device->subgroup_size * 2, 64, 2, 4, 4, device->subgroup_size };
|
| 1170 |
std::initializer_list<uint32_t> warptile_mmq_m = { 128, 64, 64, 32, device->subgroup_size, 32, 2, 4, 2, device->subgroup_size };
|
| 1171 |
+
std::initializer_list<uint32_t> warptile_mmq_s = { std::max(device->subgroup_size, 16u), 32, 32, 32, 32, 32, 2, 2, 2, device->subgroup_size };
|
| 1172 |
|
| 1173 |
std::array<uint32_t, 3> l_wg_denoms = {128, 128, 1 };
|
| 1174 |
std::array<uint32_t, 3> m_wg_denoms = { 64, 64, 1 };
|