Spaces:
Running
on
Zero
Running
on
Zero
Update app.py to support persistent checkpoints
Browse files
app.py
CHANGED
|
@@ -15,6 +15,10 @@ from pathlib import Path
|
|
| 15 |
# Check if running on Hugging Face Spaces
|
| 16 |
HF_SPACE = os.environ.get("SPACE_ID") is not None
|
| 17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
# Verify CUDA availability (will be True on HF Spaces with GPU hardware)
|
| 19 |
if torch.cuda.is_available():
|
| 20 |
print(f"✅ CUDA available! Device: {torch.cuda.get_device_name(0)}")
|
|
@@ -25,14 +29,18 @@ else:
|
|
| 25 |
|
| 26 |
def generate_image(prompt, aspect_ratio, cfg, seed, checkpoint_file, config_path):
|
| 27 |
"""Generate image from text prompt."""
|
|
|
|
| 28 |
if checkpoint_file is None:
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
checkpoint_path = checkpoint_file.name
|
| 34 |
else:
|
| 35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
if not os.path.exists(checkpoint_path):
|
| 38 |
return None, f"Error: Checkpoint file not found at {checkpoint_path}."
|
|
@@ -83,14 +91,18 @@ def generate_image(prompt, aspect_ratio, cfg, seed, checkpoint_file, config_path
|
|
| 83 |
|
| 84 |
def generate_video(prompt, aspect_ratio, cfg, seed, target_length, checkpoint_file, config_path, input_image):
|
| 85 |
"""Generate video from text prompt."""
|
|
|
|
| 86 |
if checkpoint_file is None:
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
checkpoint_path = checkpoint_file.name
|
| 92 |
else:
|
| 93 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
|
| 95 |
if not os.path.exists(checkpoint_path):
|
| 96 |
return None, f"Error: Checkpoint file not found at {checkpoint_path}."
|
|
@@ -173,7 +185,7 @@ with gr.Blocks(title="STARFlow - Text-to-Image & Video Generation") as demo:
|
|
| 173 |
lines=3
|
| 174 |
)
|
| 175 |
image_checkpoint = gr.File(
|
| 176 |
-
label="Model Checkpoint (.pth file)",
|
| 177 |
file_types=[".pth"]
|
| 178 |
)
|
| 179 |
image_config = gr.Textbox(
|
|
@@ -210,7 +222,7 @@ with gr.Blocks(title="STARFlow - Text-to-Image & Video Generation") as demo:
|
|
| 210 |
lines=3
|
| 211 |
)
|
| 212 |
video_checkpoint = gr.File(
|
| 213 |
-
label="Model Checkpoint (.pth file)",
|
| 214 |
file_types=[".pth"]
|
| 215 |
)
|
| 216 |
video_config = gr.Textbox(
|
|
|
|
| 15 |
# Check if running on Hugging Face Spaces
|
| 16 |
HF_SPACE = os.environ.get("SPACE_ID") is not None
|
| 17 |
|
| 18 |
+
# Default checkpoint paths (if uploaded to Space Files)
|
| 19 |
+
DEFAULT_IMAGE_CHECKPOINT = "ckpts/starflow_3B_t2i_256x256.pth"
|
| 20 |
+
DEFAULT_VIDEO_CHECKPOINT = "ckpts/starflow-v_7B_t2v_caus_480p_v3.pth"
|
| 21 |
+
|
| 22 |
# Verify CUDA availability (will be True on HF Spaces with GPU hardware)
|
| 23 |
if torch.cuda.is_available():
|
| 24 |
print(f"✅ CUDA available! Device: {torch.cuda.get_device_name(0)}")
|
|
|
|
| 29 |
|
| 30 |
def generate_image(prompt, aspect_ratio, cfg, seed, checkpoint_file, config_path):
|
| 31 |
"""Generate image from text prompt."""
|
| 32 |
+
# Use uploaded file if provided, otherwise try default persistent checkpoint
|
| 33 |
if checkpoint_file is None:
|
| 34 |
+
# Try to use pre-uploaded checkpoint
|
| 35 |
+
checkpoint_path = DEFAULT_IMAGE_CHECKPOINT
|
| 36 |
+
if not os.path.exists(checkpoint_path):
|
| 37 |
+
return None, f"Error: No checkpoint found. Please upload a checkpoint file or ensure '{DEFAULT_IMAGE_CHECKPOINT}' exists in Space Files."
|
|
|
|
| 38 |
else:
|
| 39 |
+
# Handle Gradio file object (user uploaded)
|
| 40 |
+
if hasattr(checkpoint_file, 'name'):
|
| 41 |
+
checkpoint_path = checkpoint_file.name
|
| 42 |
+
else:
|
| 43 |
+
checkpoint_path = str(checkpoint_file)
|
| 44 |
|
| 45 |
if not os.path.exists(checkpoint_path):
|
| 46 |
return None, f"Error: Checkpoint file not found at {checkpoint_path}."
|
|
|
|
| 91 |
|
| 92 |
def generate_video(prompt, aspect_ratio, cfg, seed, target_length, checkpoint_file, config_path, input_image):
|
| 93 |
"""Generate video from text prompt."""
|
| 94 |
+
# Use uploaded file if provided, otherwise try default persistent checkpoint
|
| 95 |
if checkpoint_file is None:
|
| 96 |
+
# Try to use pre-uploaded checkpoint
|
| 97 |
+
checkpoint_path = DEFAULT_VIDEO_CHECKPOINT
|
| 98 |
+
if not os.path.exists(checkpoint_path):
|
| 99 |
+
return None, f"Error: No checkpoint found. Please upload a checkpoint file or ensure '{DEFAULT_VIDEO_CHECKPOINT}' exists in Space Files."
|
|
|
|
| 100 |
else:
|
| 101 |
+
# Handle Gradio file object (user uploaded)
|
| 102 |
+
if hasattr(checkpoint_file, 'name'):
|
| 103 |
+
checkpoint_path = checkpoint_file.name
|
| 104 |
+
else:
|
| 105 |
+
checkpoint_path = str(checkpoint_file)
|
| 106 |
|
| 107 |
if not os.path.exists(checkpoint_path):
|
| 108 |
return None, f"Error: Checkpoint file not found at {checkpoint_path}."
|
|
|
|
| 185 |
lines=3
|
| 186 |
)
|
| 187 |
image_checkpoint = gr.File(
|
| 188 |
+
label="Model Checkpoint (.pth file) - Optional if already uploaded to Space",
|
| 189 |
file_types=[".pth"]
|
| 190 |
)
|
| 191 |
image_config = gr.Textbox(
|
|
|
|
| 222 |
lines=3
|
| 223 |
)
|
| 224 |
video_checkpoint = gr.File(
|
| 225 |
+
label="Model Checkpoint (.pth file) - Optional if already uploaded to Space",
|
| 226 |
file_types=[".pth"]
|
| 227 |
)
|
| 228 |
video_config = gr.Textbox(
|