aungkomyat commited on
Commit
74b32ce
·
verified ·
1 Parent(s): 37c2fc8

Create download_model.py

Browse files
Files changed (1) hide show
  1. download_model.py +115 -0
download_model.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import os
3
+ import sys
4
+ import requests
5
+ import tarfile
6
+ import zipfile
7
+ from pathlib import Path
8
+
9
+ MODEL_DIR = "trained_model"
10
+ MODEL_CHECKPOINT = "checkpoint_latest.pth.tar"
11
+ CONFIG_FILE = "hparams.yml"
12
+
13
+ def download_file(url, destination):
14
+ """Download a file from url to destination."""
15
+ print(f"Downloading {url} to {destination}")
16
+ response = requests.get(url, stream=True)
17
+ response.raise_for_status()
18
+
19
+ total_size = int(response.headers.get('content-length', 0))
20
+ block_size = 1024 # 1 Kibibyte
21
+ downloaded = 0
22
+
23
+ with open(destination, 'wb') as file:
24
+ for data in response.iter_content(block_size):
25
+ downloaded += len(data)
26
+ file.write(data)
27
+
28
+ # Update progress bar
29
+ done = int(50 * downloaded / total_size) if total_size > 0 else 0
30
+ sys.stdout.write(f"\r[{'=' * done}{' ' * (50-done)}] {downloaded}/{total_size} bytes")
31
+ sys.stdout.flush()
32
+
33
+ print("\nDownload complete!")
34
+
35
+
36
+ def extract_archive(archive_path, extract_to):
37
+ """Extract zip or tar archive to the specified directory."""
38
+ print(f"Extracting {archive_path} to {extract_to}")
39
+
40
+ if archive_path.endswith('.zip'):
41
+ with zipfile.ZipFile(archive_path, 'r') as zip_ref:
42
+ zip_ref.extractall(extract_to)
43
+ elif archive_path.endswith(('.tar.gz', '.tgz')):
44
+ with tarfile.open(archive_path, 'r:gz') as tar_ref:
45
+ tar_ref.extractall(extract_to)
46
+ elif archive_path.endswith('.tar'):
47
+ with tarfile.open(archive_path, 'r:') as tar_ref:
48
+ tar_ref.extractall(extract_to)
49
+ else:
50
+ print(f"Unsupported archive format: {archive_path}")
51
+ return False
52
+
53
+ print("Extraction complete!")
54
+ return True
55
+
56
+
57
+ def setup_model():
58
+ """Download and set up the model files."""
59
+ # Create model directory if it doesn't exist
60
+ os.makedirs(MODEL_DIR, exist_ok=True)
61
+
62
+ # Path for model checkpoint
63
+ model_path = os.path.join(MODEL_DIR, MODEL_CHECKPOINT)
64
+ # Path for config
65
+ config_path = os.path.join(MODEL_DIR, CONFIG_FILE)
66
+
67
+ # Check if files already exist
68
+ if os.path.exists(model_path) and os.path.exists(config_path):
69
+ print("Model files already exist. Skipping download.")
70
+ return True
71
+
72
+ # URLs for the model files
73
+ # Note: Replace these with the actual URLs for your model
74
+ model_url = "REPLACE_WITH_ACTUAL_MODEL_URL"
75
+ config_url = "REPLACE_WITH_ACTUAL_CONFIG_URL"
76
+
77
+ # Download and setup instructions
78
+ print("""
79
+ =================================================================
80
+ IMPORTANT: Model files need to be manually added
81
+ =================================================================
82
+
83
+ This demo requires the following files from the Myanmar TTS model:
84
+ 1. The model checkpoint: checkpoint_latest.pth.tar
85
+ 2. The hyperparameters file: hparams.yml
86
+
87
+ Please obtain these files from the model creator and place them in:
88
+ - trained_model/checkpoint_latest.pth.tar
89
+ - trained_model/hparams.yml
90
+
91
+ Alternatively, you can update this script with the correct download URLs.
92
+
93
+ Model repository: https://github.com/hpbyte/myanmar-tts
94
+ =================================================================
95
+ """)
96
+
97
+ # If you have working URLs, uncomment these lines:
98
+ # download_file(model_url, model_path)
99
+ # download_file(config_url, config_path)
100
+
101
+ # Check if we managed to get the files (if using manual instructions)
102
+ if not os.path.exists(model_path) or not os.path.exists(config_path):
103
+ print("Model files are missing. Please add them manually as described above.")
104
+ # Create placeholder files with instructions
105
+ with open(model_path, 'w') as f:
106
+ f.write("This is a placeholder. Replace with actual model file.")
107
+ with open(config_path, 'w') as f:
108
+ f.write("This is a placeholder. Replace with actual hparams.yml file.")
109
+ return False
110
+
111
+ return True
112
+
113
+
114
+ if __name__ == "__main__":
115
+ setup_model()