leoeric commited on
Commit
3216b1d
·
1 Parent(s): fcc18a6

Add download timeout and better progress messages

Browse files
Files changed (1) hide show
  1. app.py +13 -3
app.py CHANGED
@@ -56,13 +56,22 @@ def get_checkpoint_path(checkpoint_file, default_local_path, repo_id=None, filen
56
  cache_dir = "/workspace/checkpoints" if os.path.exists("/workspace") else "/tmp/checkpoints"
57
  os.makedirs(cache_dir, exist_ok=True)
58
 
 
 
 
 
 
59
  checkpoint_path = hf_hub_download(
60
  repo_id=repo_id,
61
  filename=filename,
62
  cache_dir=cache_dir,
63
  local_files_only=False,
64
- resume_download=True # Resume if interrupted
 
65
  )
 
 
 
66
  print(f"✅ Checkpoint downloaded to: {checkpoint_path}")
67
  return checkpoint_path
68
  except Exception as e:
@@ -132,9 +141,10 @@ def generate_image(prompt, aspect_ratio, cfg, seed, checkpoint_file, config_path
132
  ]
133
 
134
  status_msg += "Running generation...\n"
 
135
 
136
- # Run with timeout (30 minutes max)
137
- result = subprocess.run(cmd, capture_output=True, text=True, cwd=os.getcwd(), timeout=1800)
138
 
139
  if result.returncode != 0:
140
  error_msg = f"Error during generation:\n{result.stderr}\n\nStdout:\n{result.stdout}"
 
56
  cache_dir = "/workspace/checkpoints" if os.path.exists("/workspace") else "/tmp/checkpoints"
57
  os.makedirs(cache_dir, exist_ok=True)
58
 
59
+ # Add timeout and better error handling
60
+ import time
61
+ start_time = time.time()
62
+ print(f"Starting download from {repo_id}...")
63
+
64
  checkpoint_path = hf_hub_download(
65
  repo_id=repo_id,
66
  filename=filename,
67
  cache_dir=cache_dir,
68
  local_files_only=False,
69
+ resume_download=True, # Resume if interrupted
70
+ timeout=600 # 10 minute timeout per request
71
  )
72
+
73
+ elapsed = time.time() - start_time
74
+ print(f"Download completed in {elapsed:.1f} seconds")
75
  print(f"✅ Checkpoint downloaded to: {checkpoint_path}")
76
  return checkpoint_path
77
  except Exception as e:
 
141
  ]
142
 
143
  status_msg += "Running generation...\n"
144
+ status_msg += "Note: First run includes checkpoint download (~10-20 min) and model loading (~2-5 min).\n"
145
 
146
+ # Run with timeout (45 minutes max - allows for download + generation)
147
+ result = subprocess.run(cmd, capture_output=True, text=True, cwd=os.getcwd(), timeout=2700)
148
 
149
  if result.returncode != 0:
150
  error_msg = f"Error during generation:\n{result.stderr}\n\nStdout:\n{result.stdout}"