leoeric commited on
Commit
447bd94
·
1 Parent(s): d37d8d5

Add Model Hub support for downloading checkpoints (solves 1GB storage limit)

Browse files
Files changed (2) hide show
  1. app.py +65 -24
  2. requirements_hf.txt +1 -0
app.py CHANGED
@@ -12,6 +12,14 @@ import subprocess
12
  import pathlib
13
  from pathlib import Path
14
 
 
 
 
 
 
 
 
 
15
  # Check if running on Hugging Face Spaces
16
  HF_SPACE = os.environ.get("SPACE_ID") is not None
17
 
@@ -19,6 +27,41 @@ HF_SPACE = os.environ.get("SPACE_ID") is not None
19
  DEFAULT_IMAGE_CHECKPOINT = "ckpts/starflow_3B_t2i_256x256.pth"
20
  DEFAULT_VIDEO_CHECKPOINT = "ckpts/starflow-v_7B_t2v_caus_480p_v3.pth"
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  # Verify CUDA availability (will be True on HF Spaces with GPU hardware)
23
  if torch.cuda.is_available():
24
  print(f"✅ CUDA available! Device: {torch.cuda.get_device_name(0)}")
@@ -29,19 +72,18 @@ else:
29
 
30
  def generate_image(prompt, aspect_ratio, cfg, seed, checkpoint_file, config_path):
31
  """Generate image from text prompt."""
32
- # Use uploaded file if provided, otherwise try default persistent checkpoint
33
- if checkpoint_file is None:
34
- # Try to use pre-uploaded checkpoint
35
- checkpoint_path = DEFAULT_IMAGE_CHECKPOINT
36
- if not os.path.exists(checkpoint_path):
37
- return None, f"Error: No checkpoint found. Please upload a checkpoint file or ensure '{DEFAULT_IMAGE_CHECKPOINT}' exists in Space Files."
38
- else:
39
- # Handle Gradio file object (user uploaded)
40
- if hasattr(checkpoint_file, 'name'):
41
- checkpoint_path = checkpoint_file.name
42
- else:
43
- checkpoint_path = str(checkpoint_file)
44
 
 
45
  if not os.path.exists(checkpoint_path):
46
  return None, f"Error: Checkpoint file not found at {checkpoint_path}."
47
 
@@ -91,19 +133,18 @@ def generate_image(prompt, aspect_ratio, cfg, seed, checkpoint_file, config_path
91
 
92
  def generate_video(prompt, aspect_ratio, cfg, seed, target_length, checkpoint_file, config_path, input_image):
93
  """Generate video from text prompt."""
94
- # Use uploaded file if provided, otherwise try default persistent checkpoint
95
- if checkpoint_file is None:
96
- # Try to use pre-uploaded checkpoint
97
- checkpoint_path = DEFAULT_VIDEO_CHECKPOINT
98
- if not os.path.exists(checkpoint_path):
99
- return None, f"Error: No checkpoint found. Please upload a checkpoint file or ensure '{DEFAULT_VIDEO_CHECKPOINT}' exists in Space Files."
100
- else:
101
- # Handle Gradio file object (user uploaded)
102
- if hasattr(checkpoint_file, 'name'):
103
- checkpoint_path = checkpoint_file.name
104
- else:
105
- checkpoint_path = str(checkpoint_file)
106
 
 
107
  if not os.path.exists(checkpoint_path):
108
  return None, f"Error: Checkpoint file not found at {checkpoint_path}."
109
 
 
12
  import pathlib
13
  from pathlib import Path
14
 
15
+ # Try to import huggingface_hub for downloading checkpoints
16
+ try:
17
+ from huggingface_hub import hf_hub_download
18
+ HF_HUB_AVAILABLE = True
19
+ except ImportError:
20
+ HF_HUB_AVAILABLE = False
21
+ print("⚠️ huggingface_hub not available. Install with: pip install huggingface_hub")
22
+
23
  # Check if running on Hugging Face Spaces
24
  HF_SPACE = os.environ.get("SPACE_ID") is not None
25
 
 
27
  DEFAULT_IMAGE_CHECKPOINT = "ckpts/starflow_3B_t2i_256x256.pth"
28
  DEFAULT_VIDEO_CHECKPOINT = "ckpts/starflow-v_7B_t2v_caus_480p_v3.pth"
29
 
30
+ # Model Hub repositories (if using Hugging Face Model Hub)
31
+ # Set these to your Model Hub repo IDs if you upload checkpoints there
32
+ IMAGE_CHECKPOINT_REPO = None # e.g., "GlobalStudio/starflow-3b-checkpoint"
33
+ VIDEO_CHECKPOINT_REPO = None # e.g., "GlobalStudio/starflow-v-7b-checkpoint"
34
+
35
+ def get_checkpoint_path(checkpoint_file, default_local_path, repo_id=None, filename=None):
36
+ """Get checkpoint path, downloading from Hub if needed."""
37
+ # If user uploaded a file, use it
38
+ if checkpoint_file is not None:
39
+ if hasattr(checkpoint_file, 'name'):
40
+ return checkpoint_file.name
41
+ return str(checkpoint_file)
42
+
43
+ # Try local path first
44
+ if os.path.exists(default_local_path):
45
+ return default_local_path
46
+
47
+ # Try downloading from Model Hub if configured
48
+ if repo_id and filename and HF_HUB_AVAILABLE:
49
+ try:
50
+ print(f"📥 Downloading checkpoint from {repo_id}...")
51
+ checkpoint_path = hf_hub_download(
52
+ repo_id=repo_id,
53
+ filename=filename,
54
+ cache_dir="/tmp/checkpoints",
55
+ local_files_only=False
56
+ )
57
+ print(f"✅ Checkpoint downloaded to: {checkpoint_path}")
58
+ return checkpoint_path
59
+ except Exception as e:
60
+ return None, f"Error downloading checkpoint: {str(e)}"
61
+
62
+ # No checkpoint found
63
+ return None, f"Checkpoint not found. Please upload a checkpoint file or configure Model Hub repository."
64
+
65
  # Verify CUDA availability (will be True on HF Spaces with GPU hardware)
66
  if torch.cuda.is_available():
67
  print(f"✅ CUDA available! Device: {torch.cuda.get_device_name(0)}")
 
72
 
73
  def generate_image(prompt, aspect_ratio, cfg, seed, checkpoint_file, config_path):
74
  """Generate image from text prompt."""
75
+ # Get checkpoint path (from upload, local, or Model Hub)
76
+ result = get_checkpoint_path(
77
+ checkpoint_file,
78
+ DEFAULT_IMAGE_CHECKPOINT,
79
+ IMAGE_CHECKPOINT_REPO,
80
+ "starflow_3B_t2i_256x256.pth"
81
+ )
82
+
83
+ if isinstance(result, tuple) and result[0] is None:
84
+ return None, result[1] # Error message
 
 
85
 
86
+ checkpoint_path = result
87
  if not os.path.exists(checkpoint_path):
88
  return None, f"Error: Checkpoint file not found at {checkpoint_path}."
89
 
 
133
 
134
  def generate_video(prompt, aspect_ratio, cfg, seed, target_length, checkpoint_file, config_path, input_image):
135
  """Generate video from text prompt."""
136
+ # Get checkpoint path (from upload, local, or Model Hub)
137
+ result = get_checkpoint_path(
138
+ checkpoint_file,
139
+ DEFAULT_VIDEO_CHECKPOINT,
140
+ VIDEO_CHECKPOINT_REPO,
141
+ "starflow-v_7B_t2v_caus_480p_v3.pth"
142
+ )
143
+
144
+ if isinstance(result, tuple) and result[0] is None:
145
+ return None, result[1] # Error message
 
 
146
 
147
+ checkpoint_path = result
148
  if not os.path.exists(checkpoint_path):
149
  return None, f"Error: Checkpoint file not found at {checkpoint_path}."
150
 
requirements_hf.txt CHANGED
@@ -9,6 +9,7 @@ torchvision>=0.15.0
9
  # Core dependencies
10
  transformers>=4.30.0
11
  accelerate>=0.20.0
 
12
  torchinfo
13
  einops
14
  scipy
 
9
  # Core dependencies
10
  transformers>=4.30.0
11
  accelerate>=0.20.0
12
+ huggingface_hub>=0.20.0
13
  torchinfo
14
  einops
15
  scipy