Spaces:
Running
Running
Commit
·
18a0ad1
1
Parent(s):
cb018d4
vulkan: Handle GPUs with less shared memory (llama/10468)
Browse filesThere have been reports of failure to compile on systems with <= 32KB
of shared memory (e.g. #10037). This change makes the large tile size
fall back to a smaller size if necessary, and makes mul_mat_id fall
back to CPU if there's only 16KB of shared memory.
- ggml/src/ggml-vulkan/ggml-vulkan.cpp +111 -63
ggml/src/ggml-vulkan/ggml-vulkan.cpp
CHANGED
|
@@ -1232,8 +1232,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
| 1232 |
std::cerr << "ggml_vulkan: Compiling shaders";
|
| 1233 |
|
| 1234 |
// mulmat
|
| 1235 |
-
std::vector<uint32_t> l_warptile, m_warptile, s_warptile,
|
| 1236 |
-
|
|
|
|
|
|
|
| 1237 |
uint32_t l_align, m_align, s_align;
|
| 1238 |
|
| 1239 |
l_warptile = { 128, 128, 128, 16, device->subgroup_size * 2, 64, 2, 4, 4, device->subgroup_size };
|
|
@@ -1244,14 +1246,48 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
| 1244 |
m_warptile_mmq = { 128, 64, 64, 32, device->subgroup_size, 32, 2, 4, 2, device->subgroup_size };
|
| 1245 |
s_warptile_mmq = { std::max(device->subgroup_size, 16u), 32, 32, 32, 32, 32, 2, 2, 2, device->subgroup_size };
|
| 1246 |
|
| 1247 |
-
l_wg_denoms = {128, 128, 1 };
|
| 1248 |
-
m_wg_denoms = { 64, 64, 1 };
|
| 1249 |
-
s_wg_denoms = { 32, 32, 1 };
|
| 1250 |
|
| 1251 |
l_align = 128;
|
| 1252 |
m_align = 64;
|
| 1253 |
s_align = 32;
|
| 1254 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1255 |
device->pipeline_matmul_f32 = std::make_shared<vk_matmul_pipeline_struct>();
|
| 1256 |
device->pipeline_matmul_f32_f16 = std::make_shared<vk_matmul_pipeline_struct>();
|
| 1257 |
|
|
@@ -1299,35 +1335,38 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
| 1299 |
CREATE_MM(pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3);
|
| 1300 |
CREATE_MM(pipeline_matmul_f16_f32.f32acc, matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3);
|
| 1301 |
|
| 1302 |
-
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, ,
|
| 1303 |
-
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, ,
|
| 1304 |
-
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, ,
|
| 1305 |
-
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32, ,
|
| 1306 |
-
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_f32, ,
|
| 1307 |
-
|
| 1308 |
-
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, ,
|
| 1309 |
-
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, ,
|
| 1310 |
-
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32, ,
|
| 1311 |
-
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32, ,
|
| 1312 |
-
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32, ,
|
| 1313 |
-
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, ,
|
| 1314 |
-
|
| 1315 |
-
|
| 1316 |
-
|
| 1317 |
-
|
| 1318 |
-
|
| 1319 |
-
|
| 1320 |
-
|
| 1321 |
-
|
| 1322 |
-
|
| 1323 |
-
|
| 1324 |
-
|
| 1325 |
-
|
| 1326 |
-
|
| 1327 |
-
|
| 1328 |
-
|
| 1329 |
-
|
| 1330 |
-
|
|
|
|
|
|
|
|
|
|
| 1331 |
#undef CREATE_MM
|
| 1332 |
} else {
|
| 1333 |
// Create 6 variants, {s,m,l}x{unaligned,aligned}
|
|
@@ -1344,35 +1383,38 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
| 1344 |
CREATE_MM(pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3);
|
| 1345 |
CREATE_MM(pipeline_matmul_f16_f32.f32acc, matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3);
|
| 1346 |
|
| 1347 |
-
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, ,
|
| 1348 |
-
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, ,
|
| 1349 |
-
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, ,
|
| 1350 |
-
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32, ,
|
| 1351 |
-
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_f32, ,
|
| 1352 |
-
|
| 1353 |
-
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, ,
|
| 1354 |
-
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, ,
|
| 1355 |
-
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32, ,
|
| 1356 |
-
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32, ,
|
| 1357 |
-
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32, ,
|
| 1358 |
-
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, ,
|
| 1359 |
-
|
| 1360 |
-
|
| 1361 |
-
|
| 1362 |
-
|
| 1363 |
-
|
| 1364 |
-
|
| 1365 |
-
|
| 1366 |
-
|
| 1367 |
-
|
| 1368 |
-
|
| 1369 |
-
|
| 1370 |
-
|
| 1371 |
-
|
| 1372 |
-
|
| 1373 |
-
|
| 1374 |
-
|
| 1375 |
-
|
|
|
|
|
|
|
|
|
|
| 1376 |
#undef CREATE_MM
|
| 1377 |
}
|
| 1378 |
|
|
@@ -6541,6 +6583,12 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
| 6541 |
case GGML_OP_MUL_MAT:
|
| 6542 |
case GGML_OP_MUL_MAT_ID:
|
| 6543 |
{
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6544 |
switch (op->src[0]->type) {
|
| 6545 |
case GGML_TYPE_F32:
|
| 6546 |
case GGML_TYPE_F16:
|
|
|
|
| 1232 |
std::cerr << "ggml_vulkan: Compiling shaders";
|
| 1233 |
|
| 1234 |
// mulmat
|
| 1235 |
+
std::vector<uint32_t> l_warptile, m_warptile, s_warptile,
|
| 1236 |
+
l_warptile_mmq, m_warptile_mmq, s_warptile_mmq;
|
| 1237 |
+
std::array<uint32_t, 3> l_wg_denoms, m_wg_denoms, s_wg_denoms,
|
| 1238 |
+
l_mmq_wg_denoms, m_mmq_wg_denoms, s_mmq_wg_denoms;
|
| 1239 |
uint32_t l_align, m_align, s_align;
|
| 1240 |
|
| 1241 |
l_warptile = { 128, 128, 128, 16, device->subgroup_size * 2, 64, 2, 4, 4, device->subgroup_size };
|
|
|
|
| 1246 |
m_warptile_mmq = { 128, 64, 64, 32, device->subgroup_size, 32, 2, 4, 2, device->subgroup_size };
|
| 1247 |
s_warptile_mmq = { std::max(device->subgroup_size, 16u), 32, 32, 32, 32, 32, 2, 2, 2, device->subgroup_size };
|
| 1248 |
|
| 1249 |
+
l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 };
|
| 1250 |
+
m_mmq_wg_denoms = m_wg_denoms = { 64, 64, 1 };
|
| 1251 |
+
s_mmq_wg_denoms = s_wg_denoms = { 32, 32, 1 };
|
| 1252 |
|
| 1253 |
l_align = 128;
|
| 1254 |
m_align = 64;
|
| 1255 |
s_align = 32;
|
| 1256 |
|
| 1257 |
+
// Fallback to smaller sizes if there's not enough shared memory. Given the current shaders
|
| 1258 |
+
// and tile sizes, this should handle 16KB, 32KB, and 48KB+.
|
| 1259 |
+
// This logic doesn't explicitly account for the 12KB row_ids in the mul_mat_mat_id shaders.
|
| 1260 |
+
// But the numbers happen to work out for 32KB shared memory size that when using the medium
|
| 1261 |
+
// size there's enough room for everything, and we assert for this.
|
| 1262 |
+
uint32_t shmem_needed = (l_warptile[1] + l_warptile[2]) * (l_warptile[3] + 1) * sizeof(float);
|
| 1263 |
+
if (shmem_needed > device->properties.limits.maxComputeSharedMemorySize) {
|
| 1264 |
+
l_warptile = m_warptile;
|
| 1265 |
+
l_wg_denoms = m_wg_denoms;
|
| 1266 |
+
shmem_needed = (l_warptile[1] + l_warptile[2]) * (l_warptile[3] + 1) * sizeof(float);
|
| 1267 |
+
GGML_ASSERT(shmem_needed <= device->properties.limits.maxComputeSharedMemorySize);
|
| 1268 |
+
}
|
| 1269 |
+
if (device->properties.limits.maxComputeSharedMemorySize >= 32768) {
|
| 1270 |
+
// assert mul_mat_mat_id shaders will fit.
|
| 1271 |
+
GGML_ASSERT(shmem_needed + 3072*4 <= device->properties.limits.maxComputeSharedMemorySize);
|
| 1272 |
+
}
|
| 1273 |
+
|
| 1274 |
+
shmem_needed = (l_warptile_mmq[1] + l_warptile_mmq[2]) * (l_warptile_mmq[3] + 1) * sizeof(float);
|
| 1275 |
+
if (shmem_needed > device->properties.limits.maxComputeSharedMemorySize) {
|
| 1276 |
+
if (device->properties.limits.maxComputeSharedMemorySize == 32768) {
|
| 1277 |
+
l_warptile_mmq = m_warptile_mmq;
|
| 1278 |
+
l_mmq_wg_denoms = m_mmq_wg_denoms;
|
| 1279 |
+
} else {
|
| 1280 |
+
l_warptile_mmq = s_warptile_mmq;
|
| 1281 |
+
l_mmq_wg_denoms = s_mmq_wg_denoms;
|
| 1282 |
+
}
|
| 1283 |
+
shmem_needed = (l_warptile_mmq[1] + l_warptile_mmq[2]) * (l_warptile_mmq[3] + 1) * sizeof(float);
|
| 1284 |
+
GGML_ASSERT(shmem_needed <= device->properties.limits.maxComputeSharedMemorySize);
|
| 1285 |
+
}
|
| 1286 |
+
if (device->properties.limits.maxComputeSharedMemorySize >= 32768) {
|
| 1287 |
+
// assert mul_mat_mat_id shaders will fit.
|
| 1288 |
+
GGML_ASSERT(shmem_needed + 3072*4 <= device->properties.limits.maxComputeSharedMemorySize);
|
| 1289 |
+
}
|
| 1290 |
+
|
| 1291 |
device->pipeline_matmul_f32 = std::make_shared<vk_matmul_pipeline_struct>();
|
| 1292 |
device->pipeline_matmul_f32_f16 = std::make_shared<vk_matmul_pipeline_struct>();
|
| 1293 |
|
|
|
|
| 1335 |
CREATE_MM(pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3);
|
| 1336 |
CREATE_MM(pipeline_matmul_f16_f32.f32acc, matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3);
|
| 1337 |
|
| 1338 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
|
| 1339 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
|
| 1340 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
|
| 1341 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
|
| 1342 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
|
| 1343 |
+
|
| 1344 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
|
| 1345 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
|
| 1346 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
|
| 1347 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
|
| 1348 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
|
| 1349 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
|
| 1350 |
+
|
| 1351 |
+
// If there's not enough shared memory for row_ids and the result tile, don't create these pipelines.
|
| 1352 |
+
if (device->properties.limits.maxComputeSharedMemorySize >= 32768) {
|
| 1353 |
+
CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4);
|
| 1354 |
+
CREATE_MM(pipeline_matmul_id_f16, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4);
|
| 1355 |
+
CREATE_MM(pipeline_matmul_id_f16_f32, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4);
|
| 1356 |
+
|
| 1357 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
|
| 1358 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
|
| 1359 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
|
| 1360 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
|
| 1361 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
|
| 1362 |
+
|
| 1363 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
|
| 1364 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
|
| 1365 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
|
| 1366 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
|
| 1367 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
|
| 1368 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
|
| 1369 |
+
}
|
| 1370 |
#undef CREATE_MM
|
| 1371 |
} else {
|
| 1372 |
// Create 6 variants, {s,m,l}x{unaligned,aligned}
|
|
|
|
| 1383 |
CREATE_MM(pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3);
|
| 1384 |
CREATE_MM(pipeline_matmul_f16_f32.f32acc, matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3);
|
| 1385 |
|
| 1386 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
|
| 1387 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
|
| 1388 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
|
| 1389 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
|
| 1390 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
|
| 1391 |
+
|
| 1392 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
|
| 1393 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
|
| 1394 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
|
| 1395 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
|
| 1396 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
|
| 1397 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
|
| 1398 |
+
|
| 1399 |
+
// If there's not enough shared memory for row_ids and the result tile, don't create these pipelines.
|
| 1400 |
+
if (device->properties.limits.maxComputeSharedMemorySize >= 32768) {
|
| 1401 |
+
CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4);
|
| 1402 |
+
CREATE_MM(pipeline_matmul_id_f16, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4);
|
| 1403 |
+
CREATE_MM(pipeline_matmul_id_f16_f32, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4);
|
| 1404 |
+
|
| 1405 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
|
| 1406 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
|
| 1407 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
|
| 1408 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
|
| 1409 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
|
| 1410 |
+
|
| 1411 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
|
| 1412 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
|
| 1413 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
|
| 1414 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
|
| 1415 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
|
| 1416 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
|
| 1417 |
+
}
|
| 1418 |
#undef CREATE_MM
|
| 1419 |
}
|
| 1420 |
|
|
|
|
| 6583 |
case GGML_OP_MUL_MAT:
|
| 6584 |
case GGML_OP_MUL_MAT_ID:
|
| 6585 |
{
|
| 6586 |
+
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
|
| 6587 |
+
if (op->op == GGML_OP_MUL_MAT_ID &&
|
| 6588 |
+
ggml_vk_get_device(ctx->device)->properties.limits.maxComputeSharedMemorySize < 32768) {
|
| 6589 |
+
// If there's not enough shared memory for row_ids and the result tile, fallback to CPU
|
| 6590 |
+
return false;
|
| 6591 |
+
}
|
| 6592 |
switch (op->src[0]->type) {
|
| 6593 |
case GGML_TYPE_F32:
|
| 6594 |
case GGML_TYPE_F16:
|