Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -7,6 +7,7 @@ import glob
|
|
| 7 |
# Ensure 'checkpoint' directory exists
|
| 8 |
os.makedirs("checkpoint", exist_ok=True)
|
| 9 |
|
|
|
|
| 10 |
# Function to download the model weights from a Google Drive folder
|
| 11 |
def download_weights_from_folder(google_drive_folder_link):
|
| 12 |
# Extract the folder ID from the Google Drive link
|
|
@@ -16,16 +17,24 @@ def download_weights_from_folder(google_drive_folder_link):
|
|
| 16 |
# Download all files in the Google Drive folder
|
| 17 |
gdown_url = f"https://drive.google.com/drive/folders/{folder_id}"
|
| 18 |
try:
|
|
|
|
| 19 |
gdown.download_folder(gdown_url, quiet=False, output=output_folder)
|
| 20 |
|
| 21 |
-
#
|
| 22 |
-
|
| 23 |
-
if
|
| 24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
else:
|
| 26 |
-
|
| 27 |
except Exception as e:
|
| 28 |
-
|
| 29 |
|
| 30 |
download_weights_from_folder("https://drive.google.com/drive/folders/1Bq0n-w1VT5l99CoaVg02hFpqE5eGLo9O")
|
| 31 |
|
|
|
|
| 7 |
# Ensure 'checkpoint' directory exists
|
| 8 |
os.makedirs("checkpoint", exist_ok=True)
|
| 9 |
|
| 10 |
+
# Function to download the model weights from a Google Drive folder
|
| 11 |
# Function to download the model weights from a Google Drive folder
|
| 12 |
def download_weights_from_folder(google_drive_folder_link):
|
| 13 |
# Extract the folder ID from the Google Drive link
|
|
|
|
| 17 |
# Download all files in the Google Drive folder
|
| 18 |
gdown_url = f"https://drive.google.com/drive/folders/{folder_id}"
|
| 19 |
try:
|
| 20 |
+
# Download the folder contents
|
| 21 |
gdown.download_folder(gdown_url, quiet=False, output=output_folder)
|
| 22 |
|
| 23 |
+
# Ensure the downloaded file is named 'model_state-415001.th'
|
| 24 |
+
downloaded_files = glob.glob(os.path.join(output_folder, "*.th"))
|
| 25 |
+
if downloaded_files:
|
| 26 |
+
downloaded_model_path = downloaded_files[0]
|
| 27 |
+
target_model_path = os.path.join(output_folder, "model_state-415001.th")
|
| 28 |
+
|
| 29 |
+
# Rename if necessary
|
| 30 |
+
if downloaded_model_path != target_model_path:
|
| 31 |
+
os.rename(downloaded_model_path, target_model_path)
|
| 32 |
+
|
| 33 |
+
print(f"Downloaded model weights to {target_model_path}")
|
| 34 |
else:
|
| 35 |
+
print("Model file 'model_state-415001.th' not found in the folder.")
|
| 36 |
except Exception as e:
|
| 37 |
+
print(f"Failed to download weights: {e}")
|
| 38 |
|
| 39 |
download_weights_from_folder("https://drive.google.com/drive/folders/1Bq0n-w1VT5l99CoaVg02hFpqE5eGLo9O")
|
| 40 |
|