leoeric commited on
Commit
34395b9
·
1 Parent(s): 5616201

Fix GPU abort error: improve ZeroGPU decorator detection and GPU context handling

Browse files

- Fix @spaces.GPU decorator application for proper ZeroGPU detection
- Preserve CUDA_VISIBLE_DEVICES in subprocess calls
- Add GPU availability checks before generation
- Enhance error handling for GPU abort scenarios
- Add GPU status logging for debugging

Files changed (4) hide show
  1. app.py +87 -20
  2. dataset.py +42 -5
  3. sample.py +13 -6
  4. utils/training.py +3 -1
app.py CHANGED
@@ -76,8 +76,14 @@ def get_checkpoint_path(checkpoint_file, default_local_path, repo_id=None, filen
76
  # Try downloading from Model Hub if configured
77
  if repo_id and filename and HF_HUB_AVAILABLE:
78
  try:
79
- # Use /workspace if available (persistent), otherwise /tmp
80
- cache_dir = "/workspace/checkpoints" if os.path.exists("/workspace") else "/tmp/checkpoints"
 
 
 
 
 
 
81
  os.makedirs(cache_dir, exist_ok=True)
82
 
83
  # Check if already downloaded
@@ -144,15 +150,15 @@ else:
144
  print(f" PyTorch Version: {torch.__version__}")
145
 
146
  # Apply @spaces.GPU decorator if available (required for ZeroGPU)
 
 
 
 
 
 
147
  if SPACES_AVAILABLE and hasattr(spaces, 'GPU'):
148
- @spaces.GPU
149
- def generate_image(prompt, aspect_ratio, cfg, seed, checkpoint_file, config_path):
150
- """Generate image from text prompt."""
151
- return _generate_image_impl(prompt, aspect_ratio, cfg, seed, checkpoint_file, config_path)
152
- else:
153
- def generate_image(prompt, aspect_ratio, cfg, seed, checkpoint_file, config_path):
154
- """Generate image from text prompt."""
155
- return _generate_image_impl(prompt, aspect_ratio, cfg, seed, checkpoint_file, config_path)
156
 
157
  def _generate_image_impl(prompt, aspect_ratio, cfg, seed, checkpoint_file, config_path):
158
  """Generate image from text prompt (implementation)."""
@@ -222,13 +228,27 @@ def _generate_image_impl(prompt, aspect_ratio, cfg, seed, checkpoint_file, confi
222
  log_file = output_dir / "generation.log"
223
  status_msg += f"📋 Logs will be saved to: {log_file}\n"
224
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
  # Run with timeout (45 minutes max - allows for download + generation)
226
  # Capture output and write to log file
227
  result = subprocess.run(
228
  cmd,
229
  capture_output=True,
230
  text=True,
231
- cwd=os.getcwd(),
 
232
  timeout=2700
233
  )
234
 
@@ -251,6 +271,21 @@ def _generate_image_impl(prompt, aspect_ratio, cfg, seed, checkpoint_file, confi
251
 
252
  if result.returncode != 0:
253
  error_msg = f"❌ Error during generation (return code: {result.returncode})\n\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
  error_msg += f"=== STDERR ===\n{result.stderr}\n\n"
255
  error_msg += f"=== STDOUT ===\n{result.stdout}\n\n"
256
  if log_content:
@@ -323,15 +358,15 @@ def _generate_image_impl(prompt, aspect_ratio, cfg, seed, checkpoint_file, confi
323
 
324
 
325
  # Apply @spaces.GPU decorator if available (required for ZeroGPU)
 
 
 
 
 
 
326
  if SPACES_AVAILABLE and hasattr(spaces, 'GPU'):
327
- @spaces.GPU
328
- def generate_video(prompt, aspect_ratio, cfg, seed, target_length, checkpoint_file, config_path, input_image):
329
- """Generate video from text prompt."""
330
- return _generate_video_impl(prompt, aspect_ratio, cfg, seed, target_length, checkpoint_file, config_path, input_image)
331
- else:
332
- def generate_video(prompt, aspect_ratio, cfg, seed, target_length, checkpoint_file, config_path, input_image):
333
- """Generate video from text prompt."""
334
- return _generate_video_impl(prompt, aspect_ratio, cfg, seed, target_length, checkpoint_file, config_path, input_image)
335
 
336
  def _generate_video_impl(prompt, aspect_ratio, cfg, seed, target_length, checkpoint_file, config_path, input_image):
337
  """Generate video from text prompt (implementation)."""
@@ -396,10 +431,42 @@ def _generate_video_impl(prompt, aspect_ratio, cfg, seed, target_length, checkpo
396
  else:
397
  cmd.extend(["--input_image", "none"])
398
 
399
- result = subprocess.run(cmd, capture_output=True, text=True, cwd=os.getcwd())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
400
 
401
  if result.returncode != 0:
402
  error_msg = f"❌ Error during video generation (return code: {result.returncode})\n\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
403
  error_msg += f"=== STDERR ===\n{result.stderr}\n\n"
404
  error_msg += f"=== STDOUT ===\n{result.stdout}\n"
405
  return None, error_msg
 
76
  # Try downloading from Model Hub if configured
77
  if repo_id and filename and HF_HUB_AVAILABLE:
78
  try:
79
+ # Use /workspace if available (HF Spaces persistent), otherwise use local cache
80
+ if os.path.exists("/workspace"):
81
+ cache_dir = "/workspace/checkpoints"
82
+ elif os.path.exists("/tmp"):
83
+ cache_dir = "/tmp/checkpoints"
84
+ else:
85
+ # Local development: use project directory or user cache
86
+ cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "starflow")
87
  os.makedirs(cache_dir, exist_ok=True)
88
 
89
  # Check if already downloaded
 
150
  print(f" PyTorch Version: {torch.__version__}")
151
 
152
  # Apply @spaces.GPU decorator if available (required for ZeroGPU)
153
+ # IMPORTANT: Decorator must be applied at module level for ZeroGPU to detect it at startup
154
+ def generate_image(prompt, aspect_ratio, cfg, seed, checkpoint_file, config_path):
155
+ """Generate image from text prompt."""
156
+ return _generate_image_impl(prompt, aspect_ratio, cfg, seed, checkpoint_file, config_path)
157
+
158
+ # Apply decorator if spaces module is available (ZeroGPU detection happens at import time)
159
  if SPACES_AVAILABLE and hasattr(spaces, 'GPU'):
160
+ generate_image = spaces.GPU(generate_image)
161
+ print("✅ ZeroGPU decorator applied to generate_image")
 
 
 
 
 
 
162
 
163
  def _generate_image_impl(prompt, aspect_ratio, cfg, seed, checkpoint_file, config_path):
164
  """Generate image from text prompt (implementation)."""
 
228
  log_file = output_dir / "generation.log"
229
  status_msg += f"📋 Logs will be saved to: {log_file}\n"
230
 
231
+ # Ensure GPU environment variables are passed to subprocess
232
+ env = os.environ.copy()
233
+ # Preserve CUDA_VISIBLE_DEVICES if set (important for ZeroGPU)
234
+ if 'CUDA_VISIBLE_DEVICES' in env:
235
+ print(f"✅ CUDA_VISIBLE_DEVICES={env['CUDA_VISIBLE_DEVICES']}")
236
+
237
+ # Verify GPU is available before starting
238
+ if torch.cuda.is_available():
239
+ status_msg += f"✅ GPU available: {torch.cuda.get_device_name(0)}\n"
240
+ status_msg += f" GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB\n"
241
+ else:
242
+ status_msg += "⚠️ Warning: CUDA not available, will use CPU (very slow)\n"
243
+
244
  # Run with timeout (45 minutes max - allows for download + generation)
245
  # Capture output and write to log file
246
  result = subprocess.run(
247
  cmd,
248
  capture_output=True,
249
  text=True,
250
+ cwd=os.getcwd(),
251
+ env=env, # Pass environment variables (including CUDA_VISIBLE_DEVICES)
252
  timeout=2700
253
  )
254
 
 
271
 
272
  if result.returncode != 0:
273
  error_msg = f"❌ Error during generation (return code: {result.returncode})\n\n"
274
+
275
+ # Check for GPU abort or CUDA errors
276
+ error_output = (result.stderr + result.stdout).lower()
277
+ if "gpu aborted" in error_output or "cuda" in error_output or "out of memory" in error_output:
278
+ error_msg += "⚠️ GPU ERROR DETECTED\n\n"
279
+ error_msg += "Possible causes:\n"
280
+ error_msg += "1. GPU timeout (ZeroGPU may have a 5-10 min limit)\n"
281
+ error_msg += "2. CUDA out of memory (model too large for GPU)\n"
282
+ error_msg += "3. GPU allocation failed (ZeroGPU not detected)\n\n"
283
+ error_msg += "Solutions:\n"
284
+ error_msg += "- Try again (GPU may have been released)\n"
285
+ error_msg += "- Check Space logs for detailed error\n"
286
+ error_msg += "- Ensure @spaces.GPU decorator is applied\n"
287
+ error_msg += "- Consider using paid GPU tier for longer runs\n\n"
288
+
289
  error_msg += f"=== STDERR ===\n{result.stderr}\n\n"
290
  error_msg += f"=== STDOUT ===\n{result.stdout}\n\n"
291
  if log_content:
 
358
 
359
 
360
  # Apply @spaces.GPU decorator if available (required for ZeroGPU)
361
+ # IMPORTANT: Decorator must be applied at module level for ZeroGPU to detect it at startup
362
+ def generate_video(prompt, aspect_ratio, cfg, seed, target_length, checkpoint_file, config_path, input_image):
363
+ """Generate video from text prompt."""
364
+ return _generate_video_impl(prompt, aspect_ratio, cfg, seed, target_length, checkpoint_file, config_path, input_image)
365
+
366
+ # Apply decorator if spaces module is available (ZeroGPU detection happens at import time)
367
  if SPACES_AVAILABLE and hasattr(spaces, 'GPU'):
368
+ generate_video = spaces.GPU(generate_video)
369
+ print("✅ ZeroGPU decorator applied to generate_video")
 
 
 
 
 
 
370
 
371
  def _generate_video_impl(prompt, aspect_ratio, cfg, seed, target_length, checkpoint_file, config_path, input_image):
372
  """Generate video from text prompt (implementation)."""
 
431
  else:
432
  cmd.extend(["--input_image", "none"])
433
 
434
+ # Ensure GPU environment variables are passed to subprocess
435
+ env = os.environ.copy()
436
+ # Preserve CUDA_VISIBLE_DEVICES if set (important for ZeroGPU)
437
+ if 'CUDA_VISIBLE_DEVICES' in env:
438
+ print(f"✅ CUDA_VISIBLE_DEVICES={env['CUDA_VISIBLE_DEVICES']}")
439
+
440
+ # Verify GPU is available before starting
441
+ if torch.cuda.is_available():
442
+ print(f"✅ GPU available: {torch.cuda.get_device_name(0)}")
443
+
444
+ result = subprocess.run(
445
+ cmd,
446
+ capture_output=True,
447
+ text=True,
448
+ cwd=os.getcwd(),
449
+ env=env, # Pass environment variables (including CUDA_VISIBLE_DEVICES)
450
+ timeout=3600 # 60 minutes for video generation
451
+ )
452
 
453
  if result.returncode != 0:
454
  error_msg = f"❌ Error during video generation (return code: {result.returncode})\n\n"
455
+
456
+ # Check for GPU abort or CUDA errors
457
+ error_output = (result.stderr + result.stdout).lower()
458
+ if "gpu aborted" in error_output or "cuda" in error_output or "out of memory" in error_output:
459
+ error_msg += "⚠️ GPU ERROR DETECTED\n\n"
460
+ error_msg += "Possible causes:\n"
461
+ error_msg += "1. GPU timeout (ZeroGPU may have a 5-10 min limit)\n"
462
+ error_msg += "2. CUDA out of memory (video generation needs more GPU memory)\n"
463
+ error_msg += "3. GPU allocation failed (ZeroGPU not detected)\n\n"
464
+ error_msg += "Solutions:\n"
465
+ error_msg += "- Try again (GPU may have been released)\n"
466
+ error_msg += "- Check Space logs for detailed error\n"
467
+ error_msg += "- Ensure @spaces.GPU decorator is applied\n"
468
+ error_msg += "- Consider using paid GPU tier for longer runs\n\n"
469
+
470
  error_msg += f"=== STDERR ===\n{result.stderr}\n\n"
471
  error_msg += f"=== STDOUT ===\n{result.stdout}\n"
472
  return None, error_msg
dataset.py CHANGED
@@ -24,15 +24,41 @@ import gc
24
  import threading
25
  import psutil
26
  import tempfile
27
- import decord
28
- from decord import VideoReader
 
 
 
 
 
 
29
  import concurrent.futures
30
  from concurrent.futures import ThreadPoolExecutor, TimeoutError
31
  from misc import print, xprint
32
  from misc.condition_utils import get_camera_condition, get_point_condition, get_wind_condition
33
 
34
- # Initialize multiprocessing manager
35
- manager = torch.multiprocessing.Manager()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
  # ==== helpers ==== #
38
 
@@ -91,6 +117,8 @@ def sample_clip(
91
  num_frames: int = 8,
92
  out_fps: Optional[float] = None, # ← pass an fps here
93
  ):
 
 
94
  vr = VideoReader(video_path)
95
  src_fps = vr.get_avg_fps() # native fps
96
  total = len(vr)
@@ -353,11 +381,20 @@ class ImageTarDataset(Dataset):
353
  class OnlineImageTarDataset(ImageTarDataset):
354
  max_retry_n = 20
355
  max_read = 4096
356
- tar_keys_lock = manager.Lock() if manager is not None else None
357
 
358
  def __init__(self, dataset_tsv, image_size, batch_size=None, **kwargs):
359
  super().__init__(dataset_tsv, image_size, **kwargs)
360
 
 
 
 
 
 
 
 
 
 
361
  self.tar_lists = defaultdict(lambda: [])
362
  self.tar_image_buckets = defaultdict(lambda: defaultdict(lambda: 0))
363
  for i, line in enumerate(self.all_lines):
 
24
  import threading
25
  import psutil
26
  import tempfile
27
+ # Optional import for video processing (not available on macOS ARM)
28
+ try:
29
+ import decord
30
+ from decord import VideoReader
31
+ DECORD_AVAILABLE = True
32
+ except ImportError:
33
+ DECORD_AVAILABLE = False
34
+ print("⚠️ decord not available. Video processing will be disabled.")
35
  import concurrent.futures
36
  from concurrent.futures import ThreadPoolExecutor, TimeoutError
37
  from misc import print, xprint
38
  from misc.condition_utils import get_camera_condition, get_point_condition, get_wind_condition
39
 
40
+ # Lazy initialization of multiprocessing manager (only when needed, not at import time)
41
+ # This avoids issues on macOS which uses 'spawn' instead of 'fork'
42
+ _manager = None
43
+
44
+ def get_manager():
45
+ """Get or create the multiprocessing manager lazily."""
46
+ global _manager
47
+ if _manager is None:
48
+ try:
49
+ # Only create manager when actually needed (not at import time)
50
+ # This avoids RuntimeError on macOS with spawn method
51
+ _manager = torch.multiprocessing.Manager()
52
+ except (RuntimeError, EOFError) as e:
53
+ # If manager creation fails (e.g., on macOS with spawn), return None
54
+ # The code already handles None manager gracefully
55
+ print(f"⚠️ Could not create multiprocessing manager: {e}")
56
+ print(" Continuing without multiprocessing manager (may affect some features)")
57
+ _manager = False # Use False to indicate attempted but failed
58
+ return _manager if _manager is not False else None
59
+
60
+ # For backward compatibility, but will be None until get_manager() is called
61
+ manager = None
62
 
63
  # ==== helpers ==== #
64
 
 
117
  num_frames: int = 8,
118
  out_fps: Optional[float] = None, # ← pass an fps here
119
  ):
120
+ if not DECORD_AVAILABLE:
121
+ raise ImportError("decord is required for video processing but is not available. Install with: pip install decord (Note: not available on macOS ARM)")
122
  vr = VideoReader(video_path)
123
  src_fps = vr.get_avg_fps() # native fps
124
  total = len(vr)
 
381
  class OnlineImageTarDataset(ImageTarDataset):
382
  max_retry_n = 20
383
  max_read = 4096
384
+ # tar_keys_lock will be initialized in __init__ to avoid import-time issues
385
 
386
  def __init__(self, dataset_tsv, image_size, batch_size=None, **kwargs):
387
  super().__init__(dataset_tsv, image_size, **kwargs)
388
 
389
+ # Initialize manager lazily (only when this class is instantiated)
390
+ manager = get_manager()
391
+ # Use threading.Lock as fallback if multiprocessing manager unavailable
392
+ if manager is not None:
393
+ self.tar_keys_lock = manager.Lock()
394
+ else:
395
+ # Fallback to threading lock for single-process use
396
+ self.tar_keys_lock = threading.Lock()
397
+
398
  self.tar_lists = defaultdict(lambda: [])
399
  self.tar_image_buckets = defaultdict(lambda: defaultdict(lambda: 0))
400
  for i, line in enumerate(self.all_lines):
sample.py CHANGED
@@ -60,7 +60,9 @@ def setup_model_and_components(args: argparse.Namespace) -> Tuple[torch.nn.Modul
60
  dist = utils.Distributed()
61
 
62
  # If not running with torchrun, initialize single-process group
63
- if not dist.distributed and torch.cuda.is_available():
 
 
64
  import os
65
  # Initialize single-process process group for model compatibility
66
  if not torch.distributed.is_initialized():
@@ -69,13 +71,15 @@ def setup_model_and_components(args: argparse.Namespace) -> Tuple[torch.nn.Modul
69
  os.environ['RANK'] = '0'
70
  os.environ['LOCAL_RANK'] = '0'
71
  os.environ['WORLD_SIZE'] = '1'
 
 
72
  torch.distributed.init_process_group(
73
- backend='nccl' if torch.cuda.is_available() else 'gloo',
74
  init_method='env://',
75
  world_size=1,
76
  rank=0,
77
  )
78
- print("✅ Initialized single-process distributed group for model compatibility")
79
 
80
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
81
 
@@ -268,11 +272,13 @@ def main(args: argparse.Namespace) -> None:
268
 
269
  # Start sampling
270
  print(f'Starting sampling with global batch size {args.sample_batch_size}x{dist.world_size} GPUs')
271
- torch.cuda.synchronize()
 
272
  start_time = time.time()
273
 
274
  with torch.no_grad():
275
- with torch.autocast(device_type='cuda', dtype=torch.float32):
 
276
  for i in tqdm.tqdm(range(int(np.ceil(num_samples / (args.sample_batch_size * dist.world_size))))):
277
  # Determine aspect ratio and image shape
278
  x_aspect = args.aspect_ratio if args.mix_aspect else None
@@ -367,7 +373,8 @@ def main(args: argparse.Namespace) -> None:
367
  )
368
 
369
  # Print timing statistics
370
- torch.cuda.synchronize()
 
371
  elapsed_time = time.time() - start_time
372
  print(f'{model_name} cfg {args.cfg:.2f}, bsz={args.sample_batch_size}x{dist.world_size}, '
373
  f'time={elapsed_time:.2f}s, speed={num_samples / elapsed_time:.2f} images/s')
 
60
  dist = utils.Distributed()
61
 
62
  # If not running with torchrun, initialize single-process group
63
+ # This is needed because the model code uses torch.distributed.all_reduce
64
+ # Works for both CUDA and CPU modes
65
+ if not dist.distributed:
66
  import os
67
  # Initialize single-process process group for model compatibility
68
  if not torch.distributed.is_initialized():
 
71
  os.environ['RANK'] = '0'
72
  os.environ['LOCAL_RANK'] = '0'
73
  os.environ['WORLD_SIZE'] = '1'
74
+ # Use 'nccl' for CUDA, 'gloo' for CPU
75
+ backend = 'nccl' if torch.cuda.is_available() else 'gloo'
76
  torch.distributed.init_process_group(
77
+ backend=backend,
78
  init_method='env://',
79
  world_size=1,
80
  rank=0,
81
  )
82
+ print(f"✅ Initialized single-process distributed group (backend: {backend}) for model compatibility")
83
 
84
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
85
 
 
272
 
273
  # Start sampling
274
  print(f'Starting sampling with global batch size {args.sample_batch_size}x{dist.world_size} GPUs')
275
+ if torch.cuda.is_available():
276
+ torch.cuda.synchronize()
277
  start_time = time.time()
278
 
279
  with torch.no_grad():
280
+ device_type = 'cuda' if torch.cuda.is_available() else 'cpu'
281
+ with torch.autocast(device_type=device_type, dtype=torch.float32):
282
  for i in tqdm.tqdm(range(int(np.ceil(num_samples / (args.sample_batch_size * dist.world_size))))):
283
  # Determine aspect ratio and image shape
284
  x_aspect = args.aspect_ratio if args.mix_aspect else None
 
373
  )
374
 
375
  # Print timing statistics
376
+ if torch.cuda.is_available():
377
+ torch.cuda.synchronize()
378
  elapsed_time = time.time() - start_time
379
  print(f'{model_name} cfg {args.cfg:.2f}, bsz={args.sample_batch_size}x{dist.world_size}, '
380
  f'time={elapsed_time:.2f}s, speed={num_samples / elapsed_time:.2f} images/s')
utils/training.py CHANGED
@@ -81,7 +81,9 @@ class Distributed:
81
  else: # When running with python for debugging
82
  self.rank, self.local_rank, self.world_size = 0, 0, 1
83
  self.distributed = False
84
- torch.cuda.set_device(self.local_rank)
 
 
85
  self.barrier()
86
 
87
  def barrier(self) -> None:
 
81
  else: # When running with python for debugging
82
  self.rank, self.local_rank, self.world_size = 0, 0, 1
83
  self.distributed = False
84
+ # Only set CUDA device if CUDA is available
85
+ if torch.cuda.is_available():
86
+ torch.cuda.set_device(self.local_rank)
87
  self.barrier()
88
 
89
  def barrier(self) -> None: