Spaces:
Sleeping
Sleeping
Fix GPU abort error: improve ZeroGPU decorator detection and GPU context handling
Browse files- Fix @spaces.GPU decorator application for proper ZeroGPU detection
- Preserve CUDA_VISIBLE_DEVICES in subprocess calls
- Add GPU availability checks before generation
- Enhance error handling for GPU abort scenarios
- Add GPU status logging for debugging
- app.py +87 -20
- dataset.py +42 -5
- sample.py +13 -6
- utils/training.py +3 -1
app.py
CHANGED
|
@@ -76,8 +76,14 @@ def get_checkpoint_path(checkpoint_file, default_local_path, repo_id=None, filen
|
|
| 76 |
# Try downloading from Model Hub if configured
|
| 77 |
if repo_id and filename and HF_HUB_AVAILABLE:
|
| 78 |
try:
|
| 79 |
-
# Use /workspace if available (persistent), otherwise
|
| 80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
os.makedirs(cache_dir, exist_ok=True)
|
| 82 |
|
| 83 |
# Check if already downloaded
|
|
@@ -144,15 +150,15 @@ else:
|
|
| 144 |
print(f" PyTorch Version: {torch.__version__}")
|
| 145 |
|
| 146 |
# Apply @spaces.GPU decorator if available (required for ZeroGPU)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
if SPACES_AVAILABLE and hasattr(spaces, 'GPU'):
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
"""Generate image from text prompt."""
|
| 151 |
-
return _generate_image_impl(prompt, aspect_ratio, cfg, seed, checkpoint_file, config_path)
|
| 152 |
-
else:
|
| 153 |
-
def generate_image(prompt, aspect_ratio, cfg, seed, checkpoint_file, config_path):
|
| 154 |
-
"""Generate image from text prompt."""
|
| 155 |
-
return _generate_image_impl(prompt, aspect_ratio, cfg, seed, checkpoint_file, config_path)
|
| 156 |
|
| 157 |
def _generate_image_impl(prompt, aspect_ratio, cfg, seed, checkpoint_file, config_path):
|
| 158 |
"""Generate image from text prompt (implementation)."""
|
|
@@ -222,13 +228,27 @@ def _generate_image_impl(prompt, aspect_ratio, cfg, seed, checkpoint_file, confi
|
|
| 222 |
log_file = output_dir / "generation.log"
|
| 223 |
status_msg += f"📋 Logs will be saved to: {log_file}\n"
|
| 224 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 225 |
# Run with timeout (45 minutes max - allows for download + generation)
|
| 226 |
# Capture output and write to log file
|
| 227 |
result = subprocess.run(
|
| 228 |
cmd,
|
| 229 |
capture_output=True,
|
| 230 |
text=True,
|
| 231 |
-
cwd=os.getcwd(),
|
|
|
|
| 232 |
timeout=2700
|
| 233 |
)
|
| 234 |
|
|
@@ -251,6 +271,21 @@ def _generate_image_impl(prompt, aspect_ratio, cfg, seed, checkpoint_file, confi
|
|
| 251 |
|
| 252 |
if result.returncode != 0:
|
| 253 |
error_msg = f"❌ Error during generation (return code: {result.returncode})\n\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 254 |
error_msg += f"=== STDERR ===\n{result.stderr}\n\n"
|
| 255 |
error_msg += f"=== STDOUT ===\n{result.stdout}\n\n"
|
| 256 |
if log_content:
|
|
@@ -323,15 +358,15 @@ def _generate_image_impl(prompt, aspect_ratio, cfg, seed, checkpoint_file, confi
|
|
| 323 |
|
| 324 |
|
| 325 |
# Apply @spaces.GPU decorator if available (required for ZeroGPU)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 326 |
if SPACES_AVAILABLE and hasattr(spaces, 'GPU'):
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
"""Generate video from text prompt."""
|
| 330 |
-
return _generate_video_impl(prompt, aspect_ratio, cfg, seed, target_length, checkpoint_file, config_path, input_image)
|
| 331 |
-
else:
|
| 332 |
-
def generate_video(prompt, aspect_ratio, cfg, seed, target_length, checkpoint_file, config_path, input_image):
|
| 333 |
-
"""Generate video from text prompt."""
|
| 334 |
-
return _generate_video_impl(prompt, aspect_ratio, cfg, seed, target_length, checkpoint_file, config_path, input_image)
|
| 335 |
|
| 336 |
def _generate_video_impl(prompt, aspect_ratio, cfg, seed, target_length, checkpoint_file, config_path, input_image):
|
| 337 |
"""Generate video from text prompt (implementation)."""
|
|
@@ -396,10 +431,42 @@ def _generate_video_impl(prompt, aspect_ratio, cfg, seed, target_length, checkpo
|
|
| 396 |
else:
|
| 397 |
cmd.extend(["--input_image", "none"])
|
| 398 |
|
| 399 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 400 |
|
| 401 |
if result.returncode != 0:
|
| 402 |
error_msg = f"❌ Error during video generation (return code: {result.returncode})\n\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 403 |
error_msg += f"=== STDERR ===\n{result.stderr}\n\n"
|
| 404 |
error_msg += f"=== STDOUT ===\n{result.stdout}\n"
|
| 405 |
return None, error_msg
|
|
|
|
| 76 |
# Try downloading from Model Hub if configured
|
| 77 |
if repo_id and filename and HF_HUB_AVAILABLE:
|
| 78 |
try:
|
| 79 |
+
# Use /workspace if available (HF Spaces persistent), otherwise use local cache
|
| 80 |
+
if os.path.exists("/workspace"):
|
| 81 |
+
cache_dir = "/workspace/checkpoints"
|
| 82 |
+
elif os.path.exists("/tmp"):
|
| 83 |
+
cache_dir = "/tmp/checkpoints"
|
| 84 |
+
else:
|
| 85 |
+
# Local development: use project directory or user cache
|
| 86 |
+
cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "starflow")
|
| 87 |
os.makedirs(cache_dir, exist_ok=True)
|
| 88 |
|
| 89 |
# Check if already downloaded
|
|
|
|
| 150 |
print(f" PyTorch Version: {torch.__version__}")
|
| 151 |
|
| 152 |
# Apply @spaces.GPU decorator if available (required for ZeroGPU)
|
| 153 |
+
# IMPORTANT: Decorator must be applied at module level for ZeroGPU to detect it at startup
|
| 154 |
+
def generate_image(prompt, aspect_ratio, cfg, seed, checkpoint_file, config_path):
|
| 155 |
+
"""Generate image from text prompt."""
|
| 156 |
+
return _generate_image_impl(prompt, aspect_ratio, cfg, seed, checkpoint_file, config_path)
|
| 157 |
+
|
| 158 |
+
# Apply decorator if spaces module is available (ZeroGPU detection happens at import time)
|
| 159 |
if SPACES_AVAILABLE and hasattr(spaces, 'GPU'):
|
| 160 |
+
generate_image = spaces.GPU(generate_image)
|
| 161 |
+
print("✅ ZeroGPU decorator applied to generate_image")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
|
| 163 |
def _generate_image_impl(prompt, aspect_ratio, cfg, seed, checkpoint_file, config_path):
|
| 164 |
"""Generate image from text prompt (implementation)."""
|
|
|
|
| 228 |
log_file = output_dir / "generation.log"
|
| 229 |
status_msg += f"📋 Logs will be saved to: {log_file}\n"
|
| 230 |
|
| 231 |
+
# Ensure GPU environment variables are passed to subprocess
|
| 232 |
+
env = os.environ.copy()
|
| 233 |
+
# Preserve CUDA_VISIBLE_DEVICES if set (important for ZeroGPU)
|
| 234 |
+
if 'CUDA_VISIBLE_DEVICES' in env:
|
| 235 |
+
print(f"✅ CUDA_VISIBLE_DEVICES={env['CUDA_VISIBLE_DEVICES']}")
|
| 236 |
+
|
| 237 |
+
# Verify GPU is available before starting
|
| 238 |
+
if torch.cuda.is_available():
|
| 239 |
+
status_msg += f"✅ GPU available: {torch.cuda.get_device_name(0)}\n"
|
| 240 |
+
status_msg += f" GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB\n"
|
| 241 |
+
else:
|
| 242 |
+
status_msg += "⚠️ Warning: CUDA not available, will use CPU (very slow)\n"
|
| 243 |
+
|
| 244 |
# Run with timeout (45 minutes max - allows for download + generation)
|
| 245 |
# Capture output and write to log file
|
| 246 |
result = subprocess.run(
|
| 247 |
cmd,
|
| 248 |
capture_output=True,
|
| 249 |
text=True,
|
| 250 |
+
cwd=os.getcwd(),
|
| 251 |
+
env=env, # Pass environment variables (including CUDA_VISIBLE_DEVICES)
|
| 252 |
timeout=2700
|
| 253 |
)
|
| 254 |
|
|
|
|
| 271 |
|
| 272 |
if result.returncode != 0:
|
| 273 |
error_msg = f"❌ Error during generation (return code: {result.returncode})\n\n"
|
| 274 |
+
|
| 275 |
+
# Check for GPU abort or CUDA errors
|
| 276 |
+
error_output = (result.stderr + result.stdout).lower()
|
| 277 |
+
if "gpu aborted" in error_output or "cuda" in error_output or "out of memory" in error_output:
|
| 278 |
+
error_msg += "⚠️ GPU ERROR DETECTED\n\n"
|
| 279 |
+
error_msg += "Possible causes:\n"
|
| 280 |
+
error_msg += "1. GPU timeout (ZeroGPU may have a 5-10 min limit)\n"
|
| 281 |
+
error_msg += "2. CUDA out of memory (model too large for GPU)\n"
|
| 282 |
+
error_msg += "3. GPU allocation failed (ZeroGPU not detected)\n\n"
|
| 283 |
+
error_msg += "Solutions:\n"
|
| 284 |
+
error_msg += "- Try again (GPU may have been released)\n"
|
| 285 |
+
error_msg += "- Check Space logs for detailed error\n"
|
| 286 |
+
error_msg += "- Ensure @spaces.GPU decorator is applied\n"
|
| 287 |
+
error_msg += "- Consider using paid GPU tier for longer runs\n\n"
|
| 288 |
+
|
| 289 |
error_msg += f"=== STDERR ===\n{result.stderr}\n\n"
|
| 290 |
error_msg += f"=== STDOUT ===\n{result.stdout}\n\n"
|
| 291 |
if log_content:
|
|
|
|
| 358 |
|
| 359 |
|
| 360 |
# Apply @spaces.GPU decorator if available (required for ZeroGPU)
|
| 361 |
+
# IMPORTANT: Decorator must be applied at module level for ZeroGPU to detect it at startup
|
| 362 |
+
def generate_video(prompt, aspect_ratio, cfg, seed, target_length, checkpoint_file, config_path, input_image):
|
| 363 |
+
"""Generate video from text prompt."""
|
| 364 |
+
return _generate_video_impl(prompt, aspect_ratio, cfg, seed, target_length, checkpoint_file, config_path, input_image)
|
| 365 |
+
|
| 366 |
+
# Apply decorator if spaces module is available (ZeroGPU detection happens at import time)
|
| 367 |
if SPACES_AVAILABLE and hasattr(spaces, 'GPU'):
|
| 368 |
+
generate_video = spaces.GPU(generate_video)
|
| 369 |
+
print("✅ ZeroGPU decorator applied to generate_video")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 370 |
|
| 371 |
def _generate_video_impl(prompt, aspect_ratio, cfg, seed, target_length, checkpoint_file, config_path, input_image):
|
| 372 |
"""Generate video from text prompt (implementation)."""
|
|
|
|
| 431 |
else:
|
| 432 |
cmd.extend(["--input_image", "none"])
|
| 433 |
|
| 434 |
+
# Ensure GPU environment variables are passed to subprocess
|
| 435 |
+
env = os.environ.copy()
|
| 436 |
+
# Preserve CUDA_VISIBLE_DEVICES if set (important for ZeroGPU)
|
| 437 |
+
if 'CUDA_VISIBLE_DEVICES' in env:
|
| 438 |
+
print(f"✅ CUDA_VISIBLE_DEVICES={env['CUDA_VISIBLE_DEVICES']}")
|
| 439 |
+
|
| 440 |
+
# Verify GPU is available before starting
|
| 441 |
+
if torch.cuda.is_available():
|
| 442 |
+
print(f"✅ GPU available: {torch.cuda.get_device_name(0)}")
|
| 443 |
+
|
| 444 |
+
result = subprocess.run(
|
| 445 |
+
cmd,
|
| 446 |
+
capture_output=True,
|
| 447 |
+
text=True,
|
| 448 |
+
cwd=os.getcwd(),
|
| 449 |
+
env=env, # Pass environment variables (including CUDA_VISIBLE_DEVICES)
|
| 450 |
+
timeout=3600 # 60 minutes for video generation
|
| 451 |
+
)
|
| 452 |
|
| 453 |
if result.returncode != 0:
|
| 454 |
error_msg = f"❌ Error during video generation (return code: {result.returncode})\n\n"
|
| 455 |
+
|
| 456 |
+
# Check for GPU abort or CUDA errors
|
| 457 |
+
error_output = (result.stderr + result.stdout).lower()
|
| 458 |
+
if "gpu aborted" in error_output or "cuda" in error_output or "out of memory" in error_output:
|
| 459 |
+
error_msg += "⚠️ GPU ERROR DETECTED\n\n"
|
| 460 |
+
error_msg += "Possible causes:\n"
|
| 461 |
+
error_msg += "1. GPU timeout (ZeroGPU may have a 5-10 min limit)\n"
|
| 462 |
+
error_msg += "2. CUDA out of memory (video generation needs more GPU memory)\n"
|
| 463 |
+
error_msg += "3. GPU allocation failed (ZeroGPU not detected)\n\n"
|
| 464 |
+
error_msg += "Solutions:\n"
|
| 465 |
+
error_msg += "- Try again (GPU may have been released)\n"
|
| 466 |
+
error_msg += "- Check Space logs for detailed error\n"
|
| 467 |
+
error_msg += "- Ensure @spaces.GPU decorator is applied\n"
|
| 468 |
+
error_msg += "- Consider using paid GPU tier for longer runs\n\n"
|
| 469 |
+
|
| 470 |
error_msg += f"=== STDERR ===\n{result.stderr}\n\n"
|
| 471 |
error_msg += f"=== STDOUT ===\n{result.stdout}\n"
|
| 472 |
return None, error_msg
|
dataset.py
CHANGED
|
@@ -24,15 +24,41 @@ import gc
|
|
| 24 |
import threading
|
| 25 |
import psutil
|
| 26 |
import tempfile
|
| 27 |
-
import
|
| 28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
import concurrent.futures
|
| 30 |
from concurrent.futures import ThreadPoolExecutor, TimeoutError
|
| 31 |
from misc import print, xprint
|
| 32 |
from misc.condition_utils import get_camera_condition, get_point_condition, get_wind_condition
|
| 33 |
|
| 34 |
-
#
|
| 35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
# ==== helpers ==== #
|
| 38 |
|
|
@@ -91,6 +117,8 @@ def sample_clip(
|
|
| 91 |
num_frames: int = 8,
|
| 92 |
out_fps: Optional[float] = None, # ← pass an fps here
|
| 93 |
):
|
|
|
|
|
|
|
| 94 |
vr = VideoReader(video_path)
|
| 95 |
src_fps = vr.get_avg_fps() # native fps
|
| 96 |
total = len(vr)
|
|
@@ -353,11 +381,20 @@ class ImageTarDataset(Dataset):
|
|
| 353 |
class OnlineImageTarDataset(ImageTarDataset):
|
| 354 |
max_retry_n = 20
|
| 355 |
max_read = 4096
|
| 356 |
-
tar_keys_lock
|
| 357 |
|
| 358 |
def __init__(self, dataset_tsv, image_size, batch_size=None, **kwargs):
|
| 359 |
super().__init__(dataset_tsv, image_size, **kwargs)
|
| 360 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 361 |
self.tar_lists = defaultdict(lambda: [])
|
| 362 |
self.tar_image_buckets = defaultdict(lambda: defaultdict(lambda: 0))
|
| 363 |
for i, line in enumerate(self.all_lines):
|
|
|
|
| 24 |
import threading
|
| 25 |
import psutil
|
| 26 |
import tempfile
|
| 27 |
+
# Optional import for video processing (not available on macOS ARM)
|
| 28 |
+
try:
|
| 29 |
+
import decord
|
| 30 |
+
from decord import VideoReader
|
| 31 |
+
DECORD_AVAILABLE = True
|
| 32 |
+
except ImportError:
|
| 33 |
+
DECORD_AVAILABLE = False
|
| 34 |
+
print("⚠️ decord not available. Video processing will be disabled.")
|
| 35 |
import concurrent.futures
|
| 36 |
from concurrent.futures import ThreadPoolExecutor, TimeoutError
|
| 37 |
from misc import print, xprint
|
| 38 |
from misc.condition_utils import get_camera_condition, get_point_condition, get_wind_condition
|
| 39 |
|
| 40 |
+
# Lazy initialization of multiprocessing manager (only when needed, not at import time)
|
| 41 |
+
# This avoids issues on macOS which uses 'spawn' instead of 'fork'
|
| 42 |
+
_manager = None
|
| 43 |
+
|
| 44 |
+
def get_manager():
|
| 45 |
+
"""Get or create the multiprocessing manager lazily."""
|
| 46 |
+
global _manager
|
| 47 |
+
if _manager is None:
|
| 48 |
+
try:
|
| 49 |
+
# Only create manager when actually needed (not at import time)
|
| 50 |
+
# This avoids RuntimeError on macOS with spawn method
|
| 51 |
+
_manager = torch.multiprocessing.Manager()
|
| 52 |
+
except (RuntimeError, EOFError) as e:
|
| 53 |
+
# If manager creation fails (e.g., on macOS with spawn), return None
|
| 54 |
+
# The code already handles None manager gracefully
|
| 55 |
+
print(f"⚠️ Could not create multiprocessing manager: {e}")
|
| 56 |
+
print(" Continuing without multiprocessing manager (may affect some features)")
|
| 57 |
+
_manager = False # Use False to indicate attempted but failed
|
| 58 |
+
return _manager if _manager is not False else None
|
| 59 |
+
|
| 60 |
+
# For backward compatibility, but will be None until get_manager() is called
|
| 61 |
+
manager = None
|
| 62 |
|
| 63 |
# ==== helpers ==== #
|
| 64 |
|
|
|
|
| 117 |
num_frames: int = 8,
|
| 118 |
out_fps: Optional[float] = None, # ← pass an fps here
|
| 119 |
):
|
| 120 |
+
if not DECORD_AVAILABLE:
|
| 121 |
+
raise ImportError("decord is required for video processing but is not available. Install with: pip install decord (Note: not available on macOS ARM)")
|
| 122 |
vr = VideoReader(video_path)
|
| 123 |
src_fps = vr.get_avg_fps() # native fps
|
| 124 |
total = len(vr)
|
|
|
|
| 381 |
class OnlineImageTarDataset(ImageTarDataset):
|
| 382 |
max_retry_n = 20
|
| 383 |
max_read = 4096
|
| 384 |
+
# tar_keys_lock will be initialized in __init__ to avoid import-time issues
|
| 385 |
|
| 386 |
def __init__(self, dataset_tsv, image_size, batch_size=None, **kwargs):
|
| 387 |
super().__init__(dataset_tsv, image_size, **kwargs)
|
| 388 |
|
| 389 |
+
# Initialize manager lazily (only when this class is instantiated)
|
| 390 |
+
manager = get_manager()
|
| 391 |
+
# Use threading.Lock as fallback if multiprocessing manager unavailable
|
| 392 |
+
if manager is not None:
|
| 393 |
+
self.tar_keys_lock = manager.Lock()
|
| 394 |
+
else:
|
| 395 |
+
# Fallback to threading lock for single-process use
|
| 396 |
+
self.tar_keys_lock = threading.Lock()
|
| 397 |
+
|
| 398 |
self.tar_lists = defaultdict(lambda: [])
|
| 399 |
self.tar_image_buckets = defaultdict(lambda: defaultdict(lambda: 0))
|
| 400 |
for i, line in enumerate(self.all_lines):
|
sample.py
CHANGED
|
@@ -60,7 +60,9 @@ def setup_model_and_components(args: argparse.Namespace) -> Tuple[torch.nn.Modul
|
|
| 60 |
dist = utils.Distributed()
|
| 61 |
|
| 62 |
# If not running with torchrun, initialize single-process group
|
| 63 |
-
|
|
|
|
|
|
|
| 64 |
import os
|
| 65 |
# Initialize single-process process group for model compatibility
|
| 66 |
if not torch.distributed.is_initialized():
|
|
@@ -69,13 +71,15 @@ def setup_model_and_components(args: argparse.Namespace) -> Tuple[torch.nn.Modul
|
|
| 69 |
os.environ['RANK'] = '0'
|
| 70 |
os.environ['LOCAL_RANK'] = '0'
|
| 71 |
os.environ['WORLD_SIZE'] = '1'
|
|
|
|
|
|
|
| 72 |
torch.distributed.init_process_group(
|
| 73 |
-
backend=
|
| 74 |
init_method='env://',
|
| 75 |
world_size=1,
|
| 76 |
rank=0,
|
| 77 |
)
|
| 78 |
-
print("✅ Initialized single-process distributed group for model compatibility")
|
| 79 |
|
| 80 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 81 |
|
|
@@ -268,11 +272,13 @@ def main(args: argparse.Namespace) -> None:
|
|
| 268 |
|
| 269 |
# Start sampling
|
| 270 |
print(f'Starting sampling with global batch size {args.sample_batch_size}x{dist.world_size} GPUs')
|
| 271 |
-
torch.cuda.
|
|
|
|
| 272 |
start_time = time.time()
|
| 273 |
|
| 274 |
with torch.no_grad():
|
| 275 |
-
|
|
|
|
| 276 |
for i in tqdm.tqdm(range(int(np.ceil(num_samples / (args.sample_batch_size * dist.world_size))))):
|
| 277 |
# Determine aspect ratio and image shape
|
| 278 |
x_aspect = args.aspect_ratio if args.mix_aspect else None
|
|
@@ -367,7 +373,8 @@ def main(args: argparse.Namespace) -> None:
|
|
| 367 |
)
|
| 368 |
|
| 369 |
# Print timing statistics
|
| 370 |
-
torch.cuda.
|
|
|
|
| 371 |
elapsed_time = time.time() - start_time
|
| 372 |
print(f'{model_name} cfg {args.cfg:.2f}, bsz={args.sample_batch_size}x{dist.world_size}, '
|
| 373 |
f'time={elapsed_time:.2f}s, speed={num_samples / elapsed_time:.2f} images/s')
|
|
|
|
| 60 |
dist = utils.Distributed()
|
| 61 |
|
| 62 |
# If not running with torchrun, initialize single-process group
|
| 63 |
+
# This is needed because the model code uses torch.distributed.all_reduce
|
| 64 |
+
# Works for both CUDA and CPU modes
|
| 65 |
+
if not dist.distributed:
|
| 66 |
import os
|
| 67 |
# Initialize single-process process group for model compatibility
|
| 68 |
if not torch.distributed.is_initialized():
|
|
|
|
| 71 |
os.environ['RANK'] = '0'
|
| 72 |
os.environ['LOCAL_RANK'] = '0'
|
| 73 |
os.environ['WORLD_SIZE'] = '1'
|
| 74 |
+
# Use 'nccl' for CUDA, 'gloo' for CPU
|
| 75 |
+
backend = 'nccl' if torch.cuda.is_available() else 'gloo'
|
| 76 |
torch.distributed.init_process_group(
|
| 77 |
+
backend=backend,
|
| 78 |
init_method='env://',
|
| 79 |
world_size=1,
|
| 80 |
rank=0,
|
| 81 |
)
|
| 82 |
+
print(f"✅ Initialized single-process distributed group (backend: {backend}) for model compatibility")
|
| 83 |
|
| 84 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 85 |
|
|
|
|
| 272 |
|
| 273 |
# Start sampling
|
| 274 |
print(f'Starting sampling with global batch size {args.sample_batch_size}x{dist.world_size} GPUs')
|
| 275 |
+
if torch.cuda.is_available():
|
| 276 |
+
torch.cuda.synchronize()
|
| 277 |
start_time = time.time()
|
| 278 |
|
| 279 |
with torch.no_grad():
|
| 280 |
+
device_type = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 281 |
+
with torch.autocast(device_type=device_type, dtype=torch.float32):
|
| 282 |
for i in tqdm.tqdm(range(int(np.ceil(num_samples / (args.sample_batch_size * dist.world_size))))):
|
| 283 |
# Determine aspect ratio and image shape
|
| 284 |
x_aspect = args.aspect_ratio if args.mix_aspect else None
|
|
|
|
| 373 |
)
|
| 374 |
|
| 375 |
# Print timing statistics
|
| 376 |
+
if torch.cuda.is_available():
|
| 377 |
+
torch.cuda.synchronize()
|
| 378 |
elapsed_time = time.time() - start_time
|
| 379 |
print(f'{model_name} cfg {args.cfg:.2f}, bsz={args.sample_batch_size}x{dist.world_size}, '
|
| 380 |
f'time={elapsed_time:.2f}s, speed={num_samples / elapsed_time:.2f} images/s')
|
utils/training.py
CHANGED
|
@@ -81,7 +81,9 @@ class Distributed:
|
|
| 81 |
else: # When running with python for debugging
|
| 82 |
self.rank, self.local_rank, self.world_size = 0, 0, 1
|
| 83 |
self.distributed = False
|
| 84 |
-
|
|
|
|
|
|
|
| 85 |
self.barrier()
|
| 86 |
|
| 87 |
def barrier(self) -> None:
|
|
|
|
| 81 |
else: # When running with python for debugging
|
| 82 |
self.rank, self.local_rank, self.world_size = 0, 0, 1
|
| 83 |
self.distributed = False
|
| 84 |
+
# Only set CUDA device if CUDA is available
|
| 85 |
+
if torch.cuda.is_available():
|
| 86 |
+
torch.cuda.set_device(self.local_rank)
|
| 87 |
self.barrier()
|
| 88 |
|
| 89 |
def barrier(self) -> None:
|