Spaces:
Sleeping
Sleeping
Fix CUDA out of memory error by clearing cache between model loads and enabling expandable segments
Browse files
app.py
CHANGED
|
@@ -10,6 +10,8 @@ import os
|
|
| 10 |
os.environ['OMP_NUM_THREADS'] = '1'
|
| 11 |
os.environ['MKL_NUM_THREADS'] = '1'
|
| 12 |
os.environ['NUMEXPR_NUM_THREADS'] = '1'
|
|
|
|
|
|
|
| 13 |
|
| 14 |
import warnings
|
| 15 |
import gradio as gr
|
|
@@ -237,17 +239,49 @@ def generate_image(prompt, aspect_ratio, cfg, seed, checkpoint_file, config_path
|
|
| 237 |
|
| 238 |
# Find the generated image
|
| 239 |
# The sample.py script saves to logdir/model_name/...
|
| 240 |
-
#
|
| 241 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 242 |
if output_files:
|
|
|
|
| 243 |
latest_file = max(output_files, key=lambda p: p.stat().st_mtime)
|
|
|
|
| 244 |
return str(latest_file), status_msg + "β
Success! Image generated."
|
| 245 |
else:
|
| 246 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 247 |
if log_content:
|
| 248 |
-
error_msg += f"\n\nπ Check log file for details: {log_file}\nLast
|
| 249 |
else:
|
| 250 |
-
error_msg += f"\n\nCheck stdout:\n{result.stdout}"
|
| 251 |
return None, error_msg
|
| 252 |
|
| 253 |
except Exception as e:
|
|
|
|
| 10 |
os.environ['OMP_NUM_THREADS'] = '1'
|
| 11 |
os.environ['MKL_NUM_THREADS'] = '1'
|
| 12 |
os.environ['NUMEXPR_NUM_THREADS'] = '1'
|
| 13 |
+
# Fix CUDA memory fragmentation
|
| 14 |
+
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
|
| 15 |
|
| 16 |
import warnings
|
| 17 |
import gradio as gr
|
|
|
|
| 239 |
|
| 240 |
# Find the generated image
|
| 241 |
# The sample.py script saves to logdir/model_name/...
|
| 242 |
+
# Model name is derived from checkpoint path stem
|
| 243 |
+
checkpoint_stem = Path(checkpoint_path).stem
|
| 244 |
+
model_output_dir = output_dir / checkpoint_stem
|
| 245 |
+
|
| 246 |
+
status_msg += f"Searching in: {model_output_dir}\n"
|
| 247 |
+
status_msg += f"Also searching recursively in: {output_dir}\n"
|
| 248 |
+
|
| 249 |
+
# Search in model-specific directory first, then recursively
|
| 250 |
+
search_paths = [model_output_dir, output_dir]
|
| 251 |
+
output_files = []
|
| 252 |
+
|
| 253 |
+
for search_path in search_paths:
|
| 254 |
+
if search_path.exists():
|
| 255 |
+
# Look for PNG, JPG, JPEG files
|
| 256 |
+
found = list(search_path.glob("**/*.png")) + list(search_path.glob("**/*.jpg")) + list(search_path.glob("**/*.jpeg"))
|
| 257 |
+
output_files.extend(found)
|
| 258 |
+
status_msg += f"Found {len(found)} files in {search_path}\n"
|
| 259 |
+
|
| 260 |
if output_files:
|
| 261 |
+
# Get the most recent file
|
| 262 |
latest_file = max(output_files, key=lambda p: p.stat().st_mtime)
|
| 263 |
+
status_msg += f"β
Found image: {latest_file}\n"
|
| 264 |
return str(latest_file), status_msg + "β
Success! Image generated."
|
| 265 |
else:
|
| 266 |
+
# Debug: list what's actually in the directory
|
| 267 |
+
debug_info = f"\n\nDebug info:\n"
|
| 268 |
+
debug_info += f"Output dir exists: {output_dir.exists()}\n"
|
| 269 |
+
if output_dir.exists():
|
| 270 |
+
debug_info += f"Contents of {output_dir}:\n"
|
| 271 |
+
for item in output_dir.iterdir():
|
| 272 |
+
debug_info += f" - {item.name} ({'dir' if item.is_dir() else 'file'})\n"
|
| 273 |
+
if model_output_dir.exists():
|
| 274 |
+
debug_info += f"\nContents of {model_output_dir}:\n"
|
| 275 |
+
for item in model_output_dir.iterdir():
|
| 276 |
+
debug_info += f" - {item.name} ({'dir' if item.is_dir() else 'file'})\n"
|
| 277 |
+
|
| 278 |
+
error_msg = status_msg + f"Error: Generated image not found.\n"
|
| 279 |
+
error_msg += f"Searched in: {output_dir} and {model_output_dir}\n"
|
| 280 |
+
error_msg += debug_info
|
| 281 |
if log_content:
|
| 282 |
+
error_msg += f"\n\nπ Check log file for details: {log_file}\nLast 2000 chars:\n{log_content[-2000:]}"
|
| 283 |
else:
|
| 284 |
+
error_msg += f"\n\nCheck stdout:\n{result.stdout[-1000:]}"
|
| 285 |
return None, error_msg
|
| 286 |
|
| 287 |
except Exception as e:
|
sample.py
CHANGED
|
@@ -17,6 +17,7 @@ import argparse
|
|
| 17 |
import copy
|
| 18 |
import pathlib
|
| 19 |
import time
|
|
|
|
| 20 |
from typing import Dict, List, Optional, Tuple, Union
|
| 21 |
|
| 22 |
import numpy as np
|
|
@@ -60,28 +61,40 @@ def setup_model_and_components(args: argparse.Namespace) -> Tuple[torch.nn.Modul
|
|
| 60 |
utils.set_random_seed(args.seed + dist.rank)
|
| 61 |
|
| 62 |
# Setup text encoder
|
|
|
|
| 63 |
tokenizer, text_encoder = utils.setup_encoder(args, dist, device)
|
|
|
|
| 64 |
|
| 65 |
# Setup VAE if specified
|
| 66 |
vae = None
|
| 67 |
if args.vae is not None:
|
|
|
|
| 68 |
vae = utils.setup_vae(args, dist, device)
|
| 69 |
args.img_size = args.img_size // vae.downsample_factor
|
|
|
|
| 70 |
else:
|
| 71 |
args.finetuned_vae = 'none'
|
| 72 |
|
| 73 |
# Setup main transformer model
|
|
|
|
| 74 |
model = utils.setup_transformer(
|
| 75 |
args, dist,
|
| 76 |
txt_dim=text_encoder.config.hidden_size,
|
| 77 |
use_checkpoint=1
|
| 78 |
-
)
|
| 79 |
|
| 80 |
-
# Load checkpoint
|
| 81 |
-
print(f"Loading checkpoint from
|
| 82 |
state_dict = torch.load(args.checkpoint_path, map_location='cpu')
|
| 83 |
model.load_state_dict(state_dict, strict=False)
|
| 84 |
-
del state_dict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
|
| 86 |
# Set model to eval mode and disable gradients
|
| 87 |
for p in model.parameters():
|
|
@@ -90,6 +103,7 @@ def setup_model_and_components(args: argparse.Namespace) -> Tuple[torch.nn.Modul
|
|
| 90 |
|
| 91 |
# Parallelize model for multi-GPU sampling
|
| 92 |
_, model = utils.parallelize_model(args, model, dist, device)
|
|
|
|
| 93 |
|
| 94 |
return model, vae, (tokenizer, text_encoder, dist, device)
|
| 95 |
|
|
|
|
| 17 |
import copy
|
| 18 |
import pathlib
|
| 19 |
import time
|
| 20 |
+
import gc
|
| 21 |
from typing import Dict, List, Optional, Tuple, Union
|
| 22 |
|
| 23 |
import numpy as np
|
|
|
|
| 61 |
utils.set_random_seed(args.seed + dist.rank)
|
| 62 |
|
| 63 |
# Setup text encoder
|
| 64 |
+
print("Loading text encoder...")
|
| 65 |
tokenizer, text_encoder = utils.setup_encoder(args, dist, device)
|
| 66 |
+
torch.cuda.empty_cache() # Clear cache after text encoder
|
| 67 |
|
| 68 |
# Setup VAE if specified
|
| 69 |
vae = None
|
| 70 |
if args.vae is not None:
|
| 71 |
+
print("Loading VAE...")
|
| 72 |
vae = utils.setup_vae(args, dist, device)
|
| 73 |
args.img_size = args.img_size // vae.downsample_factor
|
| 74 |
+
torch.cuda.empty_cache() # Clear cache after VAE
|
| 75 |
else:
|
| 76 |
args.finetuned_vae = 'none'
|
| 77 |
|
| 78 |
# Setup main transformer model
|
| 79 |
+
print("Loading main transformer model...")
|
| 80 |
model = utils.setup_transformer(
|
| 81 |
args, dist,
|
| 82 |
txt_dim=text_encoder.config.hidden_size,
|
| 83 |
use_checkpoint=1
|
| 84 |
+
)
|
| 85 |
|
| 86 |
+
# Load checkpoint to CPU first, then move to GPU
|
| 87 |
+
print(f"Loading checkpoint from: {args.checkpoint_path}")
|
| 88 |
state_dict = torch.load(args.checkpoint_path, map_location='cpu')
|
| 89 |
model.load_state_dict(state_dict, strict=False)
|
| 90 |
+
del state_dict
|
| 91 |
+
gc.collect() # Force garbage collection
|
| 92 |
+
torch.cuda.empty_cache() # Clear any GPU cache
|
| 93 |
+
|
| 94 |
+
# Move model to GPU after loading weights
|
| 95 |
+
print("Moving model to GPU...")
|
| 96 |
+
model = model.to(device)
|
| 97 |
+
torch.cuda.empty_cache() # Clear cache after moving to GPU
|
| 98 |
|
| 99 |
# Set model to eval mode and disable gradients
|
| 100 |
for p in model.parameters():
|
|
|
|
| 103 |
|
| 104 |
# Parallelize model for multi-GPU sampling
|
| 105 |
_, model = utils.parallelize_model(args, model, dist, device)
|
| 106 |
+
torch.cuda.empty_cache() # Final cache clear
|
| 107 |
|
| 108 |
return model, vae, (tokenizer, text_encoder, dist, device)
|
| 109 |
|