slaren commited on
Commit
a13c99b
·
1 Parent(s): d8c76ac

ggml : always check bounds on get_rows operations (llama/9354)

Browse files
Files changed (1) hide show
  1. ggml/src/ggml.c +4 -4
ggml/src/ggml.c CHANGED
@@ -13709,7 +13709,7 @@ static void ggml_compute_forward_get_rows_q(
13709
  const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
13710
  const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
13711
 
13712
- assert(i01 >= 0 && i01 < ne01);
13713
 
13714
  dequantize_row_q(
13715
  (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
@@ -13750,7 +13750,7 @@ static void ggml_compute_forward_get_rows_f16(
13750
  const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
13751
  const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
13752
 
13753
- assert(i01 >= 0 && i01 < ne01);
13754
 
13755
  ggml_fp16_to_fp32_row(
13756
  (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
@@ -13791,7 +13791,7 @@ static void ggml_compute_forward_get_rows_bf16(
13791
  const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
13792
  const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
13793
 
13794
- assert(i01 >= 0 && i01 < ne01);
13795
 
13796
  ggml_bf16_to_fp32_row(
13797
  (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
@@ -13832,7 +13832,7 @@ static void ggml_compute_forward_get_rows_f32(
13832
  const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
13833
  const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
13834
 
13835
- assert(i01 >= 0 && i01 < ne01);
13836
 
13837
  ggml_vec_cpy_f32(nc,
13838
  (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3),
 
13709
  const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
13710
  const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
13711
 
13712
+ GGML_ASSERT(i01 >= 0 && i01 < ne01);
13713
 
13714
  dequantize_row_q(
13715
  (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
 
13750
  const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
13751
  const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
13752
 
13753
+ GGML_ASSERT(i01 >= 0 && i01 < ne01);
13754
 
13755
  ggml_fp16_to_fp32_row(
13756
  (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
 
13791
  const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
13792
  const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
13793
 
13794
+ GGML_ASSERT(i01 >= 0 && i01 < ne01);
13795
 
13796
  ggml_bf16_to_fp32_row(
13797
  (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
 
13832
  const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
13833
  const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
13834
 
13835
+ GGML_ASSERT(i01 >= 0 && i01 < ne01);
13836
 
13837
  ggml_vec_cpy_f32(nc,
13838
  (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3),