Spaces:
Running
on
Zero
Running
on
Zero
Fix checkpoint download logic to properly handle Model Hub downloads
Browse files
app.py
CHANGED
|
@@ -45,10 +45,18 @@ VIDEO_CHECKPOINT_REPO = "GlobalStudio/starflow-v-7b-checkpoint" # Update this a
|
|
| 45 |
def get_checkpoint_path(checkpoint_file, default_local_path, repo_id=None, filename=None):
|
| 46 |
"""Get checkpoint path, downloading from Hub if needed."""
|
| 47 |
# If user uploaded a file, use it
|
| 48 |
-
if checkpoint_file is not None:
|
| 49 |
if hasattr(checkpoint_file, 'name'):
|
| 50 |
return checkpoint_file.name
|
| 51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
|
| 53 |
# Try local path first
|
| 54 |
if os.path.exists(default_local_path):
|
|
@@ -108,6 +116,13 @@ else:
|
|
| 108 |
def generate_image(prompt, aspect_ratio, cfg, seed, checkpoint_file, config_path):
|
| 109 |
"""Generate image from text prompt."""
|
| 110 |
# Get checkpoint path (from upload, local, or Model Hub)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
result = get_checkpoint_path(
|
| 112 |
checkpoint_file,
|
| 113 |
DEFAULT_IMAGE_CHECKPOINT,
|
|
@@ -121,10 +136,10 @@ def generate_image(prompt, aspect_ratio, cfg, seed, checkpoint_file, config_path
|
|
| 121 |
checkpoint_path = result
|
| 122 |
|
| 123 |
# Show status
|
| 124 |
-
status_msg
|
| 125 |
|
| 126 |
if not os.path.exists(checkpoint_path):
|
| 127 |
-
return None, f"Error: Checkpoint file not found at {checkpoint_path}.\n\nPlease verify:\n1. Model Hub repo exists: {IMAGE_CHECKPOINT_REPO}\n2. File name matches: starflow_3B_t2i_256x256.pth\n3. Repo is Public (not Private)"
|
| 128 |
|
| 129 |
if not config_path or not os.path.exists(config_path):
|
| 130 |
return None, "Error: Config file not found. Please ensure config file exists."
|
|
@@ -219,6 +234,11 @@ def generate_image(prompt, aspect_ratio, cfg, seed, checkpoint_file, config_path
|
|
| 219 |
|
| 220 |
def generate_video(prompt, aspect_ratio, cfg, seed, target_length, checkpoint_file, config_path, input_image):
|
| 221 |
"""Generate video from text prompt."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 222 |
# Get checkpoint path (from upload, local, or Model Hub)
|
| 223 |
result = get_checkpoint_path(
|
| 224 |
checkpoint_file,
|
|
|
|
| 45 |
def get_checkpoint_path(checkpoint_file, default_local_path, repo_id=None, filename=None):
|
| 46 |
"""Get checkpoint path, downloading from Hub if needed."""
|
| 47 |
# If user uploaded a file, use it
|
| 48 |
+
if checkpoint_file is not None and checkpoint_file != "":
|
| 49 |
if hasattr(checkpoint_file, 'name'):
|
| 50 |
return checkpoint_file.name
|
| 51 |
+
checkpoint_str = str(checkpoint_file)
|
| 52 |
+
# If it's a file path that exists, use it
|
| 53 |
+
if os.path.exists(checkpoint_str):
|
| 54 |
+
return checkpoint_str
|
| 55 |
+
# If it's the default path but doesn't exist, continue to download
|
| 56 |
+
if checkpoint_str == default_local_path and not os.path.exists(checkpoint_str):
|
| 57 |
+
pass # Continue to download logic below
|
| 58 |
+
elif checkpoint_str != default_local_path:
|
| 59 |
+
return checkpoint_str
|
| 60 |
|
| 61 |
# Try local path first
|
| 62 |
if os.path.exists(default_local_path):
|
|
|
|
| 116 |
def generate_image(prompt, aspect_ratio, cfg, seed, checkpoint_file, config_path):
|
| 117 |
"""Generate image from text prompt."""
|
| 118 |
# Get checkpoint path (from upload, local, or Model Hub)
|
| 119 |
+
status_msg = ""
|
| 120 |
+
|
| 121 |
+
# Handle checkpoint file (might be string from hidden Textbox)
|
| 122 |
+
if checkpoint_file == DEFAULT_IMAGE_CHECKPOINT or checkpoint_file == "" or checkpoint_file is None:
|
| 123 |
+
# Use Model Hub download
|
| 124 |
+
checkpoint_file = None
|
| 125 |
+
|
| 126 |
result = get_checkpoint_path(
|
| 127 |
checkpoint_file,
|
| 128 |
DEFAULT_IMAGE_CHECKPOINT,
|
|
|
|
| 136 |
checkpoint_path = result
|
| 137 |
|
| 138 |
# Show status
|
| 139 |
+
status_msg += f"Using checkpoint: {checkpoint_path}\n"
|
| 140 |
|
| 141 |
if not os.path.exists(checkpoint_path):
|
| 142 |
+
return None, f"Error: Checkpoint file not found at {checkpoint_path}.\n\nPlease verify:\n1. Model Hub repo exists: {IMAGE_CHECKPOINT_REPO}\n2. File name matches: starflow_3B_t2i_256x256.pth\n3. Repo is Public (not Private)\n\nAttempting to download from Model Hub..."
|
| 143 |
|
| 144 |
if not config_path or not os.path.exists(config_path):
|
| 145 |
return None, "Error: Config file not found. Please ensure config file exists."
|
|
|
|
| 234 |
|
| 235 |
def generate_video(prompt, aspect_ratio, cfg, seed, target_length, checkpoint_file, config_path, input_image):
|
| 236 |
"""Generate video from text prompt."""
|
| 237 |
+
# Handle checkpoint file (might be string from hidden Textbox)
|
| 238 |
+
if checkpoint_file == DEFAULT_VIDEO_CHECKPOINT or checkpoint_file == "" or checkpoint_file is None:
|
| 239 |
+
# Use Model Hub download
|
| 240 |
+
checkpoint_file = None
|
| 241 |
+
|
| 242 |
# Get checkpoint path (from upload, local, or Model Hub)
|
| 243 |
result = get_checkpoint_path(
|
| 244 |
checkpoint_file,
|