leoeric commited on
Commit
3a6a9cd
Β·
1 Parent(s): bd6dbaf

Fix CUDA out of memory error by clearing cache between model loads and enabling expandable segments

Browse files
Files changed (2) hide show
  1. app.py +39 -5
  2. sample.py +18 -4
app.py CHANGED
@@ -10,6 +10,8 @@ import os
10
  os.environ['OMP_NUM_THREADS'] = '1'
11
  os.environ['MKL_NUM_THREADS'] = '1'
12
  os.environ['NUMEXPR_NUM_THREADS'] = '1'
 
 
13
 
14
  import warnings
15
  import gradio as gr
@@ -237,17 +239,49 @@ def generate_image(prompt, aspect_ratio, cfg, seed, checkpoint_file, config_path
237
 
238
  # Find the generated image
239
  # The sample.py script saves to logdir/model_name/...
240
- # We need to find the most recent output
241
- output_files = list(output_dir.glob("**/*.png")) + list(output_dir.glob("**/*.jpg"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
  if output_files:
 
243
  latest_file = max(output_files, key=lambda p: p.stat().st_mtime)
 
244
  return str(latest_file), status_msg + "βœ… Success! Image generated."
245
  else:
246
- error_msg = status_msg + f"Error: Generated image not found in {output_dir}."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
  if log_content:
248
- error_msg += f"\n\nπŸ“‹ Check log file for details: {log_file}\nLast 1000 chars:\n{log_content[-1000:]}"
249
  else:
250
- error_msg += f"\n\nCheck stdout:\n{result.stdout}"
251
  return None, error_msg
252
 
253
  except Exception as e:
 
10
  os.environ['OMP_NUM_THREADS'] = '1'
11
  os.environ['MKL_NUM_THREADS'] = '1'
12
  os.environ['NUMEXPR_NUM_THREADS'] = '1'
13
+ # Fix CUDA memory fragmentation
14
+ os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
15
 
16
  import warnings
17
  import gradio as gr
 
239
 
240
  # Find the generated image
241
  # The sample.py script saves to logdir/model_name/...
242
+ # Model name is derived from checkpoint path stem
243
+ checkpoint_stem = Path(checkpoint_path).stem
244
+ model_output_dir = output_dir / checkpoint_stem
245
+
246
+ status_msg += f"Searching in: {model_output_dir}\n"
247
+ status_msg += f"Also searching recursively in: {output_dir}\n"
248
+
249
+ # Search in model-specific directory first, then recursively
250
+ search_paths = [model_output_dir, output_dir]
251
+ output_files = []
252
+
253
+ for search_path in search_paths:
254
+ if search_path.exists():
255
+ # Look for PNG, JPG, JPEG files
256
+ found = list(search_path.glob("**/*.png")) + list(search_path.glob("**/*.jpg")) + list(search_path.glob("**/*.jpeg"))
257
+ output_files.extend(found)
258
+ status_msg += f"Found {len(found)} files in {search_path}\n"
259
+
260
  if output_files:
261
+ # Get the most recent file
262
  latest_file = max(output_files, key=lambda p: p.stat().st_mtime)
263
+ status_msg += f"βœ… Found image: {latest_file}\n"
264
  return str(latest_file), status_msg + "βœ… Success! Image generated."
265
  else:
266
+ # Debug: list what's actually in the directory
267
+ debug_info = f"\n\nDebug info:\n"
268
+ debug_info += f"Output dir exists: {output_dir.exists()}\n"
269
+ if output_dir.exists():
270
+ debug_info += f"Contents of {output_dir}:\n"
271
+ for item in output_dir.iterdir():
272
+ debug_info += f" - {item.name} ({'dir' if item.is_dir() else 'file'})\n"
273
+ if model_output_dir.exists():
274
+ debug_info += f"\nContents of {model_output_dir}:\n"
275
+ for item in model_output_dir.iterdir():
276
+ debug_info += f" - {item.name} ({'dir' if item.is_dir() else 'file'})\n"
277
+
278
+ error_msg = status_msg + f"Error: Generated image not found.\n"
279
+ error_msg += f"Searched in: {output_dir} and {model_output_dir}\n"
280
+ error_msg += debug_info
281
  if log_content:
282
+ error_msg += f"\n\nπŸ“‹ Check log file for details: {log_file}\nLast 2000 chars:\n{log_content[-2000:]}"
283
  else:
284
+ error_msg += f"\n\nCheck stdout:\n{result.stdout[-1000:]}"
285
  return None, error_msg
286
 
287
  except Exception as e:
sample.py CHANGED
@@ -17,6 +17,7 @@ import argparse
17
  import copy
18
  import pathlib
19
  import time
 
20
  from typing import Dict, List, Optional, Tuple, Union
21
 
22
  import numpy as np
@@ -60,28 +61,40 @@ def setup_model_and_components(args: argparse.Namespace) -> Tuple[torch.nn.Modul
60
  utils.set_random_seed(args.seed + dist.rank)
61
 
62
  # Setup text encoder
 
63
  tokenizer, text_encoder = utils.setup_encoder(args, dist, device)
 
64
 
65
  # Setup VAE if specified
66
  vae = None
67
  if args.vae is not None:
 
68
  vae = utils.setup_vae(args, dist, device)
69
  args.img_size = args.img_size // vae.downsample_factor
 
70
  else:
71
  args.finetuned_vae = 'none'
72
 
73
  # Setup main transformer model
 
74
  model = utils.setup_transformer(
75
  args, dist,
76
  txt_dim=text_encoder.config.hidden_size,
77
  use_checkpoint=1
78
- ).to(device)
79
 
80
- # Load checkpoint
81
- print(f"Loading checkpoint from local path: {args.checkpoint_path}")
82
  state_dict = torch.load(args.checkpoint_path, map_location='cpu')
83
  model.load_state_dict(state_dict, strict=False)
84
- del state_dict; torch.cuda.empty_cache()
 
 
 
 
 
 
 
85
 
86
  # Set model to eval mode and disable gradients
87
  for p in model.parameters():
@@ -90,6 +103,7 @@ def setup_model_and_components(args: argparse.Namespace) -> Tuple[torch.nn.Modul
90
 
91
  # Parallelize model for multi-GPU sampling
92
  _, model = utils.parallelize_model(args, model, dist, device)
 
93
 
94
  return model, vae, (tokenizer, text_encoder, dist, device)
95
 
 
17
  import copy
18
  import pathlib
19
  import time
20
+ import gc
21
  from typing import Dict, List, Optional, Tuple, Union
22
 
23
  import numpy as np
 
61
  utils.set_random_seed(args.seed + dist.rank)
62
 
63
  # Setup text encoder
64
+ print("Loading text encoder...")
65
  tokenizer, text_encoder = utils.setup_encoder(args, dist, device)
66
+ torch.cuda.empty_cache() # Clear cache after text encoder
67
 
68
  # Setup VAE if specified
69
  vae = None
70
  if args.vae is not None:
71
+ print("Loading VAE...")
72
  vae = utils.setup_vae(args, dist, device)
73
  args.img_size = args.img_size // vae.downsample_factor
74
+ torch.cuda.empty_cache() # Clear cache after VAE
75
  else:
76
  args.finetuned_vae = 'none'
77
 
78
  # Setup main transformer model
79
+ print("Loading main transformer model...")
80
  model = utils.setup_transformer(
81
  args, dist,
82
  txt_dim=text_encoder.config.hidden_size,
83
  use_checkpoint=1
84
+ )
85
 
86
+ # Load checkpoint to CPU first, then move to GPU
87
+ print(f"Loading checkpoint from: {args.checkpoint_path}")
88
  state_dict = torch.load(args.checkpoint_path, map_location='cpu')
89
  model.load_state_dict(state_dict, strict=False)
90
+ del state_dict
91
+ gc.collect() # Force garbage collection
92
+ torch.cuda.empty_cache() # Clear any GPU cache
93
+
94
+ # Move model to GPU after loading weights
95
+ print("Moving model to GPU...")
96
+ model = model.to(device)
97
+ torch.cuda.empty_cache() # Clear cache after moving to GPU
98
 
99
  # Set model to eval mode and disable gradients
100
  for p in model.parameters():
 
103
 
104
  # Parallelize model for multi-GPU sampling
105
  _, model = utils.parallelize_model(args, model, dist, device)
106
+ torch.cuda.empty_cache() # Final cache clear
107
 
108
  return model, vae, (tokenizer, text_encoder, dist, device)
109