Spaces:
Running
ggml: adds CONV_2D op and direct GEMM Vulkan implementation (llama/14316)
Browse files* ggml/ggml-vulkan/test-backend-ops: adds CONV_2D for Vulkan
* ggml-vulkan: adds f32 scalar shader to compute 2D convolution directly
with gemm (no need for im2col),
* test-backend-ops: adds test_case_ref to check the validity/performance of ops
against reference implementations having different graphs, adds tests
* * Performance fixes: minimized branch divergence, uses collectives to
eliminate redundant calculation, macros removed.
* Kernel shared memory size check
* Updates test-backend-ops to support graphs for performance
measurement.
* * Apple/Win32 compile errors fixed
* Subgroup size used to determine tile size -> fixes llvmpipe errors.
* Collectives disabled by default.
* Intel support is disabled as the performance is poor.
* Conv2d enabled for Intel with disabled collectives, disabled for Apple
* test-backend-ops modifications are reverted
* Trailing spaces and missing override fixed.
* Triggering pipeline relaunch.
* Code formatted with .clang-format.
|
@@ -483,6 +483,7 @@ struct vk_device_struct {
|
|
| 483 |
vk_pipeline pipeline_rwkv_wkv6_f32;
|
| 484 |
vk_pipeline pipeline_rwkv_wkv7_f32;
|
| 485 |
vk_pipeline pipeline_opt_step_adamw_f32;
|
|
|
|
| 486 |
vk_pipeline pipeline_conv2d_dw_whcn_f32;
|
| 487 |
vk_pipeline pipeline_conv2d_dw_cwhn_f32;
|
| 488 |
|
|
@@ -876,6 +877,38 @@ struct vk_op_rwkv_wkv7_push_constants {
|
|
| 876 |
uint32_t H;
|
| 877 |
};
|
| 878 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 879 |
struct vk_op_conv2d_dw_push_constants {
|
| 880 |
uint32_t ne;
|
| 881 |
uint32_t batches;
|
|
@@ -975,18 +1008,45 @@ private:
|
|
| 975 |
#endif // GGML_VULKAN_MEMORY_DEBUG
|
| 976 |
|
| 977 |
class vk_perf_logger {
|
| 978 |
-
public:
|
| 979 |
void print_timings() {
|
|
|
|
|
|
|
|
|
|
|
|
|
| 980 |
std::cerr << "----------------\nVulkan Timings:" << std::endl;
|
| 981 |
-
for (const auto& t : timings) {
|
| 982 |
-
uint64_t
|
| 983 |
-
for (const auto& time : t.second) {
|
| 984 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 985 |
}
|
| 986 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 987 |
}
|
| 988 |
|
| 989 |
timings.clear();
|
|
|
|
| 990 |
}
|
| 991 |
|
| 992 |
void log_timing(const ggml_tensor * node, uint64_t time) {
|
|
@@ -995,22 +1055,45 @@ public:
|
|
| 995 |
return;
|
| 996 |
}
|
| 997 |
if (node->op == GGML_OP_MUL_MAT || node->op == GGML_OP_MUL_MAT_ID) {
|
| 998 |
-
const uint64_t m
|
| 999 |
-
const uint64_t n
|
| 1000 |
-
const uint64_t k
|
| 1001 |
-
std::string
|
| 1002 |
if (n == 1) {
|
| 1003 |
name += "_VEC m=" + std::to_string(m) + " k=" + std::to_string(k);
|
| 1004 |
} else {
|
| 1005 |
name += " m=" + std::to_string(m) + " n=" + std::to_string(n) + " k=" + std::to_string(k);
|
| 1006 |
}
|
| 1007 |
timings[name].push_back(time);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1008 |
return;
|
| 1009 |
}
|
| 1010 |
timings[ggml_op_name(node->op)].push_back(time);
|
| 1011 |
}
|
| 1012 |
-
private:
|
| 1013 |
std::map<std::string, std::vector<uint64_t>> timings;
|
|
|
|
| 1014 |
};
|
| 1015 |
|
| 1016 |
struct ggml_backend_vk_context {
|
|
@@ -2113,6 +2196,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
| 2113 |
}
|
| 2114 |
compile_count++;
|
| 2115 |
}
|
|
|
|
| 2116 |
compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), spv_size, spv_data, entrypoint,
|
| 2117 |
parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size));
|
| 2118 |
};
|
|
@@ -2962,6 +3046,42 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
| 2962 |
|
| 2963 |
ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
| 2964 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2965 |
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
|
| 2966 |
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f32, "conv2d_dw_cwhn_f32", conv2d_dw_cwhn_f32_len, conv2d_dw_cwhn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
|
| 2967 |
|
|
@@ -6837,6 +6957,12 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
| 6837 |
return ctx->device->pipeline_leaky_relu_f32;
|
| 6838 |
}
|
| 6839 |
return nullptr;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6840 |
case GGML_OP_CONV_2D_DW:
|
| 6841 |
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
| 6842 |
if (ggml_is_contiguous(src1)) {
|
|
@@ -7159,6 +7285,31 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|
| 7159 |
const uint32_t OW = dst->ne[0];
|
| 7160 |
elements = { N * OC * OH * OW, 1, 1};
|
| 7161 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7162 |
case GGML_OP_ADD:
|
| 7163 |
case GGML_OP_SUB:
|
| 7164 |
case GGML_OP_DIV:
|
|
@@ -8025,6 +8176,55 @@ static void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, c
|
|
| 8025 |
}, dryrun);
|
| 8026 |
}
|
| 8027 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8028 |
static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
| 8029 |
vk_op_conv2d_dw_push_constants p{};
|
| 8030 |
p.ne = ggml_nelements(dst);
|
|
@@ -9087,6 +9287,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|
| 9087 |
case GGML_OP_TIMESTEP_EMBEDDING:
|
| 9088 |
case GGML_OP_CONV_TRANSPOSE_1D:
|
| 9089 |
case GGML_OP_POOL_2D:
|
|
|
|
| 9090 |
case GGML_OP_CONV_2D_DW:
|
| 9091 |
case GGML_OP_RWKV_WKV6:
|
| 9092 |
case GGML_OP_RWKV_WKV7:
|
|
@@ -9154,6 +9355,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|
| 9154 |
case GGML_OP_TIMESTEP_EMBEDDING:
|
| 9155 |
case GGML_OP_CONV_TRANSPOSE_1D:
|
| 9156 |
case GGML_OP_POOL_2D:
|
|
|
|
| 9157 |
case GGML_OP_CONV_2D_DW:
|
| 9158 |
case GGML_OP_LEAKY_RELU:
|
| 9159 |
{
|
|
@@ -9360,6 +9562,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|
| 9360 |
case GGML_OP_POOL_2D:
|
| 9361 |
ggml_vk_pool_2d(ctx, compute_ctx, src0, node, dryrun);
|
| 9362 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9363 |
break;
|
| 9364 |
case GGML_OP_CONV_2D_DW:
|
| 9365 |
ggml_vk_conv_2d_dw(ctx, compute_ctx, src0, src1, node, dryrun);
|
|
@@ -9490,6 +9696,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
|
|
| 9490 |
case GGML_OP_TIMESTEP_EMBEDDING:
|
| 9491 |
case GGML_OP_CONV_TRANSPOSE_1D:
|
| 9492 |
case GGML_OP_POOL_2D:
|
|
|
|
| 9493 |
case GGML_OP_CONV_2D_DW:
|
| 9494 |
case GGML_OP_RWKV_WKV6:
|
| 9495 |
case GGML_OP_RWKV_WKV7:
|
|
@@ -10071,6 +10278,12 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|
| 10071 |
ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false);
|
| 10072 |
if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) {
|
| 10073 |
total_mat_mul_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10074 |
}
|
| 10075 |
i += ctx->num_additional_fused_ops;
|
| 10076 |
ctx->num_additional_fused_ops = 0;
|
|
@@ -10647,6 +10860,20 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
| 10647 |
return true;
|
| 10648 |
case GGML_OP_CONV_TRANSPOSE_1D:
|
| 10649 |
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10650 |
default:
|
| 10651 |
return false;
|
| 10652 |
}
|
|
@@ -11205,6 +11432,14 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
|
|
| 11205 |
const int32_t p1 = tensor->op_params[6];
|
| 11206 |
|
| 11207 |
tensor_clone = ggml_pool_2d(ggml_ctx, src_clone[0], op, k0, k1, s0, s1, p0, p1);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11208 |
} else if (tensor->op == GGML_OP_LEAKY_RELU) {
|
| 11209 |
const float * op_params = (const float *)tensor->op_params;
|
| 11210 |
tensor_clone = ggml_leaky_relu(ggml_ctx, src_clone[0], op_params[0], false);
|
|
|
|
| 483 |
vk_pipeline pipeline_rwkv_wkv6_f32;
|
| 484 |
vk_pipeline pipeline_rwkv_wkv7_f32;
|
| 485 |
vk_pipeline pipeline_opt_step_adamw_f32;
|
| 486 |
+
vk_pipeline pipeline_conv2d_f32;
|
| 487 |
vk_pipeline pipeline_conv2d_dw_whcn_f32;
|
| 488 |
vk_pipeline pipeline_conv2d_dw_cwhn_f32;
|
| 489 |
|
|
|
|
| 877 |
uint32_t H;
|
| 878 |
};
|
| 879 |
|
| 880 |
+
struct vk_op_conv2d_push_constants {
|
| 881 |
+
uint32_t Cout;
|
| 882 |
+
uint32_t Cin;
|
| 883 |
+
uint32_t N;
|
| 884 |
+
|
| 885 |
+
uint32_t KW;
|
| 886 |
+
uint32_t KH;
|
| 887 |
+
uint32_t W;
|
| 888 |
+
uint32_t H;
|
| 889 |
+
uint32_t OW;
|
| 890 |
+
uint32_t OH;
|
| 891 |
+
|
| 892 |
+
uint32_t s0;
|
| 893 |
+
uint32_t s1;
|
| 894 |
+
uint32_t p0;
|
| 895 |
+
uint32_t p1;
|
| 896 |
+
uint32_t d0;
|
| 897 |
+
uint32_t d1;
|
| 898 |
+
|
| 899 |
+
uint32_t nb01;
|
| 900 |
+
uint32_t nb02;
|
| 901 |
+
uint32_t nb03;
|
| 902 |
+
|
| 903 |
+
uint32_t nb11;
|
| 904 |
+
uint32_t nb12;
|
| 905 |
+
uint32_t nb13;
|
| 906 |
+
|
| 907 |
+
uint32_t nb1;
|
| 908 |
+
uint32_t nb2;
|
| 909 |
+
uint32_t nb3;
|
| 910 |
+
};
|
| 911 |
+
|
| 912 |
struct vk_op_conv2d_dw_push_constants {
|
| 913 |
uint32_t ne;
|
| 914 |
uint32_t batches;
|
|
|
|
| 1008 |
#endif // GGML_VULKAN_MEMORY_DEBUG
|
| 1009 |
|
| 1010 |
class vk_perf_logger {
|
| 1011 |
+
public:
|
| 1012 |
void print_timings() {
|
| 1013 |
+
if (timings.empty()) {
|
| 1014 |
+
return;
|
| 1015 |
+
}
|
| 1016 |
+
uint64_t total_all_op_times = 0;
|
| 1017 |
std::cerr << "----------------\nVulkan Timings:" << std::endl;
|
| 1018 |
+
for (const auto & t : timings) {
|
| 1019 |
+
uint64_t total_op_times = 0;
|
| 1020 |
+
for (const auto & time : t.second) {
|
| 1021 |
+
total_op_times += time;
|
| 1022 |
+
}
|
| 1023 |
+
std::cerr << t.first << ": " << t.second.size() << " x " << (total_op_times / t.second.size() / 1000.0)
|
| 1024 |
+
<< " us";
|
| 1025 |
+
|
| 1026 |
+
// If we have as many flops entries as timing entries for the op, then compute and log the flops/S.
|
| 1027 |
+
auto it = flops.find(t.first);
|
| 1028 |
+
if (it != flops.end() && (it->second).size() == t.second.size()) {
|
| 1029 |
+
uint64_t total_op_flops = 0;
|
| 1030 |
+
for (const auto & elem : it->second) {
|
| 1031 |
+
total_op_flops += elem;
|
| 1032 |
+
}
|
| 1033 |
+
std::cerr << " ("
|
| 1034 |
+
<< (double(total_op_flops) / (1000.0 * 1000.0 * 1000.0)) /
|
| 1035 |
+
(double(total_op_times) / (1000.0 * 1000.0 * 1000.0))
|
| 1036 |
+
<< " GFLOPS/s)";
|
| 1037 |
}
|
| 1038 |
+
|
| 1039 |
+
total_all_op_times += total_op_times;
|
| 1040 |
+
|
| 1041 |
+
std::cerr << std::endl;
|
| 1042 |
+
}
|
| 1043 |
+
|
| 1044 |
+
if (timings.size() > 0) {
|
| 1045 |
+
std::cerr << "Total time: " << total_all_op_times / 1000.0 << " us." << std::endl;
|
| 1046 |
}
|
| 1047 |
|
| 1048 |
timings.clear();
|
| 1049 |
+
flops.clear();
|
| 1050 |
}
|
| 1051 |
|
| 1052 |
void log_timing(const ggml_tensor * node, uint64_t time) {
|
|
|
|
| 1055 |
return;
|
| 1056 |
}
|
| 1057 |
if (node->op == GGML_OP_MUL_MAT || node->op == GGML_OP_MUL_MAT_ID) {
|
| 1058 |
+
const uint64_t m = node->src[0]->ne[1];
|
| 1059 |
+
const uint64_t n = node->src[1]->ne[1];
|
| 1060 |
+
const uint64_t k = node->src[1]->ne[0];
|
| 1061 |
+
std::string name = ggml_op_name(node->op);
|
| 1062 |
if (n == 1) {
|
| 1063 |
name += "_VEC m=" + std::to_string(m) + " k=" + std::to_string(k);
|
| 1064 |
} else {
|
| 1065 |
name += " m=" + std::to_string(m) + " n=" + std::to_string(n) + " k=" + std::to_string(k);
|
| 1066 |
}
|
| 1067 |
timings[name].push_back(time);
|
| 1068 |
+
flops[name].push_back(m * n * (k + (k - 1)));
|
| 1069 |
+
return;
|
| 1070 |
+
}
|
| 1071 |
+
if (node->op == GGML_OP_CONV_2D) {
|
| 1072 |
+
std::string name = ggml_op_name(node->op);
|
| 1073 |
+
ggml_tensor * knl = node->src[0];
|
| 1074 |
+
uint64_t OW = node->ne[0];
|
| 1075 |
+
uint64_t OH = node->ne[1];
|
| 1076 |
+
uint64_t N = node->ne[3];
|
| 1077 |
+
uint64_t Cout = node->ne[2];
|
| 1078 |
+
uint64_t KW = knl->ne[0];
|
| 1079 |
+
uint64_t KH = knl->ne[1];
|
| 1080 |
+
uint64_t Cin = knl->ne[2];
|
| 1081 |
+
// KxCRS @ CRSxNPQ = KxNPQ -> M=K, K=CRS, N=NPQ
|
| 1082 |
+
uint64_t size_M = Cout;
|
| 1083 |
+
uint64_t size_K = Cin * KW * KH;
|
| 1084 |
+
uint64_t size_N = N * OW * OH;
|
| 1085 |
+
uint64_t n_flops = size_M * size_N * (size_K + (size_K - 1));
|
| 1086 |
+
name += " M=Cout=" + std::to_string(size_M) + ", K=Cin*KW*KH=" + std::to_string(size_K) +
|
| 1087 |
+
", N=N*OW*OH=" + std::to_string(size_N);
|
| 1088 |
+
flops[name].push_back(n_flops);
|
| 1089 |
+
timings[name].push_back(time);
|
| 1090 |
return;
|
| 1091 |
}
|
| 1092 |
timings[ggml_op_name(node->op)].push_back(time);
|
| 1093 |
}
|
| 1094 |
+
private:
|
| 1095 |
std::map<std::string, std::vector<uint64_t>> timings;
|
| 1096 |
+
std::map<std::string, std::vector<uint64_t>> flops;
|
| 1097 |
};
|
| 1098 |
|
| 1099 |
struct ggml_backend_vk_context {
|
|
|
|
| 2196 |
}
|
| 2197 |
compile_count++;
|
| 2198 |
}
|
| 2199 |
+
|
| 2200 |
compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), spv_size, spv_data, entrypoint,
|
| 2201 |
parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size));
|
| 2202 |
};
|
|
|
|
| 3046 |
|
| 3047 |
ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
| 3048 |
|
| 3049 |
+
// conv2d
|
| 3050 |
+
uint32_t conv2d_WG_SIZE = 256;
|
| 3051 |
+
uint32_t conv2d_BS_K = 128;
|
| 3052 |
+
uint32_t conv2d_BS_CRS = 16;
|
| 3053 |
+
uint32_t use_collectives = 0; // Enables subgroup ops for preventing the re-calculation of indices.
|
| 3054 |
+
if (device->subgroup_shuffle &&
|
| 3055 |
+
device->vendor_id != VK_VENDOR_ID_INTEL) { // Do not enable collectives on Intel, see PR 14316
|
| 3056 |
+
use_collectives = 1;
|
| 3057 |
+
conv2d_BS_CRS = std::min(
|
| 3058 |
+
device->subgroup_size,
|
| 3059 |
+
conv2d_BS_CRS); // CRS block size should be capped at sugroup size for correctness when shuffle is used.
|
| 3060 |
+
}
|
| 3061 |
+
uint32_t conv2d_BS_NPQ = 128;
|
| 3062 |
+
uint32_t conv2d_TS_K = 8;
|
| 3063 |
+
uint32_t conv2d_shmem_req =
|
| 3064 |
+
(conv2d_BS_K * (conv2d_BS_CRS + 1) + conv2d_BS_CRS * (conv2d_BS_NPQ + 1)) * sizeof(float);
|
| 3065 |
+
if (device->properties.limits.maxComputeSharedMemorySize < conv2d_shmem_req) {
|
| 3066 |
+
conv2d_BS_CRS = 8;
|
| 3067 |
+
if (use_collectives) {
|
| 3068 |
+
conv2d_BS_CRS = std::min(device->subgroup_size, conv2d_BS_CRS);
|
| 3069 |
+
}
|
| 3070 |
+
}
|
| 3071 |
+
|
| 3072 |
+
if (use_collectives) {
|
| 3073 |
+
ggml_vk_create_pipeline(
|
| 3074 |
+
device, device->pipeline_conv2d_f32, "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3,
|
| 3075 |
+
sizeof(vk_op_conv2d_push_constants), { conv2d_BS_K, conv2d_BS_NPQ, 1 },
|
| 3076 |
+
{ conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives }, 1, true, true);
|
| 3077 |
+
} else {
|
| 3078 |
+
ggml_vk_create_pipeline(
|
| 3079 |
+
device, device->pipeline_conv2d_f32, "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3,
|
| 3080 |
+
sizeof(vk_op_conv2d_push_constants), { conv2d_BS_K, conv2d_BS_NPQ, 1 },
|
| 3081 |
+
{ conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives }, 1, true,
|
| 3082 |
+
false);
|
| 3083 |
+
}
|
| 3084 |
+
|
| 3085 |
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
|
| 3086 |
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f32, "conv2d_dw_cwhn_f32", conv2d_dw_cwhn_f32_len, conv2d_dw_cwhn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
|
| 3087 |
|
|
|
|
| 6957 |
return ctx->device->pipeline_leaky_relu_f32;
|
| 6958 |
}
|
| 6959 |
return nullptr;
|
| 6960 |
+
case GGML_OP_CONV_2D:
|
| 6961 |
+
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 &&
|
| 6962 |
+
ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) {
|
| 6963 |
+
return ctx->device->pipeline_conv2d_f32;
|
| 6964 |
+
}
|
| 6965 |
+
return nullptr;
|
| 6966 |
case GGML_OP_CONV_2D_DW:
|
| 6967 |
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
| 6968 |
if (ggml_is_contiguous(src1)) {
|
|
|
|
| 7285 |
const uint32_t OW = dst->ne[0];
|
| 7286 |
elements = { N * OC * OH * OW, 1, 1};
|
| 7287 |
} break;
|
| 7288 |
+
case GGML_OP_CONV_2D:
|
| 7289 |
+
{
|
| 7290 |
+
// src0 - kernel: [KW, KH, Cin, Cout]
|
| 7291 |
+
// src1 - input: [W, H, Cin, N]
|
| 7292 |
+
// dst - result: [OW, OH, Cout, N]
|
| 7293 |
+
|
| 7294 |
+
// Copied from ggml.c: int64_t ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p, int d)
|
| 7295 |
+
auto calc_conv_output_size = [](int64_t ins, int64_t ks, int s, int p, int d) -> int64_t {
|
| 7296 |
+
return (ins + 2 * p - d * (ks - 1) - 1) / s + 1;
|
| 7297 |
+
};
|
| 7298 |
+
// parallelize in {OW/BS_K, OH/BS_NPQ, 1}
|
| 7299 |
+
int64_t W = src1->ne[0];
|
| 7300 |
+
int64_t H = src1->ne[1];
|
| 7301 |
+
int64_t KW = src0->ne[0];
|
| 7302 |
+
int64_t KH = src0->ne[1];
|
| 7303 |
+
int64_t Cout = src0->ne[3];
|
| 7304 |
+
int64_t N = src1->ne[3];
|
| 7305 |
+
int64_t OH = calc_conv_output_size(H, KH, dst->op_params[1], dst->op_params[3], dst->op_params[5]);
|
| 7306 |
+
int64_t OW = calc_conv_output_size(W, KW, dst->op_params[0], dst->op_params[2], dst->op_params[4]);
|
| 7307 |
+
int64_t NPQ = N * OW * OH;
|
| 7308 |
+
|
| 7309 |
+
// Tile output matrix to (K/NB_K, NPQ/NB_NPQ, 1) workgroups
|
| 7310 |
+
elements = { static_cast<uint32_t>(Cout), static_cast<uint32_t>(NPQ), 1 };
|
| 7311 |
+
}
|
| 7312 |
+
break;
|
| 7313 |
case GGML_OP_ADD:
|
| 7314 |
case GGML_OP_SUB:
|
| 7315 |
case GGML_OP_DIV:
|
|
|
|
| 8176 |
}, dryrun);
|
| 8177 |
}
|
| 8178 |
|
| 8179 |
+
static void ggml_vk_conv_2d(ggml_backend_vk_context * ctx, vk_context & subctx, const ggml_tensor * src0,
|
| 8180 |
+
const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
| 8181 |
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
| 8182 |
+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
| 8183 |
+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
| 8184 |
+
|
| 8185 |
+
GGML_TENSOR_BINARY_OP_LOCALS
|
| 8186 |
+
|
| 8187 |
+
GGML_ASSERT(nb00 == sizeof(float));
|
| 8188 |
+
GGML_ASSERT(nb10 == sizeof(float));
|
| 8189 |
+
GGML_ASSERT(nb0 == sizeof(float));
|
| 8190 |
+
|
| 8191 |
+
vk_op_conv2d_push_constants p{};
|
| 8192 |
+
p.Cout = static_cast<uint32_t>(ne03);
|
| 8193 |
+
p.Cin = static_cast<uint32_t>(ne02);
|
| 8194 |
+
p.N = static_cast<uint32_t>(ne13);
|
| 8195 |
+
|
| 8196 |
+
p.KW = static_cast<uint32_t>(ne00);
|
| 8197 |
+
p.KH = static_cast<uint32_t>(ne01);
|
| 8198 |
+
p.W = static_cast<uint32_t>(ne10);
|
| 8199 |
+
p.H = static_cast<uint32_t>(ne11);
|
| 8200 |
+
p.OW = static_cast<uint32_t>(ne0);
|
| 8201 |
+
p.OH = static_cast<uint32_t>(ne1);
|
| 8202 |
+
|
| 8203 |
+
p.s0 = static_cast<uint32_t>(dst->op_params[0]);
|
| 8204 |
+
p.s1 = static_cast<uint32_t>(dst->op_params[1]);
|
| 8205 |
+
p.p0 = static_cast<uint32_t>(dst->op_params[2]);
|
| 8206 |
+
p.p1 = static_cast<uint32_t>(dst->op_params[3]);
|
| 8207 |
+
p.d0 = static_cast<uint32_t>(dst->op_params[4]);
|
| 8208 |
+
p.d1 = static_cast<uint32_t>(dst->op_params[5]);
|
| 8209 |
+
|
| 8210 |
+
p.nb01 = static_cast<uint32_t>(nb01 / nb00);
|
| 8211 |
+
p.nb02 = static_cast<uint32_t>(nb02 / nb00);
|
| 8212 |
+
p.nb03 = static_cast<uint32_t>(nb03 / nb00);
|
| 8213 |
+
|
| 8214 |
+
p.nb11 = static_cast<uint32_t>(nb11 / nb10);
|
| 8215 |
+
p.nb12 = static_cast<uint32_t>(nb12 / nb10);
|
| 8216 |
+
p.nb13 = static_cast<uint32_t>(nb13 / nb10);
|
| 8217 |
+
|
| 8218 |
+
p.nb1 = static_cast<uint32_t>(nb1 / nb0);
|
| 8219 |
+
p.nb2 = static_cast<uint32_t>(nb2 / nb0);
|
| 8220 |
+
p.nb3 = static_cast<uint32_t>(nb3 / nb0);
|
| 8221 |
+
|
| 8222 |
+
GGML_ASSERT(ne03 == ne2);
|
| 8223 |
+
GGML_ASSERT(ne02 == ne12);
|
| 8224 |
+
|
| 8225 |
+
ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONV_2D, std::move(p), dryrun);
|
| 8226 |
+
}
|
| 8227 |
+
|
| 8228 |
static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
| 8229 |
vk_op_conv2d_dw_push_constants p{};
|
| 8230 |
p.ne = ggml_nelements(dst);
|
|
|
|
| 9287 |
case GGML_OP_TIMESTEP_EMBEDDING:
|
| 9288 |
case GGML_OP_CONV_TRANSPOSE_1D:
|
| 9289 |
case GGML_OP_POOL_2D:
|
| 9290 |
+
case GGML_OP_CONV_2D:
|
| 9291 |
case GGML_OP_CONV_2D_DW:
|
| 9292 |
case GGML_OP_RWKV_WKV6:
|
| 9293 |
case GGML_OP_RWKV_WKV7:
|
|
|
|
| 9355 |
case GGML_OP_TIMESTEP_EMBEDDING:
|
| 9356 |
case GGML_OP_CONV_TRANSPOSE_1D:
|
| 9357 |
case GGML_OP_POOL_2D:
|
| 9358 |
+
case GGML_OP_CONV_2D:
|
| 9359 |
case GGML_OP_CONV_2D_DW:
|
| 9360 |
case GGML_OP_LEAKY_RELU:
|
| 9361 |
{
|
|
|
|
| 9562 |
case GGML_OP_POOL_2D:
|
| 9563 |
ggml_vk_pool_2d(ctx, compute_ctx, src0, node, dryrun);
|
| 9564 |
|
| 9565 |
+
break;
|
| 9566 |
+
case GGML_OP_CONV_2D:
|
| 9567 |
+
ggml_vk_conv_2d(ctx, compute_ctx, src0, src1, node, dryrun);
|
| 9568 |
+
|
| 9569 |
break;
|
| 9570 |
case GGML_OP_CONV_2D_DW:
|
| 9571 |
ggml_vk_conv_2d_dw(ctx, compute_ctx, src0, src1, node, dryrun);
|
|
|
|
| 9696 |
case GGML_OP_TIMESTEP_EMBEDDING:
|
| 9697 |
case GGML_OP_CONV_TRANSPOSE_1D:
|
| 9698 |
case GGML_OP_POOL_2D:
|
| 9699 |
+
case GGML_OP_CONV_2D:
|
| 9700 |
case GGML_OP_CONV_2D_DW:
|
| 9701 |
case GGML_OP_RWKV_WKV6:
|
| 9702 |
case GGML_OP_RWKV_WKV7:
|
|
|
|
| 10278 |
ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false);
|
| 10279 |
if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) {
|
| 10280 |
total_mat_mul_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
|
| 10281 |
+
} else if (cgraph->nodes[i]->op == GGML_OP_CONV_2D) {
|
| 10282 |
+
// Return CRSxNPQxsizeof(*) to account as many bytes as mul_mat has in im2col->mul_mat mode.
|
| 10283 |
+
auto CRS_size =
|
| 10284 |
+
cgraph->nodes[i]->src[0]->ne[0] * cgraph->nodes[i]->src[0]->ne[1] * cgraph->nodes[i]->src[0]->ne[2];
|
| 10285 |
+
auto NPQ_size = cgraph->nodes[i]->ne[0] * cgraph->nodes[i]->ne[1] * cgraph->nodes[i]->ne[3];
|
| 10286 |
+
total_mat_mul_bytes += NPQ_size * CRS_size * ggml_type_size(cgraph->nodes[i]->type);
|
| 10287 |
}
|
| 10288 |
i += ctx->num_additional_fused_ops;
|
| 10289 |
ctx->num_additional_fused_ops = 0;
|
|
|
|
| 10860 |
return true;
|
| 10861 |
case GGML_OP_CONV_TRANSPOSE_1D:
|
| 10862 |
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
|
| 10863 |
+
case GGML_OP_CONV_2D:
|
| 10864 |
+
{
|
| 10865 |
+
// Op is disabled for Apple because it segfaults at pipeline create time on MoltenVK
|
| 10866 |
+
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
|
| 10867 |
+
const vk_device& device = ggml_vk_get_device(ctx->device);
|
| 10868 |
+
bool is_Apple = ggml_vk_get_device(ctx->device)->vendor_id == VK_VENDOR_ID_APPLE;
|
| 10869 |
+
// Channel-contiguous format is not supported yet.
|
| 10870 |
+
return (op->src[0]->type == GGML_TYPE_F32 &&
|
| 10871 |
+
op->src[1]->type == GGML_TYPE_F32 &&
|
| 10872 |
+
op->type == GGML_TYPE_F32 &&
|
| 10873 |
+
ggml_is_contiguous(op->src[0]) &&
|
| 10874 |
+
ggml_is_contiguous(op->src[1]) &&
|
| 10875 |
+
ggml_is_contiguous(op)) && !is_Apple;
|
| 10876 |
+
}
|
| 10877 |
default:
|
| 10878 |
return false;
|
| 10879 |
}
|
|
|
|
| 11432 |
const int32_t p1 = tensor->op_params[6];
|
| 11433 |
|
| 11434 |
tensor_clone = ggml_pool_2d(ggml_ctx, src_clone[0], op, k0, k1, s0, s1, p0, p1);
|
| 11435 |
+
} else if (tensor->op == GGML_OP_CONV_2D) {
|
| 11436 |
+
const int32_t s0 = tensor->op_params[0];
|
| 11437 |
+
const int32_t s1 = tensor->op_params[1];
|
| 11438 |
+
const int32_t p0 = tensor->op_params[2];
|
| 11439 |
+
const int32_t p1 = tensor->op_params[3];
|
| 11440 |
+
const int32_t d0 = tensor->op_params[4];
|
| 11441 |
+
const int32_t d1 = tensor->op_params[5];
|
| 11442 |
+
tensor_clone = ggml_conv_2d(ggml_ctx, src_clone[0], src_clone[1], s0, s1, p0, p1, d0, d1);
|
| 11443 |
} else if (tensor->op == GGML_OP_LEAKY_RELU) {
|
| 11444 |
const float * op_params = (const float *)tensor->op_params;
|
| 11445 |
tensor_clone = ggml_leaky_relu(ggml_ctx, src_clone[0], op_params[0], false);
|
|
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#version 450
|
| 2 |
+
|
| 3 |
+
#ifdef USE_COLLECTIVES
|
| 4 |
+
# extension GL_KHR_shader_subgroup_shuffle : enable
|
| 5 |
+
#endif
|
| 6 |
+
|
| 7 |
+
#include "types.comp"
|
| 8 |
+
|
| 9 |
+
// Make spec constant
|
| 10 |
+
#define SHMEM_PAD 0
|
| 11 |
+
|
| 12 |
+
// shape notation: [dim(N), ..., dim(0)] -- stride(dim(j)) >= stride(dim(i)) if i > j
|
| 13 |
+
layout(binding = 0) readonly buffer A {
|
| 14 |
+
A_TYPE knl_data[];
|
| 15 |
+
}; // src0 - kernel: [KW, KH, Cin, Cout]
|
| 16 |
+
|
| 17 |
+
layout(binding = 1) readonly buffer B {
|
| 18 |
+
B_TYPE src_data[];
|
| 19 |
+
}; // src1 - input: [W, H, Cin, N] -- channel_first format
|
| 20 |
+
|
| 21 |
+
layout(binding = 2) writeonly buffer D {
|
| 22 |
+
D_TYPE dst_data[];
|
| 23 |
+
}; // dst - result: [OW, OH, Cout, N]
|
| 24 |
+
|
| 25 |
+
layout(push_constant) uniform parameter {
|
| 26 |
+
// I/O channels, batch size
|
| 27 |
+
uint32_t Cout;
|
| 28 |
+
uint32_t Cin;
|
| 29 |
+
uint32_t N;
|
| 30 |
+
|
| 31 |
+
// Tensor spatial sizes: kernel, input, output
|
| 32 |
+
uint32_t KW;
|
| 33 |
+
uint32_t KH;
|
| 34 |
+
uint32_t W;
|
| 35 |
+
uint32_t H;
|
| 36 |
+
uint32_t OW;
|
| 37 |
+
uint32_t OH;
|
| 38 |
+
|
| 39 |
+
// Parameters: stride, padding, dilation - 0=y, 1=x
|
| 40 |
+
uint32_t s0;
|
| 41 |
+
uint32_t s1;
|
| 42 |
+
uint32_t p0;
|
| 43 |
+
uint32_t p1;
|
| 44 |
+
uint32_t d0;
|
| 45 |
+
uint32_t d1;
|
| 46 |
+
|
| 47 |
+
// Strides in elements
|
| 48 |
+
uint32_t nb01;
|
| 49 |
+
uint32_t nb02;
|
| 50 |
+
uint32_t nb03;
|
| 51 |
+
|
| 52 |
+
uint32_t nb11;
|
| 53 |
+
uint32_t nb12;
|
| 54 |
+
uint32_t nb13;
|
| 55 |
+
|
| 56 |
+
uint32_t nb1;
|
| 57 |
+
uint32_t nb2;
|
| 58 |
+
uint32_t nb3;
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
p;
|
| 62 |
+
|
| 63 |
+
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
| 64 |
+
// Blocktile sizes
|
| 65 |
+
layout(constant_id = 1) const uint BS_K = 128;
|
| 66 |
+
layout(constant_id = 2) const uint BS_CRS = 16;
|
| 67 |
+
layout(constant_id = 3) const uint BS_NPQ = 128;
|
| 68 |
+
// Thread-tile sizes
|
| 69 |
+
layout(constant_id = 4) const uint TS_K = 8;
|
| 70 |
+
layout(constant_id = 5) const uint use_collectives = 1;
|
| 71 |
+
|
| 72 |
+
uint32_t tid = gl_LocalInvocationID.x;
|
| 73 |
+
const uint32_t WG_SIZE = gl_WorkGroupSize.x;
|
| 74 |
+
|
| 75 |
+
uint splitWork(uint work_size, uint block_size) {
|
| 76 |
+
return (block_size + work_size - 1) / block_size;
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
uint32_t K = p.Cout;
|
| 80 |
+
uint32_t CRS = p.Cin * p.KH * p.KW;
|
| 81 |
+
uint32_t NPQ = p.N * p.OH * p.OW;
|
| 82 |
+
|
| 83 |
+
uint32_t n_elems_out = K * NPQ;
|
| 84 |
+
|
| 85 |
+
// Number of blocktiles per input
|
| 86 |
+
uint32_t NB_CRS = splitWork(CRS, BS_CRS);
|
| 87 |
+
|
| 88 |
+
const uint32_t Ash_stride = BS_CRS + SHMEM_PAD;
|
| 89 |
+
const uint32_t Bsh_stride = BS_NPQ + SHMEM_PAD;
|
| 90 |
+
|
| 91 |
+
const uint32_t Ash_numel = BS_K * BS_CRS;
|
| 92 |
+
const uint32_t Bsh_numel = BS_CRS * BS_NPQ;
|
| 93 |
+
|
| 94 |
+
const uint32_t Ash_len = BS_K * Ash_stride;
|
| 95 |
+
const uint32_t Bsh_len = BS_CRS * Bsh_stride;
|
| 96 |
+
|
| 97 |
+
shared float Ash[Ash_len]; // K x CRS
|
| 98 |
+
shared float Bsh[Bsh_len]; // CRS x NPQ
|
| 99 |
+
|
| 100 |
+
// Threadtile sizes
|
| 101 |
+
const uint32_t TS_NPQ = BS_K * BS_NPQ / WG_SIZE / TS_K;
|
| 102 |
+
|
| 103 |
+
// Number of threadtiles per blocktile
|
| 104 |
+
const uint32_t NT_K = BS_K / TS_K;
|
| 105 |
+
const uint32_t NT_NPQ = BS_NPQ / TS_NPQ;
|
| 106 |
+
|
| 107 |
+
float regA[TS_K];
|
| 108 |
+
float regB[TS_NPQ];
|
| 109 |
+
float regC[TS_K][TS_NPQ];
|
| 110 |
+
|
| 111 |
+
/*
|
| 112 |
+
Compute
|
| 113 |
+
KxCRS @ CRSxNPQ = K x NPQ
|
| 114 |
+
K=Cout
|
| 115 |
+
C=Cin
|
| 116 |
+
R,S=KH,KW
|
| 117 |
+
P,Q=OH,OW
|
| 118 |
+
*/
|
| 119 |
+
|
| 120 |
+
uint32_t B_idx_K = gl_WorkGroupID.x;
|
| 121 |
+
uint32_t B_idx_NPQ = gl_WorkGroupID.y;
|
| 122 |
+
|
| 123 |
+
uint32_t T_y = tid / NT_NPQ;
|
| 124 |
+
uint32_t T_x = tid % NT_NPQ;
|
| 125 |
+
|
| 126 |
+
uint32_t Ar = tid / BS_CRS;
|
| 127 |
+
uint32_t Ac = tid % BS_CRS;
|
| 128 |
+
const uint32_t ArpWg = WG_SIZE / BS_CRS;
|
| 129 |
+
|
| 130 |
+
uint32_t Br = tid / BS_NPQ;
|
| 131 |
+
uint32_t Bc = tid % BS_NPQ;
|
| 132 |
+
const uint32_t BrpWg = WG_SIZE / BS_NPQ;
|
| 133 |
+
|
| 134 |
+
void main() {
|
| 135 |
+
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
|
| 136 |
+
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
|
| 137 |
+
regC[T_ly][T_lx] = 0.0;
|
| 138 |
+
}
|
| 139 |
+
}
|
| 140 |
+
/* Advance block in CRS dim */
|
| 141 |
+
for (uint32_t B_idx_CRS = 0; B_idx_CRS < NB_CRS; B_idx_CRS++) {
|
| 142 |
+
uint32_t CRS_idx_a;
|
| 143 |
+
uint32_t Cin_idx_a;
|
| 144 |
+
uint32_t KH_idx_a;
|
| 145 |
+
uint32_t KW_idx_a;
|
| 146 |
+
|
| 147 |
+
#ifdef USE_COLLECTIVES
|
| 148 |
+
uint32_t cached_CRS_idx;
|
| 149 |
+
uint32_t cached_Cin_idx;
|
| 150 |
+
uint32_t cached_KH_idx;
|
| 151 |
+
uint32_t cached_KW_idx;
|
| 152 |
+
if (use_collectives == 1) {
|
| 153 |
+
cached_CRS_idx = B_idx_CRS * BS_CRS + gl_SubgroupInvocationID;
|
| 154 |
+
cached_Cin_idx = cached_CRS_idx / (p.KW * p.KH);
|
| 155 |
+
uint32_t cached_CRS_remainder = (cached_CRS_idx - cached_Cin_idx * p.KW * p.KH);
|
| 156 |
+
cached_KH_idx = cached_CRS_remainder / p.KW;
|
| 157 |
+
cached_KW_idx = cached_CRS_remainder - cached_KH_idx * p.KW;
|
| 158 |
+
|
| 159 |
+
CRS_idx_a = subgroupShuffle(cached_CRS_idx, Ac);
|
| 160 |
+
Cin_idx_a = subgroupShuffle(cached_Cin_idx, Ac);
|
| 161 |
+
KH_idx_a = subgroupShuffle(cached_KH_idx, Ac);
|
| 162 |
+
KW_idx_a = subgroupShuffle(cached_KW_idx, Ac);
|
| 163 |
+
} else {
|
| 164 |
+
CRS_idx_a = B_idx_CRS * BS_CRS + Ac; // Global CRS_idx_a (column index of A)
|
| 165 |
+
Cin_idx_a = CRS_idx_a / (p.KW * p.KH);
|
| 166 |
+
uint32_t CRS_remainder = CRS_idx_a - Cin_idx_a * p.KW * p.KH;
|
| 167 |
+
KH_idx_a = CRS_remainder / p.KW;
|
| 168 |
+
KW_idx_a = CRS_remainder - KH_idx_a * p.KW;
|
| 169 |
+
}
|
| 170 |
+
#else
|
| 171 |
+
CRS_idx_a = B_idx_CRS * BS_CRS + Ac; // Global CRS_idx_a (column index of A)
|
| 172 |
+
Cin_idx_a = CRS_idx_a / (p.KW * p.KH);
|
| 173 |
+
CRS_remainder = CRS_idx_a - Cin_idx_a * p.KW * p.KH;
|
| 174 |
+
KH_idx_a = CRS_remainder / p.KW;
|
| 175 |
+
KW_idx_a = CRS_remainder - KH_idx_a * p.KW;
|
| 176 |
+
#endif
|
| 177 |
+
|
| 178 |
+
/* Load kernel to A_block: (BS_K x BS_CRS)*/
|
| 179 |
+
for (uint32_t r_offset = 0; r_offset < BS_K; r_offset += ArpWg) {
|
| 180 |
+
uint32_t B_ly = r_offset + Ar;
|
| 181 |
+
uint32_t B_lx = Ac;
|
| 182 |
+
uint32_t K_idx = B_idx_K * BS_K + B_ly; /* Global K_idx (row index of A)*/
|
| 183 |
+
uint32_t knl_idx = min(KW_idx_a + KH_idx_a * p.nb01 + Cin_idx_a * p.nb02 + K_idx * p.nb03, K * CRS - 1);
|
| 184 |
+
float val = knl_data[knl_idx];
|
| 185 |
+
if (K_idx >= K || CRS_idx_a >= CRS) {
|
| 186 |
+
val = 0.0;
|
| 187 |
+
}
|
| 188 |
+
Ash[B_ly * Ash_stride + B_lx] = val;
|
| 189 |
+
}
|
| 190 |
+
/* Load input to B_block: (BS_CRS x BS_NPQ) */
|
| 191 |
+
for (uint32_t r_offset = 0; r_offset < BS_CRS; r_offset += BrpWg) {
|
| 192 |
+
uint32_t B_ly = r_offset + Br; /* Row index of B block */
|
| 193 |
+
uint32_t B_lx = Bc;
|
| 194 |
+
uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + B_lx; /* Global NPQ index (column index of B) */
|
| 195 |
+
uint32_t N_idx = NPQ_idx / (p.OH * p.OW);
|
| 196 |
+
uint32_t NPQ_remainder = NPQ_idx - N_idx * p.OH * p.OW;
|
| 197 |
+
uint32_t OH_idx = NPQ_remainder / p.OW;
|
| 198 |
+
uint32_t OW_idx = NPQ_remainder - OH_idx * p.OW;
|
| 199 |
+
|
| 200 |
+
uint32_t CRS_idx_b;
|
| 201 |
+
uint32_t Cin_idx_b;
|
| 202 |
+
uint32_t KH_idx_b;
|
| 203 |
+
uint32_t KW_idx_b;
|
| 204 |
+
#ifdef USE_COLLECTIVES
|
| 205 |
+
if (use_collectives == 1) {
|
| 206 |
+
CRS_idx_b = subgroupShuffle(cached_CRS_idx, r_offset + Br);
|
| 207 |
+
Cin_idx_b = subgroupShuffle(cached_Cin_idx, r_offset + Br);
|
| 208 |
+
KH_idx_b = subgroupShuffle(cached_KH_idx, r_offset + Br);
|
| 209 |
+
KW_idx_b = subgroupShuffle(cached_KW_idx, r_offset + Br);
|
| 210 |
+
} else {
|
| 211 |
+
CRS_idx_b = B_idx_CRS * BS_CRS + B_ly; /* Global CRS index (row index of B) */
|
| 212 |
+
Cin_idx_b = CRS_idx_b / (p.KW * p.KH);
|
| 213 |
+
uint32_t CRS_remainder = CRS_idx_b - Cin_idx_b * p.KW * p.KH;
|
| 214 |
+
KH_idx_b = CRS_remainder / p.KW;
|
| 215 |
+
KW_idx_b = CRS_remainder - KH_idx_b * p.KW;
|
| 216 |
+
}
|
| 217 |
+
#else
|
| 218 |
+
CRS_idx_b = B_idx_CRS * BS_CRS + B_ly; /* Global CRS index (row index of B) */
|
| 219 |
+
Cin_idx_b = CRS_idx_b / (p.KW * p.KH);
|
| 220 |
+
uint32_t CRS_remainder = CRS_idx_b - Cin_idx_b * p.KW * p.KH;
|
| 221 |
+
KH_idx_b = CRS_remainder / p.KW;
|
| 222 |
+
KW_idx_b = CRS_remainder - KH_idx_b * p.KW;
|
| 223 |
+
#endif
|
| 224 |
+
|
| 225 |
+
uint32_t H_idx = OH_idx * p.s1 + KH_idx_b * p.d1 - p.p1;
|
| 226 |
+
uint32_t W_idx = OW_idx * p.s0 + KW_idx_b * p.d0 - p.p0;
|
| 227 |
+
uint32_t src_idx =
|
| 228 |
+
min(max(W_idx + H_idx * p.nb11 + Cin_idx_b * p.nb12 + N_idx * p.nb13, 0), p.Cin * p.N * p.W * p.H - 1);
|
| 229 |
+
float val = src_data[src_idx];
|
| 230 |
+
if (CRS_idx_b >= CRS || NPQ_idx >= NPQ || H_idx < 0 || H_idx >= p.H || W_idx < 0 || W_idx >= p.W) {
|
| 231 |
+
val = 0.0;
|
| 232 |
+
}
|
| 233 |
+
Bsh[B_ly * Bsh_stride + B_lx] = val;
|
| 234 |
+
}
|
| 235 |
+
barrier();
|
| 236 |
+
for (uint32_t CRS_lidx = 0; CRS_lidx < BS_CRS; CRS_lidx++) {
|
| 237 |
+
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
|
| 238 |
+
regA[T_ly] = Ash[(T_y * TS_K + T_ly) * Ash_stride + CRS_lidx];
|
| 239 |
+
}
|
| 240 |
+
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
|
| 241 |
+
regB[T_lx] = Bsh[CRS_lidx * Bsh_stride + T_x * TS_NPQ + T_lx];
|
| 242 |
+
}
|
| 243 |
+
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
|
| 244 |
+
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
|
| 245 |
+
regC[T_ly][T_lx] = fma(regA[T_ly], regB[T_lx], regC[T_ly][T_lx]);
|
| 246 |
+
}
|
| 247 |
+
}
|
| 248 |
+
}
|
| 249 |
+
barrier();
|
| 250 |
+
}
|
| 251 |
+
/* Save C* */
|
| 252 |
+
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
|
| 253 |
+
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
|
| 254 |
+
uint32_t K_idx = B_idx_K * BS_K + T_y * TS_K + T_ly;
|
| 255 |
+
uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + T_x * TS_NPQ + T_lx;
|
| 256 |
+
uint32_t N_idx = NPQ_idx / (p.OH * p.OW);
|
| 257 |
+
uint32_t OH_idx = (NPQ_idx - N_idx * p.OH * p.OW) / p.OW;
|
| 258 |
+
uint32_t OW_idx = NPQ_idx - N_idx * p.OH * p.OW - OH_idx * p.OW;
|
| 259 |
+
uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + K_idx * p.nb2 + N_idx * p.nb3;
|
| 260 |
+
if (K_idx < K && NPQ_idx < NPQ) {
|
| 261 |
+
dst_data[dst_idx] = regC[T_ly][T_lx];
|
| 262 |
+
}
|
| 263 |
+
}
|
| 264 |
+
}
|
| 265 |
+
}
|
|
@@ -655,6 +655,8 @@ void process_shaders() {
|
|
| 655 |
|
| 656 |
string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
| 657 |
|
|
|
|
|
|
|
| 658 |
string_to_spv("conv2d_dw_whcn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}}));
|
| 659 |
string_to_spv("conv2d_dw_cwhn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}}));
|
| 660 |
|
|
|
|
| 655 |
|
| 656 |
string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
| 657 |
|
| 658 |
+
string_to_spv("conv2d_f32", "conv2d_mm.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}});
|
| 659 |
+
|
| 660 |
string_to_spv("conv2d_dw_whcn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}}));
|
| 661 |
string_to_spv("conv2d_dw_cwhn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}}));
|
| 662 |
|