Spaces:
Running
Running
| package main | |
| import ( | |
| "context" | |
| "flag" | |
| "fmt" | |
| "io" | |
| "net/http" | |
| "net/url" | |
| "os" | |
| "path/filepath" | |
| "strings" | |
| "syscall" | |
| "time" | |
| ) | |
| /////////////////////////////////////////////////////////////////////////////// | |
| // CONSTANTS | |
| const ( | |
| srcUrl = "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/" // The location of the models | |
| srcExt = ".bin" // Filename extension | |
| bufSize = 1024 * 64 // Size of the buffer used for downloading the model | |
| ) | |
| var ( | |
| // The models which will be downloaded, if no model is specified as an argument | |
| modelNames = []string{ | |
| "tiny", "tiny-q5_1", "tiny-q8_0", | |
| "tiny.en", "tiny.en-q5_1", "tiny.en-q8_0", | |
| "base", "base-q5_1", "base-q8_0", | |
| "base.en", "base.en-q5_1", "base.en-q8_0", | |
| "small", "small-q5_1", "small-q8_0", | |
| "small.en", "small.en-q5_1", "small.en-q8_0", | |
| "medium", "medium-q5_0", "medium-q8_0", | |
| "medium.en", "medium.en-q5_0", "medium.en-q8_0", | |
| "large-v1", | |
| "large-v2", "large-v2-q5_0", "large-v2-q8_0", | |
| "large-v3", "large-v3-q5_0", | |
| "large-v3-turbo", "large-v3-turbo-q5_0", "large-v3-turbo-q8_0", | |
| } | |
| ) | |
| var ( | |
| // The output folder. When not set, use current working directory. | |
| flagOut = flag.String("out", "", "Output folder") | |
| // HTTP timeout parameter - will timeout if takes longer than this to download a model | |
| flagTimeout = flag.Duration("timeout", 30*time.Minute, "HTTP timeout") | |
| // Quiet parameter - will not print progress if set | |
| flagQuiet = flag.Bool("quiet", false, "Quiet mode") | |
| ) | |
| /////////////////////////////////////////////////////////////////////////////// | |
| // MAIN | |
| func main() { | |
| flag.Usage = func() { | |
| name := filepath.Base(flag.CommandLine.Name()) | |
| fmt.Fprintf(flag.CommandLine.Output(), ` | |
| Usage: %s [options] [<model>...] | |
| Options: | |
| -out string Specify the output folder where models will be saved. | |
| Default: Current working directory. | |
| -timeout duration Set the maximum duration for downloading a model. | |
| Example: 10m, 1h (default: 30m0s). | |
| -quiet Suppress all output except errors. | |
| Examples: | |
| 1. Download a specific model: | |
| %s -out ./models tiny-q8_0 | |
| 2. Download all models: | |
| %s -out ./models | |
| `, name, name, name) | |
| flag.PrintDefaults() | |
| } | |
| flag.Parse() | |
| // Get output path | |
| out, err := GetOut() | |
| if err != nil { | |
| fmt.Fprintln(os.Stderr, "Error:", err) | |
| os.Exit(-1) | |
| } | |
| // Create context which quits on SIGINT or SIGQUIT | |
| ctx := ContextForSignal(os.Interrupt, syscall.SIGQUIT) | |
| // Progress filehandle | |
| progress := os.Stdout | |
| if *flagQuiet { | |
| progress, err = os.Open(os.DevNull) | |
| if err != nil { | |
| fmt.Fprintln(os.Stderr, "Error:", err) | |
| os.Exit(-1) | |
| } | |
| defer progress.Close() | |
| } | |
| // Download models - exit on error or interrupt | |
| for _, model := range GetModels() { | |
| url, err := URLForModel(model) | |
| if err != nil { | |
| fmt.Fprintln(os.Stderr, "Error:", err) | |
| continue | |
| } else if path, err := Download(ctx, progress, url, out); err == nil || err == io.EOF { | |
| continue | |
| } else if err == context.Canceled { | |
| os.Remove(path) | |
| fmt.Fprintln(progress, "\nInterrupted") | |
| break | |
| } else if err == context.DeadlineExceeded { | |
| os.Remove(path) | |
| fmt.Fprintln(progress, "Timeout downloading model") | |
| continue | |
| } else { | |
| os.Remove(path) | |
| fmt.Fprintln(os.Stderr, "Error:", err) | |
| break | |
| } | |
| } | |
| } | |
| /////////////////////////////////////////////////////////////////////////////// | |
| // PUBLIC METHODS | |
| // GetOut returns the path to the output directory | |
| func GetOut() (string, error) { | |
| if *flagOut == "" { | |
| return os.Getwd() | |
| } | |
| if info, err := os.Stat(*flagOut); err != nil { | |
| return "", err | |
| } else if !info.IsDir() { | |
| return "", fmt.Errorf("not a directory: %s", info.Name()) | |
| } else { | |
| return *flagOut, nil | |
| } | |
| } | |
| // GetModels returns the list of models to download | |
| func GetModels() []string { | |
| if flag.NArg() == 0 { | |
| fmt.Println("No model specified.") | |
| fmt.Println("Preparing to download all models...") | |
| // Calculate total download size | |
| fmt.Println("Calculating total download size...") | |
| totalSize, err := CalculateTotalDownloadSize(modelNames) | |
| if err != nil { | |
| fmt.Println("Error calculating download sizes:", err) | |
| os.Exit(1) | |
| } | |
| fmt.Println("View available models: https://huggingface.co/ggerganov/whisper.cpp/tree/main") | |
| fmt.Printf("Total download size: %.2f GB\n", float64(totalSize)/(1024*1024*1024)) | |
| fmt.Println("Would you like to download all models? (y/N)") | |
| // Prompt for user input | |
| var response string | |
| fmt.Scanln(&response) | |
| if response != "y" && response != "Y" { | |
| fmt.Println("Aborting. Specify a model to download.") | |
| os.Exit(0) | |
| } | |
| return modelNames // Return all models if confirmed | |
| } | |
| return flag.Args() // Return specific models if arguments are provided | |
| } | |
| func CalculateTotalDownloadSize(models []string) (int64, error) { | |
| var totalSize int64 | |
| client := http.Client{} | |
| for _, model := range models { | |
| modelURL, err := URLForModel(model) | |
| if err != nil { | |
| return 0, err | |
| } | |
| // Issue a HEAD request to get the file size | |
| req, err := http.NewRequest("HEAD", modelURL, nil) | |
| if err != nil { | |
| return 0, err | |
| } | |
| resp, err := client.Do(req) | |
| if err != nil { | |
| return 0, err | |
| } | |
| resp.Body.Close() | |
| if resp.StatusCode != http.StatusOK { | |
| fmt.Printf("Warning: Unable to fetch size for %s (HTTP %d)\n", model, resp.StatusCode) | |
| continue | |
| } | |
| size := resp.ContentLength | |
| totalSize += size | |
| } | |
| return totalSize, nil | |
| } | |
| // URLForModel returns the URL for the given model on huggingface.co | |
| func URLForModel(model string) (string, error) { | |
| // Ensure "ggml-" prefix is added only once | |
| if !strings.HasPrefix(model, "ggml-") { | |
| model = "ggml-" + model | |
| } | |
| // Ensure ".bin" extension is added only once | |
| if filepath.Ext(model) != srcExt { | |
| model += srcExt | |
| } | |
| // Parse the base URL | |
| url, err := url.Parse(srcUrl) | |
| if err != nil { | |
| return "", err | |
| } | |
| // Ensure no trailing slash in the base URL | |
| url.Path = fmt.Sprintf("%s/%s", strings.TrimSuffix(url.Path, "/"), model) | |
| return url.String(), nil | |
| } | |
| // Download downloads the model from the given URL to the given output directory | |
| func Download(ctx context.Context, p io.Writer, model, out string) (string, error) { | |
| // Create HTTP client | |
| client := http.Client{ | |
| Timeout: *flagTimeout, | |
| } | |
| // Initiate the download | |
| req, err := http.NewRequest("GET", model, nil) | |
| if err != nil { | |
| return "", err | |
| } | |
| resp, err := client.Do(req) | |
| if err != nil { | |
| return "", err | |
| } | |
| defer resp.Body.Close() | |
| if resp.StatusCode != http.StatusOK { | |
| return "", fmt.Errorf("%s: %s", model, resp.Status) | |
| } | |
| // If output file exists and is the same size as the model, skip | |
| path := filepath.Join(out, filepath.Base(model)) | |
| if info, err := os.Stat(path); err == nil && info.Size() == resp.ContentLength { | |
| fmt.Fprintln(p, "Skipping", model, "as it already exists") | |
| return "", nil | |
| } | |
| // Create file | |
| w, err := os.Create(path) | |
| if err != nil { | |
| return "", err | |
| } | |
| defer w.Close() | |
| // Report | |
| fmt.Fprintln(p, "Downloading", model, "to", out) | |
| // Progressively download the model | |
| data := make([]byte, bufSize) | |
| count, pct := int64(0), int64(0) | |
| ticker := time.NewTicker(5 * time.Second) | |
| for { | |
| select { | |
| case <-ctx.Done(): | |
| // Cancelled, return error | |
| return path, ctx.Err() | |
| case <-ticker.C: | |
| pct = DownloadReport(p, pct, count, resp.ContentLength) | |
| default: | |
| // Read body | |
| n, err := resp.Body.Read(data) | |
| if err != nil { | |
| DownloadReport(p, pct, count, resp.ContentLength) | |
| return path, err | |
| } else if m, err := w.Write(data[:n]); err != nil { | |
| return path, err | |
| } else { | |
| count += int64(m) | |
| } | |
| } | |
| } | |
| } | |
| // Report periodically reports the download progress when percentage changes | |
| func DownloadReport(w io.Writer, pct, count, total int64) int64 { | |
| pct_ := count * 100 / total | |
| if pct_ > pct { | |
| fmt.Fprintf(w, " ...%d MB written (%d%%)\n", count/1e6, pct_) | |
| } | |
| return pct_ | |
| } | |