R0CKSTAR commited on
Commit
5e508d2
·
1 Parent(s): 6868981

musa: refine compute capability (llama/12493)

Browse files

* musa: refine compute capability

Signed-off-by: Xiaodong Ye <[email protected]>

* Address review comments

Signed-off-by: Xiaodong Ye <[email protected]>

---------

Signed-off-by: Xiaodong Ye <[email protected]>

ggml/src/ggml-cuda/common.cuh CHANGED
@@ -41,14 +41,17 @@
41
  #define CUDART_HMAX 11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed)
42
  #define CUDART_HMASK 12000 // CUDA 12.0, min. ver. for half2 -> uint mask comparisons
43
 
44
- #define GGML_CUDA_CC_PASCAL 600
45
- #define GGML_CUDA_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
46
- #define GGML_CUDA_CC_VOLTA 700
47
- #define GGML_CUDA_CC_TURING 750
48
- #define GGML_CUDA_CC_AMPERE 800
49
- #define GGML_CUDA_CC_ADA_LOVELACE 890
50
- #define GGML_CUDA_CC_OFFSET_AMD 0x1000000
51
-
 
 
 
52
  // GCN/CNDA, wave size is 64
53
  #define GGML_CUDA_CC_GCN4 (GGML_CUDA_CC_OFFSET_AMD + 0x803) // Tonga, Fiji, Polaris, minimum for fast fp16
54
  #define GGML_CUDA_CC_VEGA (GGML_CUDA_CC_OFFSET_AMD + 0x900) // Vega56/64, minimum for fp16 dual issue
@@ -70,8 +73,17 @@
70
  #define GGML_CUDA_CC_IS_GCN(cc) (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA)
71
  #define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA && cc < GGML_CUDA_CC_RDNA1)
72
 
73
- #define GGML_CUDA_CC_QY1 210
74
- #define GGML_CUDA_CC_QY2 220
 
 
 
 
 
 
 
 
 
75
 
76
  #ifdef __CUDA_ARCH_LIST__
77
  constexpr bool ggml_cuda_has_arch_impl(int) {
@@ -209,21 +221,21 @@ typedef float2 dfloat2;
209
  #define CP_ASYNC_AVAILABLE
210
  #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
211
 
212
- #if !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
213
  #define FLASH_ATTN_AVAILABLE
214
- #endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
215
 
216
  static bool fp16_available(const int cc) {
217
  return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL;
218
  }
219
 
220
  static bool fast_fp16_available(const int cc) {
221
- return fp16_available(cc) && cc != 610;
222
  }
223
 
224
  // To be used for feature selection of external libraries, e.g. cuBLAS.
225
  static bool fast_fp16_hardware_available(const int cc) {
226
- return cc >= GGML_CUDA_CC_PASCAL && cc != 610;
227
  }
228
 
229
  // Any FP16 tensor core instructions are available for ggml code.
@@ -231,20 +243,20 @@ static bool fp16_mma_available(const int cc) {
231
  #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN)
232
  return false;
233
  #else
234
- return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ||
235
- GGML_CUDA_CC_IS_CDNA(cc) || cc >= GGML_CUDA_CC_RDNA3;
236
  #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN)
237
  }
238
 
239
  // To be used for feature selection of external libraries, e.g. cuBLAS.
240
  static bool fp16_mma_hardware_available(const int cc) {
241
- return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_VOLTA ||
242
- GGML_CUDA_CC_IS_CDNA(cc) || cc >= GGML_CUDA_CC_RDNA3;
243
  }
244
 
245
  // Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
246
  static bool new_mma_available(const int cc) {
247
- return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING;
248
  }
249
 
250
  static bool cp_async_available(const int cc) {
 
41
  #define CUDART_HMAX 11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed)
42
  #define CUDART_HMASK 12000 // CUDA 12.0, min. ver. for half2 -> uint mask comparisons
43
 
44
+ #define GGML_CUDA_CC_PASCAL 600
45
+ #define GGML_CUDA_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
46
+ #define GGML_CUDA_CC_VOLTA 700
47
+ #define GGML_CUDA_CC_TURING 750
48
+ #define GGML_CUDA_CC_AMPERE 800
49
+ #define GGML_CUDA_CC_ADA_LOVELACE 890
50
+ #define GGML_CUDA_CC_OFFSET_AMD 0x1000000
51
+ #define GGML_CUDA_CC_OFFSET_MTHREADS 0x0100000
52
+ #define GGML_CUDA_CC_IS_NVIDIA(cc) (cc < GGML_CUDA_CC_OFFSET_MTHREADS)
53
+
54
+ // AMD
55
  // GCN/CNDA, wave size is 64
56
  #define GGML_CUDA_CC_GCN4 (GGML_CUDA_CC_OFFSET_AMD + 0x803) // Tonga, Fiji, Polaris, minimum for fast fp16
57
  #define GGML_CUDA_CC_VEGA (GGML_CUDA_CC_OFFSET_AMD + 0x900) // Vega56/64, minimum for fp16 dual issue
 
73
  #define GGML_CUDA_CC_IS_GCN(cc) (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA)
74
  #define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA && cc < GGML_CUDA_CC_RDNA1)
75
 
76
+ // Moore Threads
77
+ #define GGML_CUDA_MUSA_ARCH_IS_QY1 (__MUSA_ARCH__ <= 210)
78
+
79
+ #define GGML_CUDA_CC_QY1 (GGML_MUSA_CC_OFFSET_MTHREADS + 0x210) // MTT S80, MTT S3000
80
+ #define GGML_CUDA_CC_QY2 (GGML_MUSA_CC_OFFSET_MTHREADS + 0x220) // MTT S4000
81
+ #define GGML_CUDA_CC_NG (GGML_MUSA_CC_OFFSET_MTHREADS + 0x310) // TBD
82
+
83
+ #define GGML_CUDA_CC_IS_MTHREADS(cc) (cc >= GGML_CUDA_CC_OFFSET_MTHREADS && cc < GGML_CUDA_CC_OFFSET_AMD)
84
+ #define GGML_CUDA_CC_IS_QY1(cc) (cc >= GGML_CUDA_CC_QY1 && cc < GGML_CUDA_CC_QY2)
85
+ #define GGML_CUDA_CC_IS_QY2(cc) (cc >= GGML_CUDA_CC_QY2 && cc < GGML_CUDA_CC_NEXT)
86
+ #define GGML_CUDA_CC_IS_NG(cc) (cc >= GGML_CUDA_CC_NG)
87
 
88
  #ifdef __CUDA_ARCH_LIST__
89
  constexpr bool ggml_cuda_has_arch_impl(int) {
 
221
  #define CP_ASYNC_AVAILABLE
222
  #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
223
 
224
+ #if !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && GGML_CUDA_MUSA_ARCH_IS_QY1)
225
  #define FLASH_ATTN_AVAILABLE
226
+ #endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && GGML_CUDA_MUSA_ARCH_IS_QY1)
227
 
228
  static bool fp16_available(const int cc) {
229
  return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL;
230
  }
231
 
232
  static bool fast_fp16_available(const int cc) {
233
+ return (GGML_CUDA_CC_IS_NVIDIA(cc) && fp16_available(cc) && cc != 610) || GGML_CUDA_CC_IS_AMD(cc);
234
  }
235
 
236
  // To be used for feature selection of external libraries, e.g. cuBLAS.
237
  static bool fast_fp16_hardware_available(const int cc) {
238
+ return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_PASCAL && cc != 610) || GGML_CUDA_CC_IS_AMD(cc);
239
  }
240
 
241
  // Any FP16 tensor core instructions are available for ggml code.
 
243
  #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN)
244
  return false;
245
  #else
246
+ return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ||
247
+ GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc);
248
  #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN)
249
  }
250
 
251
  // To be used for feature selection of external libraries, e.g. cuBLAS.
252
  static bool fp16_mma_hardware_available(const int cc) {
253
+ return GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_VOLTA ||
254
+ GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc);
255
  }
256
 
257
  // Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
258
  static bool new_mma_available(const int cc) {
259
+ return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING;
260
  }
261
 
262
  static bool cp_async_available(const int cc) {
ggml/src/ggml-cuda/fattn.cu CHANGED
@@ -253,7 +253,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
253
  const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
254
  const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
255
 
256
- if (cc >= GGML_CUDA_CC_OFFSET_AMD) {
257
  #if defined(GGML_HIP_ROCWMMA_FATTN)
258
  if (fp16_mma_available(cc)) {
259
  ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
 
253
  const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
254
  const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
255
 
256
+ if (GGML_CUDA_CC_IS_AMD(cc)) {
257
  #if defined(GGML_HIP_ROCWMMA_FATTN)
258
  if (fp16_mma_available(cc)) {
259
  ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
ggml/src/ggml-cuda/ggml-cuda.cu CHANGED
@@ -264,9 +264,9 @@ static ggml_cuda_device_info ggml_cuda_init() {
264
  #elif defined(GGML_USE_MUSA)
265
  // FIXME: Ensure compatibility with varying warp sizes across different MUSA archs.
266
  info.devices[id].warp_size = 32;
267
- // TODO: refine the .cc to reflect MUSA's actual CC capabilities
268
  info.devices[id].smpbo = prop.sharedMemPerBlockOptin;
269
- info.devices[id].cc = 100*prop.major + 10*prop.minor;
 
270
  GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s\n",
271
  id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no");
272
  #else
@@ -1188,11 +1188,11 @@ static void ggml_cuda_op_mul_mat_cublas(
1188
  // ldc == nrows of the matrix that cuBLAS writes into
1189
  int64_t ldc = id == ctx.device ? ne0 : row_diff;
1190
 
1191
- const int compute_capability = ggml_cuda_info().devices[id].cc;
1192
 
1193
  const bool use_fp16 = (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT;
1194
 
1195
- if (compute_capability >= GGML_CUDA_CC_VOLTA && use_fp16) {
1196
  // convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
1197
  ggml_cuda_pool_alloc<half> src0_as_f16(ctx.pool(id));
1198
  if (src0->type != GGML_TYPE_F16) {
@@ -1216,7 +1216,7 @@ static void ggml_cuda_op_mul_mat_cublas(
1216
 
1217
  CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream));
1218
 
1219
- if (GGML_CUDA_CC_IS_CDNA(compute_capability)) {
1220
  const float alpha = 1.0f;
1221
  const float beta = 0.0f;
1222
  CUBLAS_CHECK(
 
264
  #elif defined(GGML_USE_MUSA)
265
  // FIXME: Ensure compatibility with varying warp sizes across different MUSA archs.
266
  info.devices[id].warp_size = 32;
 
267
  info.devices[id].smpbo = prop.sharedMemPerBlockOptin;
268
+ info.devices[id].cc = GGML_CUDA_CC_OFFSET_MTHREADS + prop.major * 0x100;
269
+ info.devices[id].cc += prop.minor * 0x10;
270
  GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s\n",
271
  id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no");
272
  #else
 
1188
  // ldc == nrows of the matrix that cuBLAS writes into
1189
  int64_t ldc = id == ctx.device ? ne0 : row_diff;
1190
 
1191
+ const int cc = ggml_cuda_info().devices[id].cc;
1192
 
1193
  const bool use_fp16 = (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT;
1194
 
1195
+ if (((cc >= GGML_CUDA_CC_VOLTA && GGML_CUDA_CC_IS_NVIDIA(cc)) || GGML_CUDA_CC_IS_AMD(cc)) && use_fp16) {
1196
  // convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
1197
  ggml_cuda_pool_alloc<half> src0_as_f16(ctx.pool(id));
1198
  if (src0->type != GGML_TYPE_F16) {
 
1216
 
1217
  CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream));
1218
 
1219
+ if (GGML_CUDA_CC_IS_CDNA(cc)) {
1220
  const float alpha = 1.0f;
1221
  const float beta = 0.0f;
1222
  CUBLAS_CHECK(
ggml/src/ggml-cuda/mmq.cu CHANGED
@@ -28,7 +28,7 @@ void ggml_cuda_op_mul_mat_q(
28
  // Also its fixup needs to allocate a temporary buffer in the memory pool.
29
  // There are multiple parallel CUDA streams for src1_ncols != ne11 which would introduce a race condition for this buffer.
30
  const bool use_stream_k = ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA &&
31
- cc < GGML_CUDA_CC_OFFSET_AMD && src1_ncols == ne11;
32
  const mmq_args args = {src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stride00, src1_padded_row_size, src1_ncols, ne11, nrows_dst, use_stream_k};
33
 
34
  switch (src0->type) {
@@ -145,7 +145,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
145
  return true;
146
  #endif //GGML_CUDA_FORCE_MMQ
147
 
148
- if (cc < GGML_CUDA_CC_OFFSET_AMD) {
149
  return !fp16_mma_hardware_available(cc) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
150
  }
151
 
 
28
  // Also its fixup needs to allocate a temporary buffer in the memory pool.
29
  // There are multiple parallel CUDA streams for src1_ncols != ne11 which would introduce a race condition for this buffer.
30
  const bool use_stream_k = ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA &&
31
+ GGML_CUDA_CC_IS_NVIDIA(cc) && src1_ncols == ne11;
32
  const mmq_args args = {src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stride00, src1_padded_row_size, src1_ncols, ne11, nrows_dst, use_stream_k};
33
 
34
  switch (src0->type) {
 
145
  return true;
146
  #endif //GGML_CUDA_FORCE_MMQ
147
 
148
+ if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
149
  return !fp16_mma_hardware_available(cc) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
150
  }
151
 
ggml/src/ggml-cuda/mmq.cuh CHANGED
@@ -90,7 +90,7 @@ struct tile_x_sizes {
90
 
91
  static int get_mmq_x_max_host(const int cc) {
92
  return new_mma_available(cc) ? 128 :
93
- ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD ?
94
  #ifdef GGML_CUDA_FORCE_MMQ
95
  128 : 64;
96
  #else
@@ -123,8 +123,8 @@ static constexpr __device__ int get_mmq_x_max_device() {
123
  }
124
 
125
  static int get_mmq_y_host(const int cc) {
126
- return cc >= GGML_CUDA_CC_OFFSET_AMD ? (GGML_CUDA_CC_IS_RDNA1(cc) ? 64 : 128) :
127
- (ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ? 128 : 64);
128
  }
129
 
130
  static constexpr __device__ int get_mmq_y_device() {
@@ -2772,14 +2772,14 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
2772
 
2773
  const int shmem = mmq_get_shmem<type>(mmq_x, mmq_y, cc);
2774
 
2775
- #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
2776
  static bool shmem_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
2777
  if (!shmem_limit_raised[id]) {
2778
  CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, false>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
2779
  CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, true>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
2780
  shmem_limit_raised[id] = true;
2781
  }
2782
- #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
2783
 
2784
  const int nty = (args.ne01 + mmq_y - 1) / mmq_y;
2785
  const int ntx = (args.ne11 + mmq_x - 1) / mmq_x;
@@ -2832,7 +2832,7 @@ void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cuda
2832
  const int mmq_x_max = get_mmq_x_max_host(cc);
2833
  const int mmq_y = get_mmq_y_host(cc);
2834
  const int block_num_y = (args.ne01 + mmq_y - 1) / mmq_y;
2835
- const bool use_stream_k = ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD;
2836
 
2837
  int mmq_x_best = 0;
2838
  int nparts_best = INT_MAX;
 
90
 
91
  static int get_mmq_x_max_host(const int cc) {
92
  return new_mma_available(cc) ? 128 :
93
+ ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && GGML_CUDA_CC_IS_NVIDIA(cc) ?
94
  #ifdef GGML_CUDA_FORCE_MMQ
95
  128 : 64;
96
  #else
 
123
  }
124
 
125
  static int get_mmq_y_host(const int cc) {
126
+ return GGML_CUDA_CC_IS_AMD(cc) ? (GGML_CUDA_CC_IS_RDNA1(cc) ? 64 : 128) :
127
+ ((ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && GGML_CUDA_CC_IS_NVIDIA(cc)) ? 128 : 64);
128
  }
129
 
130
  static constexpr __device__ int get_mmq_y_device() {
 
2772
 
2773
  const int shmem = mmq_get_shmem<type>(mmq_x, mmq_y, cc);
2774
 
2775
+ #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
2776
  static bool shmem_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
2777
  if (!shmem_limit_raised[id]) {
2778
  CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, false>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
2779
  CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, true>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
2780
  shmem_limit_raised[id] = true;
2781
  }
2782
+ #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
2783
 
2784
  const int nty = (args.ne01 + mmq_y - 1) / mmq_y;
2785
  const int ntx = (args.ne11 + mmq_x - 1) / mmq_x;
 
2832
  const int mmq_x_max = get_mmq_x_max_host(cc);
2833
  const int mmq_y = get_mmq_y_host(cc);
2834
  const int block_num_y = (args.ne01 + mmq_y - 1) / mmq_y;
2835
+ const bool use_stream_k = ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && GGML_CUDA_CC_IS_NVIDIA(cc);
2836
 
2837
  int mmq_x_best = 0;
2838
  int nparts_best = INT_MAX;