Spaces:
Sleeping
Sleeping
| # | |
| # For licensing see accompanying LICENSE file. | |
| # Copyright (C) 2025 Apple Inc. All Rights Reserved. | |
| # | |
| #!/usr/bin/env python3 | |
| """ | |
| Scalable Transformer Autoregressive Flow (STARFlow) Sampling Script | |
| This script provides functionality for sampling from trained transformer autoregressive flow models. | |
| Supports both image and video generation with various conditioning options. | |
| Usage: | |
| python sample.py --model_config_path config.yaml --checkpoint_path model.pth --caption "A cat" | |
| """ | |
| import argparse | |
| import copy | |
| import pathlib | |
| import time | |
| import gc | |
| from typing import Dict, List, Optional, Tuple, Union | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| import torch.utils.data | |
| import torchvision as tv | |
| import tqdm | |
| import yaml | |
| from einops import repeat | |
| from PIL import Image | |
| # Local imports | |
| import transformer_flow | |
| import utils | |
| from dataset import aspect_ratio_to_image_size | |
| from train import get_tarflow_parser | |
| from utils import process_denoising, save_samples_unified, load_model_config, encode_text, add_noise | |
| from transformer_flow import KVCache | |
| from misc import print | |
| # Default caption templates for testing and demonstrations | |
| DEFAULT_CAPTIONS = { | |
| 'template1': "In the image, a corgi dog is wearing a Santa hat and is laying on a fluffy rug. The dog's tongue is sticking out and it appears to be happy. There are two pumpkins and a basket of leaves nearby, indicating that the scene takes place during the fall season. The background features a Christmas tree, further suggesting the holiday atmosphere. The image has a warm and cozy feel to it, with the dog looking adorable in its hat and the pumpkins adding a festive touch.", | |
| 'template2': "A close-up portrait of a cheerful Corgi dog, showcasing its fluffy, sandy-brown fur and perky ears. The dog has a friendly expression with a slight smile, looking directly into the camera. Set against a soft, natural green background, the image is captured in a high-definition, realistic photography style, emphasizing the texture of the fur and the vibrant colors.", | |
| 'template3': "A high-resolution, wide-angle selfie photograph of Albert Einstein in a garden setting. Einstein looks directly into the camera with a gentle, knowing smile. His distinctive wild white hair and bushy mustache frame a face marked by thoughtful wrinkles. He wears a classic tweed jacket over a simple shirt. In the background, lush greenery and flowering bushes under soft daylight create a serene, scholarly atmosphere. Ultra-realistic style, 4K detail.", | |
| 'template4': 'A close-up, high-resolution selfie of a red panda perched on a tree branch, its large dark eyes looking directly into the lens. Rich reddish-orange fur with white facial markings contrasts against the lush green bamboo forest behind. Soft sunlight filters through the leaves, casting a warm, natural glow over the scene. Ultra-realistic detail, digital photograph style, 4K resolution.', | |
| 'template5': "A realistic selfie of a llama standing in front of a classic Ivy League building on the Princeton University campus. He is smiling gently, wearing his iconic wild hair and mustache, dressed in a wool sweater and collared shirt. The photo has a vintage, slightly sepia tone, with soft natural lighting and leafy trees in the background, capturing an academic and historical vibe.", | |
| } | |
| def setup_model_and_components(args: argparse.Namespace) -> Tuple[torch.nn.Module, Optional[torch.nn.Module], tuple]: | |
| """Initialize and load the model, VAE, and text encoder.""" | |
| # Initialize distributed training context | |
| # For single GPU inference, we still need to initialize process group | |
| # because the model code uses torch.distributed.all_reduce | |
| dist = utils.Distributed() | |
| # If not running with torchrun, initialize single-process group | |
| # This is needed because the model code uses torch.distributed.all_reduce | |
| # Works for both CUDA and CPU modes | |
| if not dist.distributed: | |
| import os | |
| # Initialize single-process process group for model compatibility | |
| if not torch.distributed.is_initialized(): | |
| os.environ['MASTER_ADDR'] = 'localhost' | |
| os.environ['MASTER_PORT'] = '12355' | |
| os.environ['RANK'] = '0' | |
| os.environ['LOCAL_RANK'] = '0' | |
| os.environ['WORLD_SIZE'] = '1' | |
| # Use 'nccl' for CUDA, 'gloo' for CPU | |
| backend = 'nccl' if torch.cuda.is_available() else 'gloo' | |
| torch.distributed.init_process_group( | |
| backend=backend, | |
| init_method='env://', | |
| world_size=1, | |
| rank=0, | |
| ) | |
| print(f"✅ Initialized single-process distributed group (backend: {backend}) for model compatibility") | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Set random seed | |
| utils.set_random_seed(args.seed + dist.rank) | |
| # Setup text encoder | |
| print("Loading text encoder...") | |
| tokenizer, text_encoder = utils.setup_encoder(args, dist, device) | |
| torch.cuda.empty_cache() # Clear cache after text encoder | |
| # Setup VAE if specified | |
| vae = None | |
| if args.vae is not None: | |
| print("Loading VAE...") | |
| vae = utils.setup_vae(args, dist, device) | |
| args.img_size = args.img_size // vae.downsample_factor | |
| torch.cuda.empty_cache() # Clear cache after VAE | |
| else: | |
| args.finetuned_vae = 'none' | |
| # Setup main transformer model | |
| print("Loading main transformer model...") | |
| model = utils.setup_transformer( | |
| args, dist, | |
| txt_dim=text_encoder.config.hidden_size, | |
| use_checkpoint=1 | |
| ) | |
| # Load checkpoint to CPU first, then move to GPU | |
| print(f"Loading checkpoint from: {args.checkpoint_path}") | |
| state_dict = torch.load(args.checkpoint_path, map_location='cpu') | |
| model.load_state_dict(state_dict, strict=False) | |
| del state_dict | |
| gc.collect() # Force garbage collection | |
| torch.cuda.empty_cache() # Clear any GPU cache | |
| # Move model to GPU after loading weights | |
| print("Moving model to GPU...") | |
| model = model.to(device) | |
| torch.cuda.empty_cache() # Clear cache after moving to GPU | |
| # Set model to eval mode and disable gradients | |
| for p in model.parameters(): | |
| p.requires_grad = False | |
| model.eval() | |
| # Parallelize model for multi-GPU sampling (do this before half precision conversion) | |
| _, model = utils.parallelize_model(args, model, dist, device) | |
| torch.cuda.empty_cache() | |
| # Convert model to half precision for memory efficiency (if CUDA available) | |
| # Do this AFTER parallelization to avoid issues | |
| if torch.cuda.is_available(): | |
| # Use bfloat16 if supported, otherwise float16 | |
| if torch.cuda.is_bf16_supported(): | |
| model = model.to(torch.bfloat16) | |
| print("✅ Converted model to bfloat16 for memory efficiency") | |
| else: | |
| model = model.to(torch.float16) | |
| print("✅ Converted model to float16 for memory efficiency") | |
| torch.cuda.empty_cache() | |
| # Print memory usage | |
| allocated = torch.cuda.memory_allocated(0) / 1024**3 | |
| reserved = torch.cuda.memory_reserved(0) / 1024**3 | |
| total = torch.cuda.get_device_properties(0).total_memory / 1024**3 | |
| print(f"📊 GPU Memory: {allocated:.2f} GB allocated, {reserved:.2f} GB reserved, {total:.2f} GB total") | |
| torch.cuda.empty_cache() # Final cache clear | |
| return model, vae, (tokenizer, text_encoder, dist, device) | |
| def prepare_captions(args: argparse.Namespace, dist) -> Tuple[List[str], List[int], int, str]: | |
| """Prepare captions for sampling from file or template.""" | |
| if args.caption.endswith('.txt'): | |
| with open(args.caption, 'r') as f: | |
| lines = [line.strip() for line in f.readlines()] | |
| num_samples = len(lines) | |
| fixed_y = lines[dist.rank:][::dist.world_size] | |
| fixed_idxs = list(range(len(lines)))[dist.rank:][::dist.world_size] | |
| caption_name = args.caption.split('/')[-1][:-4] | |
| else: | |
| caption_text = DEFAULT_CAPTIONS.get(args.caption, args.caption) | |
| fixed_y = [caption_text] * args.sample_batch_size | |
| fixed_idxs = [] | |
| num_samples = args.sample_batch_size * dist.world_size | |
| caption_name = args.caption | |
| return fixed_y, fixed_idxs, num_samples, caption_name | |
| def get_noise_shape(args: argparse.Namespace, vae) -> callable: | |
| """Generate noise tensor with appropriate shape for sampling.""" | |
| def _get_noise_func(b: int, x_shape: tuple) -> torch.Tensor: | |
| rand_shape = [args.channel_size, x_shape[0], x_shape[1]] | |
| if len(x_shape) == 3: | |
| rand_shape = [x_shape[2]] + rand_shape | |
| if vae is not None: | |
| if args.vid_size is not None: | |
| rand_shape[0] = (rand_shape[0] - 1) // vae.temporal_downsample_factor + 1 | |
| rand_shape[-2] //= vae.downsample_factor | |
| rand_shape[-1] //= vae.downsample_factor | |
| return torch.randn(b, *rand_shape) | |
| return _get_noise_func | |
| def prepare_input_image(args: argparse.Namespace, x_shape: tuple, vae, device: torch.device, noise_std: float) -> Optional[torch.Tensor]: | |
| """Load and preprocess input image for conditional generation.""" | |
| input_image = Image.open(args.input_image).convert('RGB') | |
| # Resize and crop to target shape | |
| scale = max(x_shape[0] / input_image.height, x_shape[1] / input_image.width) | |
| transform = tv.transforms.Compose([ | |
| tv.transforms.Resize((int(input_image.height * scale), int(input_image.width * scale))), | |
| tv.transforms.CenterCrop(x_shape[:2]), | |
| tv.transforms.ToTensor(), | |
| tv.transforms.Normalize([0.5]*3, [0.5]*3) | |
| ]) | |
| input_image = transform(input_image).unsqueeze(0).to(device) | |
| # Encode with VAE if available | |
| with torch.no_grad(): | |
| if vae is not None: | |
| input_image = vae.encode(input_image) | |
| # Add noise | |
| input_image = add_noise(input_image, noise_std)[0] | |
| return input_image | |
| def build_sampling_kwargs(args: argparse.Namespace, caption_name: str) -> dict: | |
| """Build sampling keyword arguments based on configuration.""" | |
| sampling_kwargs = { | |
| 'guidance': args.cfg, | |
| 'guide_top': args.guide_top, | |
| 'verbose': not caption_name.endswith('/'), | |
| 'return_sequence': args.return_sequence, | |
| 'jacobi': args.jacobi, | |
| 'context_length': args.context_length | |
| } | |
| if args.jacobi: | |
| sampling_kwargs.update({ | |
| 'jacobi_th': args.jacobi_th, | |
| 'jacobi_block_size': args.jacobi_block_size, | |
| 'jacobi_max_iter': args.jacobi_max_iter | |
| }) | |
| else: | |
| sampling_kwargs.update({ | |
| 'attn_temp': args.attn_temp, | |
| 'annealed_guidance': False | |
| }) | |
| return sampling_kwargs | |
| def main(args: argparse.Namespace) -> None: | |
| """Main sampling function.""" | |
| # Load model configuration and merge with command line args | |
| trainer_args = load_model_config(args.model_config_path) | |
| trainer_dict = vars(trainer_args) | |
| trainer_dict.update(vars(args)) | |
| args = argparse.Namespace(**trainer_dict) | |
| # Handle target length configuration for video | |
| if args.target_length is not None: | |
| assert args.vid_size is not None, "it must be a video model to use target_length" | |
| assert args.jacobi == 1, "target_length is only supported with jacobi sampling" | |
| if args.target_length == 1: # generate single image | |
| args.vid_size = None | |
| args.out_fps = 0 | |
| else: | |
| args.local_attn_window = (int(args.vid_size.split(':')[0]) - 1) // 4 + 1 | |
| args.vid_size = f"{args.target_length}:16" | |
| if args.context_length is None: | |
| args.context_length = args.local_attn_window - 1 | |
| # Override some settings for sampling | |
| # Disable FSDP for single GPU inference (FSDP can cause CPU fallback) | |
| args.fsdp = 0 # Disable FSDP for single GPU - use regular GPU inference | |
| if args.use_pretrained_lm is not None: | |
| args.text = args.use_pretrained_lm | |
| # Setup model and components | |
| model, vae, (tokenizer, text_encoder, dist, device) = setup_model_and_components(args) | |
| # Setup output directory | |
| model_name = pathlib.Path(args.checkpoint_path).stem | |
| sample_dir: pathlib.Path = args.logdir / f'{model_name}' | |
| if dist.local_rank == 0: | |
| sample_dir.mkdir(parents=True, exist_ok=True) | |
| dist.barrier() | |
| print(f'{" Load ":-^80} {model_name}') | |
| # Prepare captions and sampling parameters | |
| fixed_y, fixed_idxs, num_samples, caption_name = prepare_captions(args, dist) | |
| print(f'Sampling {num_samples} from {args.caption} on {dist.world_size} GPU(s)') | |
| get_noise = get_noise_shape(args, vae) | |
| sampling_kwargs = build_sampling_kwargs(args, caption_name) | |
| noise_std = args.target_noise_std if args.target_noise_std else args.noise_std | |
| # Start sampling | |
| print(f'Starting sampling with global batch size {args.sample_batch_size}x{dist.world_size} GPUs') | |
| if torch.cuda.is_available(): | |
| torch.cuda.synchronize() | |
| start_time = time.time() | |
| with torch.no_grad(): | |
| device_type = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| # Use bfloat16 for CUDA (memory efficient), float32 for CPU | |
| if torch.cuda.is_available() and torch.cuda.is_bf16_supported(): | |
| autocast_dtype = torch.bfloat16 | |
| elif torch.cuda.is_available(): | |
| autocast_dtype = torch.float16 | |
| else: | |
| autocast_dtype = torch.float32 | |
| with torch.autocast(device_type=device_type, dtype=autocast_dtype): | |
| for i in tqdm.tqdm(range(int(np.ceil(num_samples / (args.sample_batch_size * dist.world_size))))): | |
| # Determine aspect ratio and image shape | |
| x_aspect = args.aspect_ratio if args.mix_aspect else None | |
| if x_aspect == "random": | |
| x_aspect = np.random.choice([ | |
| "1:1", "2:3", "3:2", "16:9", "9:16", "4:5", "5:4", "21:9", "9:21" | |
| ]) | |
| x_shape = aspect_ratio_to_image_size( | |
| args.img_size * vae.downsample_factor, x_aspect, | |
| multiple=vae.downsample_factor * args.patch_size | |
| ) | |
| # Setup text encoder kwargs | |
| text_encoder_kwargs = dict( | |
| aspect_ratio=x_aspect, | |
| fps=args.out_fps if args.fps_cond else None, | |
| noise_std=noise_std if args.cond_noise_level else None | |
| ) | |
| # Handle video dimensions | |
| if args.vid_size is not None: | |
| vid_size = tuple(map(int, args.vid_size.split(':'))) | |
| out_fps = args.out_fps if args.fps_cond else vid_size[1] | |
| num_frames = vid_size[0] | |
| x_shape = (x_shape[0], x_shape[1], num_frames) | |
| else: | |
| out_fps = args.out_fps | |
| # Prepare batch and captions | |
| b = args.sample_batch_size | |
| y = fixed_y[i * b : (i + 1) * b] | |
| y_caption = copy.deepcopy(y) | |
| # Add null captions for CFG | |
| if args.cfg > 0: | |
| y += [""] * len(y) | |
| # Prepare text & noise | |
| y = encode_text(text_encoder, tokenizer, y, args.txt_size, device, **text_encoder_kwargs) | |
| noise = get_noise(len(y_caption), x_shape).to(device) | |
| # Prepare input image if specified | |
| if args.input_image is not None: | |
| input_image = prepare_input_image(args, x_shape, vae, device, noise_std) | |
| input_image = repeat(input_image, '1 c h w -> b c h w', b=b) | |
| assert args.cfg > 0, "CFG is required for image conditioned generation" | |
| kv_caches = model(input_image.unsqueeze(1), y, context=True) | |
| else: | |
| input_image, kv_caches = None, None | |
| # Generate samples | |
| samples = model(noise, y, reverse=True, kv_caches=kv_caches, **sampling_kwargs) | |
| del kv_caches; torch.cuda.empty_cache() # free up memory | |
| # Apply denoising if enabled | |
| samples = process_denoising( | |
| samples, y_caption, args, model, text_encoder, | |
| tokenizer, text_encoder_kwargs, noise_std | |
| ) | |
| # Decode with VAE if available | |
| if args.vae is not None: | |
| dec_fn = vae.decode | |
| else: | |
| dec_fn = lambda x: x | |
| if isinstance(samples, list): | |
| samples = torch.cat([dec_fn(s) for s in samples], dim=-1) | |
| else: | |
| samples = dec_fn(samples) | |
| # Save samples using unified function | |
| print(f' Saving samples ... {sample_dir}') | |
| # Determine save mode based on args | |
| if args.save_folder and args.caption.endswith('.txt'): | |
| grid_mode = "individual" # Save individual files when using caption file | |
| else: | |
| grid_mode = "auto" # Use automatic grid arrangement | |
| save_samples_unified( | |
| samples=samples, | |
| save_dir=sample_dir, | |
| filename_prefix=caption_name[:200] if len(caption_name) > 0 else "samples", | |
| epoch_or_iter=i, | |
| fps=out_fps, | |
| dist=dist, | |
| wandb_log=False, # Let sample.py handle its own wandb logging | |
| grid_arrangement=grid_mode | |
| ) | |
| # Print timing statistics | |
| if torch.cuda.is_available(): | |
| torch.cuda.synchronize() | |
| elapsed_time = time.time() - start_time | |
| print(f'{model_name} cfg {args.cfg:.2f}, bsz={args.sample_batch_size}x{dist.world_size}, ' | |
| f'time={elapsed_time:.2f}s, speed={num_samples / elapsed_time:.2f} images/s') | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser() | |
| # Model config | |
| parser.add_argument('--model_config_path', required=True, type=str, help='path to YAML config file or directory containing config file') | |
| parser.add_argument('--checkpoint_path', required=True, type=str, help='path to local checkpoint file (required when using model_config_path)') | |
| parser.add_argument('--logdir', default='./logs', type=pathlib.Path, help='output directory for generated samples') | |
| parser.add_argument('--save_folder', default=0, type=int) | |
| # Caption, condition | |
| parser.add_argument('--caption', type=str, required=True, help='Caption input (required)') | |
| parser.add_argument('--input_image', default=None, type=str, help='path to the input image for image-conditioned generation') | |
| parser.add_argument('--aspect_ratio', default="1:1", type=str, choices=["random", "1:1", "2:3", "3:2", "16:9", "9:16", "4:5", "5:4", "21:9", "9:21"]) | |
| parser.add_argument('--out_fps', default=8, type=int, help='fps for video datasets, only useful if fps_cond is set to 1') | |
| # Sampling parameters | |
| parser.add_argument('--seed', default=191, type=int) | |
| parser.add_argument('--denoising_batch_size', default=1, type=int) | |
| parser.add_argument('--self_denoising_lr', default=1, type=float) | |
| parser.add_argument('--disable_learnable_denoiser', default=0, type=int) | |
| parser.add_argument('--attn_temp', default=1, type=float) | |
| parser.add_argument('--jacobi_th', default=0.005, type=float) | |
| parser.add_argument('--jacobi', default=0, type=int) | |
| parser.add_argument('--jacobi_block_size', default=64, type=int) | |
| parser.add_argument('--jacobi_max_iter', default=32, type=int) | |
| parser.add_argument('--num_samples', default=50000, type=int) | |
| parser.add_argument('--sample_batch_size', default=16, type=int) | |
| parser.add_argument('--return_sequence', default=0, type=int) | |
| parser.add_argument('--cfg', default=5, type=float) | |
| parser.add_argument('--guide_top', default=None, type=int) | |
| parser.add_argument('--finetuned_vae', default="px82zaheuu", type=str) | |
| parser.add_argument('--vae_adapter', default=None) | |
| parser.add_argument('--target_noise_std', default=None, help="option to use different noise_std from the config") | |
| # Video-specific parameters | |
| parser.add_argument('--target_length', default=None, type=int, help="target length maybe longer than training") | |
| parser.add_argument('--context_length', default=16, type=int, help="context length used for consective sampling") | |
| args = parser.parse_args() | |
| if args.input_image and args.input_image == 'none': | |
| args.input_image = None | |
| main(args) |