jeffbolznv commited on
Commit
461484c
·
1 Parent(s): 1bbdb81

vulkan: request round-to-even for fp16 in im2col/rope_head (llama/10767)

Browse files

Vulkan doesn't mandate a specific rounding mode, but the shader_float_controls
feature allows rounding mode to be requested if the implementation supports it.

ggml/src/ggml-vulkan/ggml-vulkan.cpp CHANGED
@@ -162,6 +162,7 @@ struct vk_device_struct {
162
  uint32_t subgroup_size;
163
  uint32_t shader_core_count;
164
  bool uma;
 
165
  bool coopmat2;
166
 
167
  bool coopmat_support;
@@ -1916,17 +1917,26 @@ static void ggml_vk_load_shaders(vk_device& device) {
1916
  ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
1917
 
1918
  ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
1919
- ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
1920
-
1921
  ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
1922
- ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
 
 
 
 
 
 
 
1923
 
1924
  ggml_vk_create_pipeline(device, device->pipeline_argsort_f32, "argsort_f32", argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1024, 1, 1}, {}, 1);
1925
 
1926
  ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1927
 
1928
  ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32_len, im2col_f32_data, "main", 2, sizeof(vk_op_im2col_push_constants), {256, 1, 1}, {}, 1);
1929
- ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_len, im2col_f32_f16_data, "main", 2, sizeof(vk_op_im2col_push_constants), {256, 1, 1}, {}, 1);
 
 
 
 
1930
 
1931
  ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1);
1932
 
@@ -2007,11 +2017,13 @@ static vk_device ggml_vk_get_device(size_t idx) {
2007
  vk::PhysicalDeviceDriverProperties driver_props;
2008
  vk::PhysicalDeviceShaderSMBuiltinsPropertiesNV sm_props;
2009
  vk::PhysicalDeviceShaderCoreProperties2AMD amd_shader_core_properties2_props;
 
2010
  props2.pNext = &props3;
2011
  props3.pNext = &subgroup_props;
2012
  subgroup_props.pNext = &driver_props;
 
2013
 
2014
- VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&driver_props;
2015
 
2016
  if (maintenance4_support) {
2017
  last_struct->pNext = (VkBaseOutStructure *)&props4;
@@ -2057,6 +2069,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
2057
  } else {
2058
  device->shader_core_count = 0;
2059
  }
 
2060
 
2061
  const bool force_disable_f16 = getenv("GGML_VK_DISABLE_F16") != nullptr;
2062
 
 
162
  uint32_t subgroup_size;
163
  uint32_t shader_core_count;
164
  bool uma;
165
+ bool float_controls_rte_fp16;
166
  bool coopmat2;
167
 
168
  bool coopmat_support;
 
1917
  ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
1918
 
1919
  ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
 
 
1920
  ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
1921
+
1922
+ if (device->float_controls_rte_fp16) {
1923
+ ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_rte_len, rope_norm_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
1924
+ ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_rte_len, rope_neox_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
1925
+ } else {
1926
+ ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
1927
+ ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
1928
+ }
1929
 
1930
  ggml_vk_create_pipeline(device, device->pipeline_argsort_f32, "argsort_f32", argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1024, 1, 1}, {}, 1);
1931
 
1932
  ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1933
 
1934
  ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32_len, im2col_f32_data, "main", 2, sizeof(vk_op_im2col_push_constants), {256, 1, 1}, {}, 1);
1935
+ if (device->float_controls_rte_fp16) {
1936
+ ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte_len, im2col_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_push_constants), {256, 1, 1}, {}, 1);
1937
+ } else {
1938
+ ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_len, im2col_f32_f16_data, "main", 2, sizeof(vk_op_im2col_push_constants), {256, 1, 1}, {}, 1);
1939
+ }
1940
 
1941
  ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1);
1942
 
 
2017
  vk::PhysicalDeviceDriverProperties driver_props;
2018
  vk::PhysicalDeviceShaderSMBuiltinsPropertiesNV sm_props;
2019
  vk::PhysicalDeviceShaderCoreProperties2AMD amd_shader_core_properties2_props;
2020
+ vk::PhysicalDeviceVulkan12Properties vk12_props;
2021
  props2.pNext = &props3;
2022
  props3.pNext = &subgroup_props;
2023
  subgroup_props.pNext = &driver_props;
2024
+ driver_props.pNext = &vk12_props;
2025
 
2026
+ VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&vk12_props;
2027
 
2028
  if (maintenance4_support) {
2029
  last_struct->pNext = (VkBaseOutStructure *)&props4;
 
2069
  } else {
2070
  device->shader_core_count = 0;
2071
  }
2072
+ device->float_controls_rte_fp16 = vk12_props.shaderRoundingModeRTEFloat16;
2073
 
2074
  const bool force_disable_f16 = getenv("GGML_VK_DISABLE_F16") != nullptr;
2075
 
ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp CHANGED
@@ -1,6 +1,11 @@
1
  #version 450
2
 
3
  #extension GL_EXT_shader_16bit_storage : require
 
 
 
 
 
4
 
5
  layout (push_constant) uniform parameter
6
  {
 
1
  #version 450
2
 
3
  #extension GL_EXT_shader_16bit_storage : require
4
+ #extension GL_EXT_spirv_intrinsics: enable
5
+
6
+ #if RTE16
7
+ spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits
8
+ #endif
9
 
10
  layout (push_constant) uniform parameter
11
  {
ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp CHANGED
@@ -1,6 +1,11 @@
1
  #include "types.comp"
2
 
3
  #extension GL_EXT_shader_16bit_storage : require
 
 
 
 
 
4
 
5
  layout(local_size_x = 1, local_size_y = 256, local_size_z = 1) in;
6
 
 
1
  #include "types.comp"
2
 
3
  #extension GL_EXT_shader_16bit_storage : require
4
+ #extension GL_EXT_spirv_intrinsics: enable
5
+
6
+ #if RTE16
7
+ spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits
8
+ #endif
9
 
10
  layout(local_size_x = 1, local_size_y = 256, local_size_z = 1) in;
11
 
ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp CHANGED
@@ -461,9 +461,11 @@ void process_shaders() {
461
 
462
  string_to_spv("rope_norm_f32", "rope_norm.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
463
  string_to_spv("rope_norm_f16", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
 
464
 
465
  string_to_spv("rope_neox_f32", "rope_neox.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
466
  string_to_spv("rope_neox_f16", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
 
467
 
468
  string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}});
469
 
@@ -471,6 +473,7 @@ void process_shaders() {
471
 
472
  string_to_spv("im2col_f32", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
473
  string_to_spv("im2col_f32_f16", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}));
 
474
 
475
  string_to_spv("timestep_embedding_f32", "timestep_embedding.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
476
 
 
461
 
462
  string_to_spv("rope_norm_f32", "rope_norm.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
463
  string_to_spv("rope_norm_f16", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
464
+ string_to_spv("rope_norm_f16_rte", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}});
465
 
466
  string_to_spv("rope_neox_f32", "rope_neox.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
467
  string_to_spv("rope_neox_f16", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
468
+ string_to_spv("rope_neox_f16_rte", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}});
469
 
470
  string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}});
471
 
 
473
 
474
  string_to_spv("im2col_f32", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
475
  string_to_spv("im2col_f32_f16", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}));
476
+ string_to_spv("im2col_f32_f16_rte", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}));
477
 
478
  string_to_spv("timestep_embedding_f32", "timestep_embedding.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
479