leoeric commited on
Commit
86fdec4
·
1 Parent(s): 06f3d78

Fix checkpoint download logic to properly handle Model Hub downloads

Browse files
Files changed (1) hide show
  1. app.py +24 -4
app.py CHANGED
@@ -45,10 +45,18 @@ VIDEO_CHECKPOINT_REPO = "GlobalStudio/starflow-v-7b-checkpoint" # Update this a
45
  def get_checkpoint_path(checkpoint_file, default_local_path, repo_id=None, filename=None):
46
  """Get checkpoint path, downloading from Hub if needed."""
47
  # If user uploaded a file, use it
48
- if checkpoint_file is not None:
49
  if hasattr(checkpoint_file, 'name'):
50
  return checkpoint_file.name
51
- return str(checkpoint_file)
 
 
 
 
 
 
 
 
52
 
53
  # Try local path first
54
  if os.path.exists(default_local_path):
@@ -108,6 +116,13 @@ else:
108
  def generate_image(prompt, aspect_ratio, cfg, seed, checkpoint_file, config_path):
109
  """Generate image from text prompt."""
110
  # Get checkpoint path (from upload, local, or Model Hub)
 
 
 
 
 
 
 
111
  result = get_checkpoint_path(
112
  checkpoint_file,
113
  DEFAULT_IMAGE_CHECKPOINT,
@@ -121,10 +136,10 @@ def generate_image(prompt, aspect_ratio, cfg, seed, checkpoint_file, config_path
121
  checkpoint_path = result
122
 
123
  # Show status
124
- status_msg = f"Using checkpoint: {checkpoint_path}\n"
125
 
126
  if not os.path.exists(checkpoint_path):
127
- return None, f"Error: Checkpoint file not found at {checkpoint_path}.\n\nPlease verify:\n1. Model Hub repo exists: {IMAGE_CHECKPOINT_REPO}\n2. File name matches: starflow_3B_t2i_256x256.pth\n3. Repo is Public (not Private)"
128
 
129
  if not config_path or not os.path.exists(config_path):
130
  return None, "Error: Config file not found. Please ensure config file exists."
@@ -219,6 +234,11 @@ def generate_image(prompt, aspect_ratio, cfg, seed, checkpoint_file, config_path
219
 
220
  def generate_video(prompt, aspect_ratio, cfg, seed, target_length, checkpoint_file, config_path, input_image):
221
  """Generate video from text prompt."""
 
 
 
 
 
222
  # Get checkpoint path (from upload, local, or Model Hub)
223
  result = get_checkpoint_path(
224
  checkpoint_file,
 
45
  def get_checkpoint_path(checkpoint_file, default_local_path, repo_id=None, filename=None):
46
  """Get checkpoint path, downloading from Hub if needed."""
47
  # If user uploaded a file, use it
48
+ if checkpoint_file is not None and checkpoint_file != "":
49
  if hasattr(checkpoint_file, 'name'):
50
  return checkpoint_file.name
51
+ checkpoint_str = str(checkpoint_file)
52
+ # If it's a file path that exists, use it
53
+ if os.path.exists(checkpoint_str):
54
+ return checkpoint_str
55
+ # If it's the default path but doesn't exist, continue to download
56
+ if checkpoint_str == default_local_path and not os.path.exists(checkpoint_str):
57
+ pass # Continue to download logic below
58
+ elif checkpoint_str != default_local_path:
59
+ return checkpoint_str
60
 
61
  # Try local path first
62
  if os.path.exists(default_local_path):
 
116
  def generate_image(prompt, aspect_ratio, cfg, seed, checkpoint_file, config_path):
117
  """Generate image from text prompt."""
118
  # Get checkpoint path (from upload, local, or Model Hub)
119
+ status_msg = ""
120
+
121
+ # Handle checkpoint file (might be string from hidden Textbox)
122
+ if checkpoint_file == DEFAULT_IMAGE_CHECKPOINT or checkpoint_file == "" or checkpoint_file is None:
123
+ # Use Model Hub download
124
+ checkpoint_file = None
125
+
126
  result = get_checkpoint_path(
127
  checkpoint_file,
128
  DEFAULT_IMAGE_CHECKPOINT,
 
136
  checkpoint_path = result
137
 
138
  # Show status
139
+ status_msg += f"Using checkpoint: {checkpoint_path}\n"
140
 
141
  if not os.path.exists(checkpoint_path):
142
+ return None, f"Error: Checkpoint file not found at {checkpoint_path}.\n\nPlease verify:\n1. Model Hub repo exists: {IMAGE_CHECKPOINT_REPO}\n2. File name matches: starflow_3B_t2i_256x256.pth\n3. Repo is Public (not Private)\n\nAttempting to download from Model Hub..."
143
 
144
  if not config_path or not os.path.exists(config_path):
145
  return None, "Error: Config file not found. Please ensure config file exists."
 
234
 
235
  def generate_video(prompt, aspect_ratio, cfg, seed, target_length, checkpoint_file, config_path, input_image):
236
  """Generate video from text prompt."""
237
+ # Handle checkpoint file (might be string from hidden Textbox)
238
+ if checkpoint_file == DEFAULT_VIDEO_CHECKPOINT or checkpoint_file == "" or checkpoint_file is None:
239
+ # Use Model Hub download
240
+ checkpoint_file = None
241
+
242
  # Get checkpoint path (from upload, local, or Model Hub)
243
  result = get_checkpoint_path(
244
  checkpoint_file,