etasnadi commited on
Commit
5885084
·
1 Parent(s): 0855a18

ggml: adds CONV_2D op and direct GEMM Vulkan implementation (llama/14316)

Browse files

* ggml/ggml-vulkan/test-backend-ops: adds CONV_2D for Vulkan

* ggml-vulkan: adds f32 scalar shader to compute 2D convolution directly
with gemm (no need for im2col),

* test-backend-ops: adds test_case_ref to check the validity/performance of ops
against reference implementations having different graphs, adds tests

* * Performance fixes: minimized branch divergence, uses collectives to
eliminate redundant calculation, macros removed.

* Kernel shared memory size check

* Updates test-backend-ops to support graphs for performance
measurement.

* * Apple/Win32 compile errors fixed

* Subgroup size used to determine tile size -> fixes llvmpipe errors.

* Collectives disabled by default.

* Intel support is disabled as the performance is poor.

* Conv2d enabled for Intel with disabled collectives, disabled for Apple

* test-backend-ops modifications are reverted

* Trailing spaces and missing override fixed.

* Triggering pipeline relaunch.

* Code formatted with .clang-format.

ggml/src/ggml-vulkan/ggml-vulkan.cpp CHANGED
@@ -483,6 +483,7 @@ struct vk_device_struct {
483
  vk_pipeline pipeline_rwkv_wkv6_f32;
484
  vk_pipeline pipeline_rwkv_wkv7_f32;
485
  vk_pipeline pipeline_opt_step_adamw_f32;
 
486
  vk_pipeline pipeline_conv2d_dw_whcn_f32;
487
  vk_pipeline pipeline_conv2d_dw_cwhn_f32;
488
 
@@ -876,6 +877,38 @@ struct vk_op_rwkv_wkv7_push_constants {
876
  uint32_t H;
877
  };
878
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
879
  struct vk_op_conv2d_dw_push_constants {
880
  uint32_t ne;
881
  uint32_t batches;
@@ -975,18 +1008,45 @@ private:
975
  #endif // GGML_VULKAN_MEMORY_DEBUG
976
 
977
  class vk_perf_logger {
978
- public:
979
  void print_timings() {
 
 
 
 
980
  std::cerr << "----------------\nVulkan Timings:" << std::endl;
981
- for (const auto& t : timings) {
982
- uint64_t total = 0;
983
- for (const auto& time : t.second) {
984
- total += time;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
985
  }
986
- std::cerr << t.first << ": " << t.second.size() << " x " << (total / t.second.size() / 1000.0) << " us" << std::endl;
 
 
 
 
 
 
 
987
  }
988
 
989
  timings.clear();
 
990
  }
991
 
992
  void log_timing(const ggml_tensor * node, uint64_t time) {
@@ -995,22 +1055,45 @@ public:
995
  return;
996
  }
997
  if (node->op == GGML_OP_MUL_MAT || node->op == GGML_OP_MUL_MAT_ID) {
998
- const uint64_t m = node->src[0]->ne[1];
999
- const uint64_t n = node->src[1]->ne[1];
1000
- const uint64_t k = node->src[1]->ne[0];
1001
- std::string name = ggml_op_name(node->op);
1002
  if (n == 1) {
1003
  name += "_VEC m=" + std::to_string(m) + " k=" + std::to_string(k);
1004
  } else {
1005
  name += " m=" + std::to_string(m) + " n=" + std::to_string(n) + " k=" + std::to_string(k);
1006
  }
1007
  timings[name].push_back(time);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1008
  return;
1009
  }
1010
  timings[ggml_op_name(node->op)].push_back(time);
1011
  }
1012
- private:
1013
  std::map<std::string, std::vector<uint64_t>> timings;
 
1014
  };
1015
 
1016
  struct ggml_backend_vk_context {
@@ -2113,6 +2196,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2113
  }
2114
  compile_count++;
2115
  }
 
2116
  compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), spv_size, spv_data, entrypoint,
2117
  parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size));
2118
  };
@@ -2962,6 +3046,42 @@ static void ggml_vk_load_shaders(vk_device& device) {
2962
 
2963
  ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2964
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2965
  ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
2966
  ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f32, "conv2d_dw_cwhn_f32", conv2d_dw_cwhn_f32_len, conv2d_dw_cwhn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
2967
 
@@ -6837,6 +6957,12 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
6837
  return ctx->device->pipeline_leaky_relu_f32;
6838
  }
6839
  return nullptr;
 
 
 
 
 
 
6840
  case GGML_OP_CONV_2D_DW:
6841
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6842
  if (ggml_is_contiguous(src1)) {
@@ -7159,6 +7285,31 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
7159
  const uint32_t OW = dst->ne[0];
7160
  elements = { N * OC * OH * OW, 1, 1};
7161
  } break;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7162
  case GGML_OP_ADD:
7163
  case GGML_OP_SUB:
7164
  case GGML_OP_DIV:
@@ -8025,6 +8176,55 @@ static void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, c
8025
  }, dryrun);
8026
  }
8027
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8028
  static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
8029
  vk_op_conv2d_dw_push_constants p{};
8030
  p.ne = ggml_nelements(dst);
@@ -9087,6 +9287,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
9087
  case GGML_OP_TIMESTEP_EMBEDDING:
9088
  case GGML_OP_CONV_TRANSPOSE_1D:
9089
  case GGML_OP_POOL_2D:
 
9090
  case GGML_OP_CONV_2D_DW:
9091
  case GGML_OP_RWKV_WKV6:
9092
  case GGML_OP_RWKV_WKV7:
@@ -9154,6 +9355,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
9154
  case GGML_OP_TIMESTEP_EMBEDDING:
9155
  case GGML_OP_CONV_TRANSPOSE_1D:
9156
  case GGML_OP_POOL_2D:
 
9157
  case GGML_OP_CONV_2D_DW:
9158
  case GGML_OP_LEAKY_RELU:
9159
  {
@@ -9360,6 +9562,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
9360
  case GGML_OP_POOL_2D:
9361
  ggml_vk_pool_2d(ctx, compute_ctx, src0, node, dryrun);
9362
 
 
 
 
 
9363
  break;
9364
  case GGML_OP_CONV_2D_DW:
9365
  ggml_vk_conv_2d_dw(ctx, compute_ctx, src0, src1, node, dryrun);
@@ -9490,6 +9696,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
9490
  case GGML_OP_TIMESTEP_EMBEDDING:
9491
  case GGML_OP_CONV_TRANSPOSE_1D:
9492
  case GGML_OP_POOL_2D:
 
9493
  case GGML_OP_CONV_2D_DW:
9494
  case GGML_OP_RWKV_WKV6:
9495
  case GGML_OP_RWKV_WKV7:
@@ -10071,6 +10278,12 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
10071
  ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false);
10072
  if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) {
10073
  total_mat_mul_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
 
 
 
 
 
 
10074
  }
10075
  i += ctx->num_additional_fused_ops;
10076
  ctx->num_additional_fused_ops = 0;
@@ -10647,6 +10860,20 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
10647
  return true;
10648
  case GGML_OP_CONV_TRANSPOSE_1D:
10649
  return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10650
  default:
10651
  return false;
10652
  }
@@ -11205,6 +11432,14 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
11205
  const int32_t p1 = tensor->op_params[6];
11206
 
11207
  tensor_clone = ggml_pool_2d(ggml_ctx, src_clone[0], op, k0, k1, s0, s1, p0, p1);
 
 
 
 
 
 
 
 
11208
  } else if (tensor->op == GGML_OP_LEAKY_RELU) {
11209
  const float * op_params = (const float *)tensor->op_params;
11210
  tensor_clone = ggml_leaky_relu(ggml_ctx, src_clone[0], op_params[0], false);
 
483
  vk_pipeline pipeline_rwkv_wkv6_f32;
484
  vk_pipeline pipeline_rwkv_wkv7_f32;
485
  vk_pipeline pipeline_opt_step_adamw_f32;
486
+ vk_pipeline pipeline_conv2d_f32;
487
  vk_pipeline pipeline_conv2d_dw_whcn_f32;
488
  vk_pipeline pipeline_conv2d_dw_cwhn_f32;
489
 
 
877
  uint32_t H;
878
  };
879
 
880
+ struct vk_op_conv2d_push_constants {
881
+ uint32_t Cout;
882
+ uint32_t Cin;
883
+ uint32_t N;
884
+
885
+ uint32_t KW;
886
+ uint32_t KH;
887
+ uint32_t W;
888
+ uint32_t H;
889
+ uint32_t OW;
890
+ uint32_t OH;
891
+
892
+ uint32_t s0;
893
+ uint32_t s1;
894
+ uint32_t p0;
895
+ uint32_t p1;
896
+ uint32_t d0;
897
+ uint32_t d1;
898
+
899
+ uint32_t nb01;
900
+ uint32_t nb02;
901
+ uint32_t nb03;
902
+
903
+ uint32_t nb11;
904
+ uint32_t nb12;
905
+ uint32_t nb13;
906
+
907
+ uint32_t nb1;
908
+ uint32_t nb2;
909
+ uint32_t nb3;
910
+ };
911
+
912
  struct vk_op_conv2d_dw_push_constants {
913
  uint32_t ne;
914
  uint32_t batches;
 
1008
  #endif // GGML_VULKAN_MEMORY_DEBUG
1009
 
1010
  class vk_perf_logger {
1011
+ public:
1012
  void print_timings() {
1013
+ if (timings.empty()) {
1014
+ return;
1015
+ }
1016
+ uint64_t total_all_op_times = 0;
1017
  std::cerr << "----------------\nVulkan Timings:" << std::endl;
1018
+ for (const auto & t : timings) {
1019
+ uint64_t total_op_times = 0;
1020
+ for (const auto & time : t.second) {
1021
+ total_op_times += time;
1022
+ }
1023
+ std::cerr << t.first << ": " << t.second.size() << " x " << (total_op_times / t.second.size() / 1000.0)
1024
+ << " us";
1025
+
1026
+ // If we have as many flops entries as timing entries for the op, then compute and log the flops/S.
1027
+ auto it = flops.find(t.first);
1028
+ if (it != flops.end() && (it->second).size() == t.second.size()) {
1029
+ uint64_t total_op_flops = 0;
1030
+ for (const auto & elem : it->second) {
1031
+ total_op_flops += elem;
1032
+ }
1033
+ std::cerr << " ("
1034
+ << (double(total_op_flops) / (1000.0 * 1000.0 * 1000.0)) /
1035
+ (double(total_op_times) / (1000.0 * 1000.0 * 1000.0))
1036
+ << " GFLOPS/s)";
1037
  }
1038
+
1039
+ total_all_op_times += total_op_times;
1040
+
1041
+ std::cerr << std::endl;
1042
+ }
1043
+
1044
+ if (timings.size() > 0) {
1045
+ std::cerr << "Total time: " << total_all_op_times / 1000.0 << " us." << std::endl;
1046
  }
1047
 
1048
  timings.clear();
1049
+ flops.clear();
1050
  }
1051
 
1052
  void log_timing(const ggml_tensor * node, uint64_t time) {
 
1055
  return;
1056
  }
1057
  if (node->op == GGML_OP_MUL_MAT || node->op == GGML_OP_MUL_MAT_ID) {
1058
+ const uint64_t m = node->src[0]->ne[1];
1059
+ const uint64_t n = node->src[1]->ne[1];
1060
+ const uint64_t k = node->src[1]->ne[0];
1061
+ std::string name = ggml_op_name(node->op);
1062
  if (n == 1) {
1063
  name += "_VEC m=" + std::to_string(m) + " k=" + std::to_string(k);
1064
  } else {
1065
  name += " m=" + std::to_string(m) + " n=" + std::to_string(n) + " k=" + std::to_string(k);
1066
  }
1067
  timings[name].push_back(time);
1068
+ flops[name].push_back(m * n * (k + (k - 1)));
1069
+ return;
1070
+ }
1071
+ if (node->op == GGML_OP_CONV_2D) {
1072
+ std::string name = ggml_op_name(node->op);
1073
+ ggml_tensor * knl = node->src[0];
1074
+ uint64_t OW = node->ne[0];
1075
+ uint64_t OH = node->ne[1];
1076
+ uint64_t N = node->ne[3];
1077
+ uint64_t Cout = node->ne[2];
1078
+ uint64_t KW = knl->ne[0];
1079
+ uint64_t KH = knl->ne[1];
1080
+ uint64_t Cin = knl->ne[2];
1081
+ // KxCRS @ CRSxNPQ = KxNPQ -> M=K, K=CRS, N=NPQ
1082
+ uint64_t size_M = Cout;
1083
+ uint64_t size_K = Cin * KW * KH;
1084
+ uint64_t size_N = N * OW * OH;
1085
+ uint64_t n_flops = size_M * size_N * (size_K + (size_K - 1));
1086
+ name += " M=Cout=" + std::to_string(size_M) + ", K=Cin*KW*KH=" + std::to_string(size_K) +
1087
+ ", N=N*OW*OH=" + std::to_string(size_N);
1088
+ flops[name].push_back(n_flops);
1089
+ timings[name].push_back(time);
1090
  return;
1091
  }
1092
  timings[ggml_op_name(node->op)].push_back(time);
1093
  }
1094
+ private:
1095
  std::map<std::string, std::vector<uint64_t>> timings;
1096
+ std::map<std::string, std::vector<uint64_t>> flops;
1097
  };
1098
 
1099
  struct ggml_backend_vk_context {
 
2196
  }
2197
  compile_count++;
2198
  }
2199
+
2200
  compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), spv_size, spv_data, entrypoint,
2201
  parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size));
2202
  };
 
3046
 
3047
  ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
3048
 
3049
+ // conv2d
3050
+ uint32_t conv2d_WG_SIZE = 256;
3051
+ uint32_t conv2d_BS_K = 128;
3052
+ uint32_t conv2d_BS_CRS = 16;
3053
+ uint32_t use_collectives = 0; // Enables subgroup ops for preventing the re-calculation of indices.
3054
+ if (device->subgroup_shuffle &&
3055
+ device->vendor_id != VK_VENDOR_ID_INTEL) { // Do not enable collectives on Intel, see PR 14316
3056
+ use_collectives = 1;
3057
+ conv2d_BS_CRS = std::min(
3058
+ device->subgroup_size,
3059
+ conv2d_BS_CRS); // CRS block size should be capped at sugroup size for correctness when shuffle is used.
3060
+ }
3061
+ uint32_t conv2d_BS_NPQ = 128;
3062
+ uint32_t conv2d_TS_K = 8;
3063
+ uint32_t conv2d_shmem_req =
3064
+ (conv2d_BS_K * (conv2d_BS_CRS + 1) + conv2d_BS_CRS * (conv2d_BS_NPQ + 1)) * sizeof(float);
3065
+ if (device->properties.limits.maxComputeSharedMemorySize < conv2d_shmem_req) {
3066
+ conv2d_BS_CRS = 8;
3067
+ if (use_collectives) {
3068
+ conv2d_BS_CRS = std::min(device->subgroup_size, conv2d_BS_CRS);
3069
+ }
3070
+ }
3071
+
3072
+ if (use_collectives) {
3073
+ ggml_vk_create_pipeline(
3074
+ device, device->pipeline_conv2d_f32, "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3,
3075
+ sizeof(vk_op_conv2d_push_constants), { conv2d_BS_K, conv2d_BS_NPQ, 1 },
3076
+ { conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives }, 1, true, true);
3077
+ } else {
3078
+ ggml_vk_create_pipeline(
3079
+ device, device->pipeline_conv2d_f32, "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3,
3080
+ sizeof(vk_op_conv2d_push_constants), { conv2d_BS_K, conv2d_BS_NPQ, 1 },
3081
+ { conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives }, 1, true,
3082
+ false);
3083
+ }
3084
+
3085
  ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
3086
  ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f32, "conv2d_dw_cwhn_f32", conv2d_dw_cwhn_f32_len, conv2d_dw_cwhn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
3087
 
 
6957
  return ctx->device->pipeline_leaky_relu_f32;
6958
  }
6959
  return nullptr;
6960
+ case GGML_OP_CONV_2D:
6961
+ if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 &&
6962
+ ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) {
6963
+ return ctx->device->pipeline_conv2d_f32;
6964
+ }
6965
+ return nullptr;
6966
  case GGML_OP_CONV_2D_DW:
6967
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6968
  if (ggml_is_contiguous(src1)) {
 
7285
  const uint32_t OW = dst->ne[0];
7286
  elements = { N * OC * OH * OW, 1, 1};
7287
  } break;
7288
+ case GGML_OP_CONV_2D:
7289
+ {
7290
+ // src0 - kernel: [KW, KH, Cin, Cout]
7291
+ // src1 - input: [W, H, Cin, N]
7292
+ // dst - result: [OW, OH, Cout, N]
7293
+
7294
+ // Copied from ggml.c: int64_t ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p, int d)
7295
+ auto calc_conv_output_size = [](int64_t ins, int64_t ks, int s, int p, int d) -> int64_t {
7296
+ return (ins + 2 * p - d * (ks - 1) - 1) / s + 1;
7297
+ };
7298
+ // parallelize in {OW/BS_K, OH/BS_NPQ, 1}
7299
+ int64_t W = src1->ne[0];
7300
+ int64_t H = src1->ne[1];
7301
+ int64_t KW = src0->ne[0];
7302
+ int64_t KH = src0->ne[1];
7303
+ int64_t Cout = src0->ne[3];
7304
+ int64_t N = src1->ne[3];
7305
+ int64_t OH = calc_conv_output_size(H, KH, dst->op_params[1], dst->op_params[3], dst->op_params[5]);
7306
+ int64_t OW = calc_conv_output_size(W, KW, dst->op_params[0], dst->op_params[2], dst->op_params[4]);
7307
+ int64_t NPQ = N * OW * OH;
7308
+
7309
+ // Tile output matrix to (K/NB_K, NPQ/NB_NPQ, 1) workgroups
7310
+ elements = { static_cast<uint32_t>(Cout), static_cast<uint32_t>(NPQ), 1 };
7311
+ }
7312
+ break;
7313
  case GGML_OP_ADD:
7314
  case GGML_OP_SUB:
7315
  case GGML_OP_DIV:
 
8176
  }, dryrun);
8177
  }
8178
 
8179
+ static void ggml_vk_conv_2d(ggml_backend_vk_context * ctx, vk_context & subctx, const ggml_tensor * src0,
8180
+ const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
8181
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
8182
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
8183
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
8184
+
8185
+ GGML_TENSOR_BINARY_OP_LOCALS
8186
+
8187
+ GGML_ASSERT(nb00 == sizeof(float));
8188
+ GGML_ASSERT(nb10 == sizeof(float));
8189
+ GGML_ASSERT(nb0 == sizeof(float));
8190
+
8191
+ vk_op_conv2d_push_constants p{};
8192
+ p.Cout = static_cast<uint32_t>(ne03);
8193
+ p.Cin = static_cast<uint32_t>(ne02);
8194
+ p.N = static_cast<uint32_t>(ne13);
8195
+
8196
+ p.KW = static_cast<uint32_t>(ne00);
8197
+ p.KH = static_cast<uint32_t>(ne01);
8198
+ p.W = static_cast<uint32_t>(ne10);
8199
+ p.H = static_cast<uint32_t>(ne11);
8200
+ p.OW = static_cast<uint32_t>(ne0);
8201
+ p.OH = static_cast<uint32_t>(ne1);
8202
+
8203
+ p.s0 = static_cast<uint32_t>(dst->op_params[0]);
8204
+ p.s1 = static_cast<uint32_t>(dst->op_params[1]);
8205
+ p.p0 = static_cast<uint32_t>(dst->op_params[2]);
8206
+ p.p1 = static_cast<uint32_t>(dst->op_params[3]);
8207
+ p.d0 = static_cast<uint32_t>(dst->op_params[4]);
8208
+ p.d1 = static_cast<uint32_t>(dst->op_params[5]);
8209
+
8210
+ p.nb01 = static_cast<uint32_t>(nb01 / nb00);
8211
+ p.nb02 = static_cast<uint32_t>(nb02 / nb00);
8212
+ p.nb03 = static_cast<uint32_t>(nb03 / nb00);
8213
+
8214
+ p.nb11 = static_cast<uint32_t>(nb11 / nb10);
8215
+ p.nb12 = static_cast<uint32_t>(nb12 / nb10);
8216
+ p.nb13 = static_cast<uint32_t>(nb13 / nb10);
8217
+
8218
+ p.nb1 = static_cast<uint32_t>(nb1 / nb0);
8219
+ p.nb2 = static_cast<uint32_t>(nb2 / nb0);
8220
+ p.nb3 = static_cast<uint32_t>(nb3 / nb0);
8221
+
8222
+ GGML_ASSERT(ne03 == ne2);
8223
+ GGML_ASSERT(ne02 == ne12);
8224
+
8225
+ ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONV_2D, std::move(p), dryrun);
8226
+ }
8227
+
8228
  static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
8229
  vk_op_conv2d_dw_push_constants p{};
8230
  p.ne = ggml_nelements(dst);
 
9287
  case GGML_OP_TIMESTEP_EMBEDDING:
9288
  case GGML_OP_CONV_TRANSPOSE_1D:
9289
  case GGML_OP_POOL_2D:
9290
+ case GGML_OP_CONV_2D:
9291
  case GGML_OP_CONV_2D_DW:
9292
  case GGML_OP_RWKV_WKV6:
9293
  case GGML_OP_RWKV_WKV7:
 
9355
  case GGML_OP_TIMESTEP_EMBEDDING:
9356
  case GGML_OP_CONV_TRANSPOSE_1D:
9357
  case GGML_OP_POOL_2D:
9358
+ case GGML_OP_CONV_2D:
9359
  case GGML_OP_CONV_2D_DW:
9360
  case GGML_OP_LEAKY_RELU:
9361
  {
 
9562
  case GGML_OP_POOL_2D:
9563
  ggml_vk_pool_2d(ctx, compute_ctx, src0, node, dryrun);
9564
 
9565
+ break;
9566
+ case GGML_OP_CONV_2D:
9567
+ ggml_vk_conv_2d(ctx, compute_ctx, src0, src1, node, dryrun);
9568
+
9569
  break;
9570
  case GGML_OP_CONV_2D_DW:
9571
  ggml_vk_conv_2d_dw(ctx, compute_ctx, src0, src1, node, dryrun);
 
9696
  case GGML_OP_TIMESTEP_EMBEDDING:
9697
  case GGML_OP_CONV_TRANSPOSE_1D:
9698
  case GGML_OP_POOL_2D:
9699
+ case GGML_OP_CONV_2D:
9700
  case GGML_OP_CONV_2D_DW:
9701
  case GGML_OP_RWKV_WKV6:
9702
  case GGML_OP_RWKV_WKV7:
 
10278
  ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false);
10279
  if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) {
10280
  total_mat_mul_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
10281
+ } else if (cgraph->nodes[i]->op == GGML_OP_CONV_2D) {
10282
+ // Return CRSxNPQxsizeof(*) to account as many bytes as mul_mat has in im2col->mul_mat mode.
10283
+ auto CRS_size =
10284
+ cgraph->nodes[i]->src[0]->ne[0] * cgraph->nodes[i]->src[0]->ne[1] * cgraph->nodes[i]->src[0]->ne[2];
10285
+ auto NPQ_size = cgraph->nodes[i]->ne[0] * cgraph->nodes[i]->ne[1] * cgraph->nodes[i]->ne[3];
10286
+ total_mat_mul_bytes += NPQ_size * CRS_size * ggml_type_size(cgraph->nodes[i]->type);
10287
  }
10288
  i += ctx->num_additional_fused_ops;
10289
  ctx->num_additional_fused_ops = 0;
 
10860
  return true;
10861
  case GGML_OP_CONV_TRANSPOSE_1D:
10862
  return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
10863
+ case GGML_OP_CONV_2D:
10864
+ {
10865
+ // Op is disabled for Apple because it segfaults at pipeline create time on MoltenVK
10866
+ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
10867
+ const vk_device& device = ggml_vk_get_device(ctx->device);
10868
+ bool is_Apple = ggml_vk_get_device(ctx->device)->vendor_id == VK_VENDOR_ID_APPLE;
10869
+ // Channel-contiguous format is not supported yet.
10870
+ return (op->src[0]->type == GGML_TYPE_F32 &&
10871
+ op->src[1]->type == GGML_TYPE_F32 &&
10872
+ op->type == GGML_TYPE_F32 &&
10873
+ ggml_is_contiguous(op->src[0]) &&
10874
+ ggml_is_contiguous(op->src[1]) &&
10875
+ ggml_is_contiguous(op)) && !is_Apple;
10876
+ }
10877
  default:
10878
  return false;
10879
  }
 
11432
  const int32_t p1 = tensor->op_params[6];
11433
 
11434
  tensor_clone = ggml_pool_2d(ggml_ctx, src_clone[0], op, k0, k1, s0, s1, p0, p1);
11435
+ } else if (tensor->op == GGML_OP_CONV_2D) {
11436
+ const int32_t s0 = tensor->op_params[0];
11437
+ const int32_t s1 = tensor->op_params[1];
11438
+ const int32_t p0 = tensor->op_params[2];
11439
+ const int32_t p1 = tensor->op_params[3];
11440
+ const int32_t d0 = tensor->op_params[4];
11441
+ const int32_t d1 = tensor->op_params[5];
11442
+ tensor_clone = ggml_conv_2d(ggml_ctx, src_clone[0], src_clone[1], s0, s1, p0, p1, d0, d1);
11443
  } else if (tensor->op == GGML_OP_LEAKY_RELU) {
11444
  const float * op_params = (const float *)tensor->op_params;
11445
  tensor_clone = ggml_leaky_relu(ggml_ctx, src_clone[0], op_params[0], false);
ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #version 450
2
+
3
+ #ifdef USE_COLLECTIVES
4
+ # extension GL_KHR_shader_subgroup_shuffle : enable
5
+ #endif
6
+
7
+ #include "types.comp"
8
+
9
+ // Make spec constant
10
+ #define SHMEM_PAD 0
11
+
12
+ // shape notation: [dim(N), ..., dim(0)] -- stride(dim(j)) >= stride(dim(i)) if i > j
13
+ layout(binding = 0) readonly buffer A {
14
+ A_TYPE knl_data[];
15
+ }; // src0 - kernel: [KW, KH, Cin, Cout]
16
+
17
+ layout(binding = 1) readonly buffer B {
18
+ B_TYPE src_data[];
19
+ }; // src1 - input: [W, H, Cin, N] -- channel_first format
20
+
21
+ layout(binding = 2) writeonly buffer D {
22
+ D_TYPE dst_data[];
23
+ }; // dst - result: [OW, OH, Cout, N]
24
+
25
+ layout(push_constant) uniform parameter {
26
+ // I/O channels, batch size
27
+ uint32_t Cout;
28
+ uint32_t Cin;
29
+ uint32_t N;
30
+
31
+ // Tensor spatial sizes: kernel, input, output
32
+ uint32_t KW;
33
+ uint32_t KH;
34
+ uint32_t W;
35
+ uint32_t H;
36
+ uint32_t OW;
37
+ uint32_t OH;
38
+
39
+ // Parameters: stride, padding, dilation - 0=y, 1=x
40
+ uint32_t s0;
41
+ uint32_t s1;
42
+ uint32_t p0;
43
+ uint32_t p1;
44
+ uint32_t d0;
45
+ uint32_t d1;
46
+
47
+ // Strides in elements
48
+ uint32_t nb01;
49
+ uint32_t nb02;
50
+ uint32_t nb03;
51
+
52
+ uint32_t nb11;
53
+ uint32_t nb12;
54
+ uint32_t nb13;
55
+
56
+ uint32_t nb1;
57
+ uint32_t nb2;
58
+ uint32_t nb3;
59
+ }
60
+
61
+ p;
62
+
63
+ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
64
+ // Blocktile sizes
65
+ layout(constant_id = 1) const uint BS_K = 128;
66
+ layout(constant_id = 2) const uint BS_CRS = 16;
67
+ layout(constant_id = 3) const uint BS_NPQ = 128;
68
+ // Thread-tile sizes
69
+ layout(constant_id = 4) const uint TS_K = 8;
70
+ layout(constant_id = 5) const uint use_collectives = 1;
71
+
72
+ uint32_t tid = gl_LocalInvocationID.x;
73
+ const uint32_t WG_SIZE = gl_WorkGroupSize.x;
74
+
75
+ uint splitWork(uint work_size, uint block_size) {
76
+ return (block_size + work_size - 1) / block_size;
77
+ }
78
+
79
+ uint32_t K = p.Cout;
80
+ uint32_t CRS = p.Cin * p.KH * p.KW;
81
+ uint32_t NPQ = p.N * p.OH * p.OW;
82
+
83
+ uint32_t n_elems_out = K * NPQ;
84
+
85
+ // Number of blocktiles per input
86
+ uint32_t NB_CRS = splitWork(CRS, BS_CRS);
87
+
88
+ const uint32_t Ash_stride = BS_CRS + SHMEM_PAD;
89
+ const uint32_t Bsh_stride = BS_NPQ + SHMEM_PAD;
90
+
91
+ const uint32_t Ash_numel = BS_K * BS_CRS;
92
+ const uint32_t Bsh_numel = BS_CRS * BS_NPQ;
93
+
94
+ const uint32_t Ash_len = BS_K * Ash_stride;
95
+ const uint32_t Bsh_len = BS_CRS * Bsh_stride;
96
+
97
+ shared float Ash[Ash_len]; // K x CRS
98
+ shared float Bsh[Bsh_len]; // CRS x NPQ
99
+
100
+ // Threadtile sizes
101
+ const uint32_t TS_NPQ = BS_K * BS_NPQ / WG_SIZE / TS_K;
102
+
103
+ // Number of threadtiles per blocktile
104
+ const uint32_t NT_K = BS_K / TS_K;
105
+ const uint32_t NT_NPQ = BS_NPQ / TS_NPQ;
106
+
107
+ float regA[TS_K];
108
+ float regB[TS_NPQ];
109
+ float regC[TS_K][TS_NPQ];
110
+
111
+ /*
112
+ Compute
113
+ KxCRS @ CRSxNPQ = K x NPQ
114
+ K=Cout
115
+ C=Cin
116
+ R,S=KH,KW
117
+ P,Q=OH,OW
118
+ */
119
+
120
+ uint32_t B_idx_K = gl_WorkGroupID.x;
121
+ uint32_t B_idx_NPQ = gl_WorkGroupID.y;
122
+
123
+ uint32_t T_y = tid / NT_NPQ;
124
+ uint32_t T_x = tid % NT_NPQ;
125
+
126
+ uint32_t Ar = tid / BS_CRS;
127
+ uint32_t Ac = tid % BS_CRS;
128
+ const uint32_t ArpWg = WG_SIZE / BS_CRS;
129
+
130
+ uint32_t Br = tid / BS_NPQ;
131
+ uint32_t Bc = tid % BS_NPQ;
132
+ const uint32_t BrpWg = WG_SIZE / BS_NPQ;
133
+
134
+ void main() {
135
+ for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
136
+ for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
137
+ regC[T_ly][T_lx] = 0.0;
138
+ }
139
+ }
140
+ /* Advance block in CRS dim */
141
+ for (uint32_t B_idx_CRS = 0; B_idx_CRS < NB_CRS; B_idx_CRS++) {
142
+ uint32_t CRS_idx_a;
143
+ uint32_t Cin_idx_a;
144
+ uint32_t KH_idx_a;
145
+ uint32_t KW_idx_a;
146
+
147
+ #ifdef USE_COLLECTIVES
148
+ uint32_t cached_CRS_idx;
149
+ uint32_t cached_Cin_idx;
150
+ uint32_t cached_KH_idx;
151
+ uint32_t cached_KW_idx;
152
+ if (use_collectives == 1) {
153
+ cached_CRS_idx = B_idx_CRS * BS_CRS + gl_SubgroupInvocationID;
154
+ cached_Cin_idx = cached_CRS_idx / (p.KW * p.KH);
155
+ uint32_t cached_CRS_remainder = (cached_CRS_idx - cached_Cin_idx * p.KW * p.KH);
156
+ cached_KH_idx = cached_CRS_remainder / p.KW;
157
+ cached_KW_idx = cached_CRS_remainder - cached_KH_idx * p.KW;
158
+
159
+ CRS_idx_a = subgroupShuffle(cached_CRS_idx, Ac);
160
+ Cin_idx_a = subgroupShuffle(cached_Cin_idx, Ac);
161
+ KH_idx_a = subgroupShuffle(cached_KH_idx, Ac);
162
+ KW_idx_a = subgroupShuffle(cached_KW_idx, Ac);
163
+ } else {
164
+ CRS_idx_a = B_idx_CRS * BS_CRS + Ac; // Global CRS_idx_a (column index of A)
165
+ Cin_idx_a = CRS_idx_a / (p.KW * p.KH);
166
+ uint32_t CRS_remainder = CRS_idx_a - Cin_idx_a * p.KW * p.KH;
167
+ KH_idx_a = CRS_remainder / p.KW;
168
+ KW_idx_a = CRS_remainder - KH_idx_a * p.KW;
169
+ }
170
+ #else
171
+ CRS_idx_a = B_idx_CRS * BS_CRS + Ac; // Global CRS_idx_a (column index of A)
172
+ Cin_idx_a = CRS_idx_a / (p.KW * p.KH);
173
+ CRS_remainder = CRS_idx_a - Cin_idx_a * p.KW * p.KH;
174
+ KH_idx_a = CRS_remainder / p.KW;
175
+ KW_idx_a = CRS_remainder - KH_idx_a * p.KW;
176
+ #endif
177
+
178
+ /* Load kernel to A_block: (BS_K x BS_CRS)*/
179
+ for (uint32_t r_offset = 0; r_offset < BS_K; r_offset += ArpWg) {
180
+ uint32_t B_ly = r_offset + Ar;
181
+ uint32_t B_lx = Ac;
182
+ uint32_t K_idx = B_idx_K * BS_K + B_ly; /* Global K_idx (row index of A)*/
183
+ uint32_t knl_idx = min(KW_idx_a + KH_idx_a * p.nb01 + Cin_idx_a * p.nb02 + K_idx * p.nb03, K * CRS - 1);
184
+ float val = knl_data[knl_idx];
185
+ if (K_idx >= K || CRS_idx_a >= CRS) {
186
+ val = 0.0;
187
+ }
188
+ Ash[B_ly * Ash_stride + B_lx] = val;
189
+ }
190
+ /* Load input to B_block: (BS_CRS x BS_NPQ) */
191
+ for (uint32_t r_offset = 0; r_offset < BS_CRS; r_offset += BrpWg) {
192
+ uint32_t B_ly = r_offset + Br; /* Row index of B block */
193
+ uint32_t B_lx = Bc;
194
+ uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + B_lx; /* Global NPQ index (column index of B) */
195
+ uint32_t N_idx = NPQ_idx / (p.OH * p.OW);
196
+ uint32_t NPQ_remainder = NPQ_idx - N_idx * p.OH * p.OW;
197
+ uint32_t OH_idx = NPQ_remainder / p.OW;
198
+ uint32_t OW_idx = NPQ_remainder - OH_idx * p.OW;
199
+
200
+ uint32_t CRS_idx_b;
201
+ uint32_t Cin_idx_b;
202
+ uint32_t KH_idx_b;
203
+ uint32_t KW_idx_b;
204
+ #ifdef USE_COLLECTIVES
205
+ if (use_collectives == 1) {
206
+ CRS_idx_b = subgroupShuffle(cached_CRS_idx, r_offset + Br);
207
+ Cin_idx_b = subgroupShuffle(cached_Cin_idx, r_offset + Br);
208
+ KH_idx_b = subgroupShuffle(cached_KH_idx, r_offset + Br);
209
+ KW_idx_b = subgroupShuffle(cached_KW_idx, r_offset + Br);
210
+ } else {
211
+ CRS_idx_b = B_idx_CRS * BS_CRS + B_ly; /* Global CRS index (row index of B) */
212
+ Cin_idx_b = CRS_idx_b / (p.KW * p.KH);
213
+ uint32_t CRS_remainder = CRS_idx_b - Cin_idx_b * p.KW * p.KH;
214
+ KH_idx_b = CRS_remainder / p.KW;
215
+ KW_idx_b = CRS_remainder - KH_idx_b * p.KW;
216
+ }
217
+ #else
218
+ CRS_idx_b = B_idx_CRS * BS_CRS + B_ly; /* Global CRS index (row index of B) */
219
+ Cin_idx_b = CRS_idx_b / (p.KW * p.KH);
220
+ uint32_t CRS_remainder = CRS_idx_b - Cin_idx_b * p.KW * p.KH;
221
+ KH_idx_b = CRS_remainder / p.KW;
222
+ KW_idx_b = CRS_remainder - KH_idx_b * p.KW;
223
+ #endif
224
+
225
+ uint32_t H_idx = OH_idx * p.s1 + KH_idx_b * p.d1 - p.p1;
226
+ uint32_t W_idx = OW_idx * p.s0 + KW_idx_b * p.d0 - p.p0;
227
+ uint32_t src_idx =
228
+ min(max(W_idx + H_idx * p.nb11 + Cin_idx_b * p.nb12 + N_idx * p.nb13, 0), p.Cin * p.N * p.W * p.H - 1);
229
+ float val = src_data[src_idx];
230
+ if (CRS_idx_b >= CRS || NPQ_idx >= NPQ || H_idx < 0 || H_idx >= p.H || W_idx < 0 || W_idx >= p.W) {
231
+ val = 0.0;
232
+ }
233
+ Bsh[B_ly * Bsh_stride + B_lx] = val;
234
+ }
235
+ barrier();
236
+ for (uint32_t CRS_lidx = 0; CRS_lidx < BS_CRS; CRS_lidx++) {
237
+ for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
238
+ regA[T_ly] = Ash[(T_y * TS_K + T_ly) * Ash_stride + CRS_lidx];
239
+ }
240
+ for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
241
+ regB[T_lx] = Bsh[CRS_lidx * Bsh_stride + T_x * TS_NPQ + T_lx];
242
+ }
243
+ for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
244
+ for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
245
+ regC[T_ly][T_lx] = fma(regA[T_ly], regB[T_lx], regC[T_ly][T_lx]);
246
+ }
247
+ }
248
+ }
249
+ barrier();
250
+ }
251
+ /* Save C* */
252
+ for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
253
+ for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
254
+ uint32_t K_idx = B_idx_K * BS_K + T_y * TS_K + T_ly;
255
+ uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + T_x * TS_NPQ + T_lx;
256
+ uint32_t N_idx = NPQ_idx / (p.OH * p.OW);
257
+ uint32_t OH_idx = (NPQ_idx - N_idx * p.OH * p.OW) / p.OW;
258
+ uint32_t OW_idx = NPQ_idx - N_idx * p.OH * p.OW - OH_idx * p.OW;
259
+ uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + K_idx * p.nb2 + N_idx * p.nb3;
260
+ if (K_idx < K && NPQ_idx < NPQ) {
261
+ dst_data[dst_idx] = regC[T_ly][T_lx];
262
+ }
263
+ }
264
+ }
265
+ }
ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp CHANGED
@@ -655,6 +655,8 @@ void process_shaders() {
655
 
656
  string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
657
 
 
 
658
  string_to_spv("conv2d_dw_whcn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}}));
659
  string_to_spv("conv2d_dw_cwhn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}}));
660
 
 
655
 
656
  string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
657
 
658
+ string_to_spv("conv2d_f32", "conv2d_mm.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}});
659
+
660
  string_to_spv("conv2d_dw_whcn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}}));
661
  string_to_spv("conv2d_dw_cwhn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}}));
662