ggerganov commited on
Commit
9542e42
·
1 Parent(s): 6bdbd69

ggml : add asserts for type conversion in fattn kernels (llama/9971)

Browse files
Files changed (1) hide show
  1. ggml/src/ggml.c +5 -1
ggml/src/ggml.c CHANGED
@@ -325,8 +325,9 @@ struct ggml_logger_state {
325
  static struct ggml_logger_state g_logger_state = {ggml_log_callback_default, NULL};
326
 
327
  static void ggml_log_internal_v(enum ggml_log_level level, const char * format, va_list args) {
328
- if (format == NULL)
329
  return;
 
330
  va_list args_copy;
331
  va_copy(args_copy, args);
332
  char buffer[128];
@@ -15690,6 +15691,9 @@ static void ggml_compute_forward_flash_attn_ext_f16(
15690
  ggml_vec_dot_t const kq_vec_dot = type_traits[k->type].vec_dot;
15691
  ggml_to_float_t const v_to_float = type_traits[v->type].to_float;
15692
 
 
 
 
15693
  // loop over n_batch and n_head
15694
  for (int ir = ir0; ir < ir1; ++ir) {
15695
  // q indices
 
325
  static struct ggml_logger_state g_logger_state = {ggml_log_callback_default, NULL};
326
 
327
  static void ggml_log_internal_v(enum ggml_log_level level, const char * format, va_list args) {
328
+ if (format == NULL) {
329
  return;
330
+ }
331
  va_list args_copy;
332
  va_copy(args_copy, args);
333
  char buffer[128];
 
15691
  ggml_vec_dot_t const kq_vec_dot = type_traits[k->type].vec_dot;
15692
  ggml_to_float_t const v_to_float = type_traits[v->type].to_float;
15693
 
15694
+ GGML_ASSERT(q_to_vec_dot && "fattn: unsupported K-type");
15695
+ GGML_ASSERT(v_to_float && "fattn: unsupported V-type");
15696
+
15697
  // loop over n_batch and n_head
15698
  for (int ir = ir0; ir < ir1; ++ir) {
15699
  // q indices