leoeric commited on
Commit
3edcad7
·
1 Parent(s): d43b33a

Update app.py to support persistent checkpoints

Browse files
Files changed (1) hide show
  1. app.py +26 -14
app.py CHANGED
@@ -15,6 +15,10 @@ from pathlib import Path
15
  # Check if running on Hugging Face Spaces
16
  HF_SPACE = os.environ.get("SPACE_ID") is not None
17
 
 
 
 
 
18
  # Verify CUDA availability (will be True on HF Spaces with GPU hardware)
19
  if torch.cuda.is_available():
20
  print(f"✅ CUDA available! Device: {torch.cuda.get_device_name(0)}")
@@ -25,14 +29,18 @@ else:
25
 
26
  def generate_image(prompt, aspect_ratio, cfg, seed, checkpoint_file, config_path):
27
  """Generate image from text prompt."""
 
28
  if checkpoint_file is None:
29
- return None, "Error: Please upload a checkpoint file."
30
-
31
- # Handle Gradio file object
32
- if hasattr(checkpoint_file, 'name'):
33
- checkpoint_path = checkpoint_file.name
34
  else:
35
- checkpoint_path = str(checkpoint_file)
 
 
 
 
36
 
37
  if not os.path.exists(checkpoint_path):
38
  return None, f"Error: Checkpoint file not found at {checkpoint_path}."
@@ -83,14 +91,18 @@ def generate_image(prompt, aspect_ratio, cfg, seed, checkpoint_file, config_path
83
 
84
  def generate_video(prompt, aspect_ratio, cfg, seed, target_length, checkpoint_file, config_path, input_image):
85
  """Generate video from text prompt."""
 
86
  if checkpoint_file is None:
87
- return None, "Error: Please upload a checkpoint file."
88
-
89
- # Handle Gradio file object
90
- if hasattr(checkpoint_file, 'name'):
91
- checkpoint_path = checkpoint_file.name
92
  else:
93
- checkpoint_path = str(checkpoint_file)
 
 
 
 
94
 
95
  if not os.path.exists(checkpoint_path):
96
  return None, f"Error: Checkpoint file not found at {checkpoint_path}."
@@ -173,7 +185,7 @@ with gr.Blocks(title="STARFlow - Text-to-Image & Video Generation") as demo:
173
  lines=3
174
  )
175
  image_checkpoint = gr.File(
176
- label="Model Checkpoint (.pth file)",
177
  file_types=[".pth"]
178
  )
179
  image_config = gr.Textbox(
@@ -210,7 +222,7 @@ with gr.Blocks(title="STARFlow - Text-to-Image & Video Generation") as demo:
210
  lines=3
211
  )
212
  video_checkpoint = gr.File(
213
- label="Model Checkpoint (.pth file)",
214
  file_types=[".pth"]
215
  )
216
  video_config = gr.Textbox(
 
15
  # Check if running on Hugging Face Spaces
16
  HF_SPACE = os.environ.get("SPACE_ID") is not None
17
 
18
+ # Default checkpoint paths (if uploaded to Space Files)
19
+ DEFAULT_IMAGE_CHECKPOINT = "ckpts/starflow_3B_t2i_256x256.pth"
20
+ DEFAULT_VIDEO_CHECKPOINT = "ckpts/starflow-v_7B_t2v_caus_480p_v3.pth"
21
+
22
  # Verify CUDA availability (will be True on HF Spaces with GPU hardware)
23
  if torch.cuda.is_available():
24
  print(f"✅ CUDA available! Device: {torch.cuda.get_device_name(0)}")
 
29
 
30
  def generate_image(prompt, aspect_ratio, cfg, seed, checkpoint_file, config_path):
31
  """Generate image from text prompt."""
32
+ # Use uploaded file if provided, otherwise try default persistent checkpoint
33
  if checkpoint_file is None:
34
+ # Try to use pre-uploaded checkpoint
35
+ checkpoint_path = DEFAULT_IMAGE_CHECKPOINT
36
+ if not os.path.exists(checkpoint_path):
37
+ return None, f"Error: No checkpoint found. Please upload a checkpoint file or ensure '{DEFAULT_IMAGE_CHECKPOINT}' exists in Space Files."
 
38
  else:
39
+ # Handle Gradio file object (user uploaded)
40
+ if hasattr(checkpoint_file, 'name'):
41
+ checkpoint_path = checkpoint_file.name
42
+ else:
43
+ checkpoint_path = str(checkpoint_file)
44
 
45
  if not os.path.exists(checkpoint_path):
46
  return None, f"Error: Checkpoint file not found at {checkpoint_path}."
 
91
 
92
  def generate_video(prompt, aspect_ratio, cfg, seed, target_length, checkpoint_file, config_path, input_image):
93
  """Generate video from text prompt."""
94
+ # Use uploaded file if provided, otherwise try default persistent checkpoint
95
  if checkpoint_file is None:
96
+ # Try to use pre-uploaded checkpoint
97
+ checkpoint_path = DEFAULT_VIDEO_CHECKPOINT
98
+ if not os.path.exists(checkpoint_path):
99
+ return None, f"Error: No checkpoint found. Please upload a checkpoint file or ensure '{DEFAULT_VIDEO_CHECKPOINT}' exists in Space Files."
 
100
  else:
101
+ # Handle Gradio file object (user uploaded)
102
+ if hasattr(checkpoint_file, 'name'):
103
+ checkpoint_path = checkpoint_file.name
104
+ else:
105
+ checkpoint_path = str(checkpoint_file)
106
 
107
  if not os.path.exists(checkpoint_path):
108
  return None, f"Error: Checkpoint file not found at {checkpoint_path}."
 
185
  lines=3
186
  )
187
  image_checkpoint = gr.File(
188
+ label="Model Checkpoint (.pth file) - Optional if already uploaded to Space",
189
  file_types=[".pth"]
190
  )
191
  image_config = gr.Textbox(
 
222
  lines=3
223
  )
224
  video_checkpoint = gr.File(
225
+ label="Model Checkpoint (.pth file) - Optional if already uploaded to Space",
226
  file_types=[".pth"]
227
  )
228
  video_config = gr.Textbox(