Digipom commited on
Commit
a3cb126
·
unverified ·
1 Parent(s): 272633a

whisper.android : address ARM's big.LITTLE arch by checking cpu info (#1254)

Browse files
examples/whisper.android/app/src/main/java/com/whispercppdemo/whisper/LibWhisper.kt CHANGED
@@ -18,7 +18,9 @@ class WhisperContext private constructor(private var ptr: Long) {
18
 
19
  suspend fun transcribeData(data: FloatArray): String = withContext(scope.coroutineContext) {
20
  require(ptr != 0L)
21
- WhisperLib.fullTranscribe(ptr, data)
 
 
22
  val textCount = WhisperLib.getTextSegmentCount(ptr)
23
  return@withContext buildString {
24
  for (i in 0 until textCount) {
@@ -126,7 +128,7 @@ private class WhisperLib {
126
  external fun initContextFromAsset(assetManager: AssetManager, assetPath: String): Long
127
  external fun initContext(modelPath: String): Long
128
  external fun freeContext(contextPtr: Long)
129
- external fun fullTranscribe(contextPtr: Long, audioData: FloatArray)
130
  external fun getTextSegmentCount(contextPtr: Long): Int
131
  external fun getTextSegment(contextPtr: Long, index: Int): String
132
  external fun getSystemInfo(): String
 
18
 
19
  suspend fun transcribeData(data: FloatArray): String = withContext(scope.coroutineContext) {
20
  require(ptr != 0L)
21
+ val numThreads = WhisperCpuConfig.preferredThreadCount
22
+ Log.d(LOG_TAG, "Selecting $numThreads threads")
23
+ WhisperLib.fullTranscribe(ptr, numThreads, data)
24
  val textCount = WhisperLib.getTextSegmentCount(ptr)
25
  return@withContext buildString {
26
  for (i in 0 until textCount) {
 
128
  external fun initContextFromAsset(assetManager: AssetManager, assetPath: String): Long
129
  external fun initContext(modelPath: String): Long
130
  external fun freeContext(contextPtr: Long)
131
+ external fun fullTranscribe(contextPtr: Long, numThreads: Int, audioData: FloatArray)
132
  external fun getTextSegmentCount(contextPtr: Long): Int
133
  external fun getTextSegment(contextPtr: Long, index: Int): String
134
  external fun getSystemInfo(): String
examples/whisper.android/app/src/main/java/com/whispercppdemo/whisper/WhisperCpuConfig.kt ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package com.whispercppdemo.whisper
2
+
3
+ import android.util.Log
4
+ import java.io.BufferedReader
5
+ import java.io.FileReader
6
+
7
+ object WhisperCpuConfig {
8
+ val preferredThreadCount: Int
9
+ // Always use at least 2 threads:
10
+ get() = CpuInfo.getHighPerfCpuCount().coerceAtLeast(2)
11
+ }
12
+
13
+ private class CpuInfo(private val lines: List<String>) {
14
+ private fun getHighPerfCpuCount(): Int = try {
15
+ getHighPerfCpuCountByFrequencies()
16
+ } catch (e: Exception) {
17
+ Log.d(LOG_TAG, "Couldn't read CPU frequencies", e)
18
+ getHighPerfCpuCountByVariant()
19
+ }
20
+
21
+ private fun getHighPerfCpuCountByFrequencies(): Int =
22
+ getCpuValues(property = "processor") { getMaxCpuFrequency(it.toInt()) }
23
+ .also { Log.d(LOG_TAG, "Binned cpu frequencies (frequency, count): ${it.binnedValues()}") }
24
+ .countDroppingMin()
25
+
26
+ private fun getHighPerfCpuCountByVariant(): Int =
27
+ getCpuValues(property = "CPU variant") { it.substringAfter("0x").toInt(radix = 16) }
28
+ .also { Log.d(LOG_TAG, "Binned cpu variants (variant, count): ${it.binnedValues()}") }
29
+ .countKeepingMin()
30
+
31
+ private fun List<Int>.binnedValues() = groupingBy { it }.eachCount()
32
+
33
+ private fun getCpuValues(property: String, mapper: (String) -> Int) = lines
34
+ .asSequence()
35
+ .filter { it.startsWith(property) }
36
+ .map { mapper(it.substringAfter(':').trim()) }
37
+ .sorted()
38
+ .toList()
39
+
40
+
41
+ private fun List<Int>.countDroppingMin(): Int {
42
+ val min = min()
43
+ return count { it > min }
44
+ }
45
+
46
+ private fun List<Int>.countKeepingMin(): Int {
47
+ val min = min()
48
+ return count { it == min }
49
+ }
50
+
51
+ companion object {
52
+ private const val LOG_TAG = "WhisperCpuConfig"
53
+
54
+ fun getHighPerfCpuCount(): Int = try {
55
+ readCpuInfo().getHighPerfCpuCount()
56
+ } catch (e: Exception) {
57
+ Log.d(LOG_TAG, "Couldn't read CPU info", e)
58
+ // Our best guess -- just return the # of CPUs minus 4.
59
+ (Runtime.getRuntime().availableProcessors() - 4).coerceAtLeast(0)
60
+ }
61
+
62
+ private fun readCpuInfo() = CpuInfo(
63
+ BufferedReader(FileReader("/proc/cpuinfo"))
64
+ .useLines { it.toList() }
65
+ )
66
+
67
+ private fun getMaxCpuFrequency(cpuIndex: Int): Int {
68
+ val path = "/sys/devices/system/cpu/cpu${cpuIndex}/cpufreq/cpuinfo_max_freq"
69
+ val maxFreq = BufferedReader(FileReader(path)).use { it.readLine() }
70
+ return maxFreq.toInt()
71
+ }
72
+ }
73
+ }
examples/whisper.android/app/src/main/jni/whisper/jni.c CHANGED
@@ -163,16 +163,12 @@ Java_com_whispercppdemo_whisper_WhisperLib_00024Companion_freeContext(
163
 
164
  JNIEXPORT void JNICALL
165
  Java_com_whispercppdemo_whisper_WhisperLib_00024Companion_fullTranscribe(
166
- JNIEnv *env, jobject thiz, jlong context_ptr, jfloatArray audio_data) {
167
  UNUSED(thiz);
168
  struct whisper_context *context = (struct whisper_context *) context_ptr;
169
  jfloat *audio_data_arr = (*env)->GetFloatArrayElements(env, audio_data, NULL);
170
  const jsize audio_data_length = (*env)->GetArrayLength(env, audio_data);
171
 
172
- // Leave 2 processors free (i.e. the high-efficiency cores).
173
- int max_threads = max(1, min(8, get_nprocs() - 2));
174
- LOGI("Selecting %d threads", max_threads);
175
-
176
  // The below adapted from the Objective-C iOS sample
177
  struct whisper_full_params params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
178
  params.print_realtime = true;
@@ -181,7 +177,7 @@ Java_com_whispercppdemo_whisper_WhisperLib_00024Companion_fullTranscribe(
181
  params.print_special = false;
182
  params.translate = false;
183
  params.language = "en";
184
- params.n_threads = max_threads;
185
  params.offset_ms = 0;
186
  params.no_context = true;
187
  params.single_segment = false;
 
163
 
164
  JNIEXPORT void JNICALL
165
  Java_com_whispercppdemo_whisper_WhisperLib_00024Companion_fullTranscribe(
166
+ JNIEnv *env, jobject thiz, jlong context_ptr, jint num_threads, jfloatArray audio_data) {
167
  UNUSED(thiz);
168
  struct whisper_context *context = (struct whisper_context *) context_ptr;
169
  jfloat *audio_data_arr = (*env)->GetFloatArrayElements(env, audio_data, NULL);
170
  const jsize audio_data_length = (*env)->GetArrayLength(env, audio_data);
171
 
 
 
 
 
172
  // The below adapted from the Objective-C iOS sample
173
  struct whisper_full_params params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
174
  params.print_realtime = true;
 
177
  params.print_special = false;
178
  params.translate = false;
179
  params.language = "en";
180
+ params.n_threads = num_threads;
181
  params.offset_ms = 0;
182
  params.no_context = true;
183
  params.single_segment = false;