Spaces:
Running
on
Zero
Running
on
Zero
Add Model Hub support for downloading checkpoints (solves 1GB storage limit)
Browse files- app.py +65 -24
- requirements_hf.txt +1 -0
app.py
CHANGED
|
@@ -12,6 +12,14 @@ import subprocess
|
|
| 12 |
import pathlib
|
| 13 |
from pathlib import Path
|
| 14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
# Check if running on Hugging Face Spaces
|
| 16 |
HF_SPACE = os.environ.get("SPACE_ID") is not None
|
| 17 |
|
|
@@ -19,6 +27,41 @@ HF_SPACE = os.environ.get("SPACE_ID") is not None
|
|
| 19 |
DEFAULT_IMAGE_CHECKPOINT = "ckpts/starflow_3B_t2i_256x256.pth"
|
| 20 |
DEFAULT_VIDEO_CHECKPOINT = "ckpts/starflow-v_7B_t2v_caus_480p_v3.pth"
|
| 21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
# Verify CUDA availability (will be True on HF Spaces with GPU hardware)
|
| 23 |
if torch.cuda.is_available():
|
| 24 |
print(f"✅ CUDA available! Device: {torch.cuda.get_device_name(0)}")
|
|
@@ -29,19 +72,18 @@ else:
|
|
| 29 |
|
| 30 |
def generate_image(prompt, aspect_ratio, cfg, seed, checkpoint_file, config_path):
|
| 31 |
"""Generate image from text prompt."""
|
| 32 |
-
#
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
else:
|
| 43 |
-
checkpoint_path = str(checkpoint_file)
|
| 44 |
|
|
|
|
| 45 |
if not os.path.exists(checkpoint_path):
|
| 46 |
return None, f"Error: Checkpoint file not found at {checkpoint_path}."
|
| 47 |
|
|
@@ -91,19 +133,18 @@ def generate_image(prompt, aspect_ratio, cfg, seed, checkpoint_file, config_path
|
|
| 91 |
|
| 92 |
def generate_video(prompt, aspect_ratio, cfg, seed, target_length, checkpoint_file, config_path, input_image):
|
| 93 |
"""Generate video from text prompt."""
|
| 94 |
-
#
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
else:
|
| 105 |
-
checkpoint_path = str(checkpoint_file)
|
| 106 |
|
|
|
|
| 107 |
if not os.path.exists(checkpoint_path):
|
| 108 |
return None, f"Error: Checkpoint file not found at {checkpoint_path}."
|
| 109 |
|
|
|
|
| 12 |
import pathlib
|
| 13 |
from pathlib import Path
|
| 14 |
|
| 15 |
+
# Try to import huggingface_hub for downloading checkpoints
|
| 16 |
+
try:
|
| 17 |
+
from huggingface_hub import hf_hub_download
|
| 18 |
+
HF_HUB_AVAILABLE = True
|
| 19 |
+
except ImportError:
|
| 20 |
+
HF_HUB_AVAILABLE = False
|
| 21 |
+
print("⚠️ huggingface_hub not available. Install with: pip install huggingface_hub")
|
| 22 |
+
|
| 23 |
# Check if running on Hugging Face Spaces
|
| 24 |
HF_SPACE = os.environ.get("SPACE_ID") is not None
|
| 25 |
|
|
|
|
| 27 |
DEFAULT_IMAGE_CHECKPOINT = "ckpts/starflow_3B_t2i_256x256.pth"
|
| 28 |
DEFAULT_VIDEO_CHECKPOINT = "ckpts/starflow-v_7B_t2v_caus_480p_v3.pth"
|
| 29 |
|
| 30 |
+
# Model Hub repositories (if using Hugging Face Model Hub)
|
| 31 |
+
# Set these to your Model Hub repo IDs if you upload checkpoints there
|
| 32 |
+
IMAGE_CHECKPOINT_REPO = None # e.g., "GlobalStudio/starflow-3b-checkpoint"
|
| 33 |
+
VIDEO_CHECKPOINT_REPO = None # e.g., "GlobalStudio/starflow-v-7b-checkpoint"
|
| 34 |
+
|
| 35 |
+
def get_checkpoint_path(checkpoint_file, default_local_path, repo_id=None, filename=None):
|
| 36 |
+
"""Get checkpoint path, downloading from Hub if needed."""
|
| 37 |
+
# If user uploaded a file, use it
|
| 38 |
+
if checkpoint_file is not None:
|
| 39 |
+
if hasattr(checkpoint_file, 'name'):
|
| 40 |
+
return checkpoint_file.name
|
| 41 |
+
return str(checkpoint_file)
|
| 42 |
+
|
| 43 |
+
# Try local path first
|
| 44 |
+
if os.path.exists(default_local_path):
|
| 45 |
+
return default_local_path
|
| 46 |
+
|
| 47 |
+
# Try downloading from Model Hub if configured
|
| 48 |
+
if repo_id and filename and HF_HUB_AVAILABLE:
|
| 49 |
+
try:
|
| 50 |
+
print(f"📥 Downloading checkpoint from {repo_id}...")
|
| 51 |
+
checkpoint_path = hf_hub_download(
|
| 52 |
+
repo_id=repo_id,
|
| 53 |
+
filename=filename,
|
| 54 |
+
cache_dir="/tmp/checkpoints",
|
| 55 |
+
local_files_only=False
|
| 56 |
+
)
|
| 57 |
+
print(f"✅ Checkpoint downloaded to: {checkpoint_path}")
|
| 58 |
+
return checkpoint_path
|
| 59 |
+
except Exception as e:
|
| 60 |
+
return None, f"Error downloading checkpoint: {str(e)}"
|
| 61 |
+
|
| 62 |
+
# No checkpoint found
|
| 63 |
+
return None, f"Checkpoint not found. Please upload a checkpoint file or configure Model Hub repository."
|
| 64 |
+
|
| 65 |
# Verify CUDA availability (will be True on HF Spaces with GPU hardware)
|
| 66 |
if torch.cuda.is_available():
|
| 67 |
print(f"✅ CUDA available! Device: {torch.cuda.get_device_name(0)}")
|
|
|
|
| 72 |
|
| 73 |
def generate_image(prompt, aspect_ratio, cfg, seed, checkpoint_file, config_path):
|
| 74 |
"""Generate image from text prompt."""
|
| 75 |
+
# Get checkpoint path (from upload, local, or Model Hub)
|
| 76 |
+
result = get_checkpoint_path(
|
| 77 |
+
checkpoint_file,
|
| 78 |
+
DEFAULT_IMAGE_CHECKPOINT,
|
| 79 |
+
IMAGE_CHECKPOINT_REPO,
|
| 80 |
+
"starflow_3B_t2i_256x256.pth"
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
if isinstance(result, tuple) and result[0] is None:
|
| 84 |
+
return None, result[1] # Error message
|
|
|
|
|
|
|
| 85 |
|
| 86 |
+
checkpoint_path = result
|
| 87 |
if not os.path.exists(checkpoint_path):
|
| 88 |
return None, f"Error: Checkpoint file not found at {checkpoint_path}."
|
| 89 |
|
|
|
|
| 133 |
|
| 134 |
def generate_video(prompt, aspect_ratio, cfg, seed, target_length, checkpoint_file, config_path, input_image):
|
| 135 |
"""Generate video from text prompt."""
|
| 136 |
+
# Get checkpoint path (from upload, local, or Model Hub)
|
| 137 |
+
result = get_checkpoint_path(
|
| 138 |
+
checkpoint_file,
|
| 139 |
+
DEFAULT_VIDEO_CHECKPOINT,
|
| 140 |
+
VIDEO_CHECKPOINT_REPO,
|
| 141 |
+
"starflow-v_7B_t2v_caus_480p_v3.pth"
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
if isinstance(result, tuple) and result[0] is None:
|
| 145 |
+
return None, result[1] # Error message
|
|
|
|
|
|
|
| 146 |
|
| 147 |
+
checkpoint_path = result
|
| 148 |
if not os.path.exists(checkpoint_path):
|
| 149 |
return None, f"Error: Checkpoint file not found at {checkpoint_path}."
|
| 150 |
|
requirements_hf.txt
CHANGED
|
@@ -9,6 +9,7 @@ torchvision>=0.15.0
|
|
| 9 |
# Core dependencies
|
| 10 |
transformers>=4.30.0
|
| 11 |
accelerate>=0.20.0
|
|
|
|
| 12 |
torchinfo
|
| 13 |
einops
|
| 14 |
scipy
|
|
|
|
| 9 |
# Core dependencies
|
| 10 |
transformers>=4.30.0
|
| 11 |
accelerate>=0.20.0
|
| 12 |
+
huggingface_hub>=0.20.0
|
| 13 |
torchinfo
|
| 14 |
einops
|
| 15 |
scipy
|