Spaces:
Sleeping
Sleeping
Ensure train.py, dataset.py, and misc/ are included in Space
Browse files
train.py
ADDED
|
@@ -0,0 +1,555 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# For licensing see accompanying LICENSE file.
|
| 3 |
+
# Copyright (C) 2025 Apple Inc. All Rights Reserved.
|
| 4 |
+
#
|
| 5 |
+
#!/usr/bin/env python3
|
| 6 |
+
"""
|
| 7 |
+
Scalable Transformer Autoregressive Flow (STARFlow) Training Script
|
| 8 |
+
|
| 9 |
+
This script provides functionality for training transformer autoregressive flow models
|
| 10 |
+
with support for both image and video generation.
|
| 11 |
+
|
| 12 |
+
Usage:
|
| 13 |
+
python train.py --model_config_path config.yaml --epochs 100
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
import argparse
|
| 17 |
+
import builtins
|
| 18 |
+
import pathlib
|
| 19 |
+
import copy
|
| 20 |
+
import torch
|
| 21 |
+
import torch.nn.functional as F
|
| 22 |
+
import torchinfo
|
| 23 |
+
import torch.amp
|
| 24 |
+
import torch.utils
|
| 25 |
+
import torch.utils.data
|
| 26 |
+
import torchvision as tv
|
| 27 |
+
import numpy as np
|
| 28 |
+
import random
|
| 29 |
+
import transformer_flow
|
| 30 |
+
import utils
|
| 31 |
+
import time
|
| 32 |
+
import contextlib
|
| 33 |
+
import tqdm
|
| 34 |
+
import os
|
| 35 |
+
import gc
|
| 36 |
+
import sys
|
| 37 |
+
import wandb
|
| 38 |
+
import yaml
|
| 39 |
+
from typing import Dict, List, Optional, Tuple, Union
|
| 40 |
+
|
| 41 |
+
from dataset import read_tsv, aspect_ratio_to_image_size
|
| 42 |
+
from contextlib import nullcontext
|
| 43 |
+
from misc import print # local_rank=0 print
|
| 44 |
+
from utils import simple_denoising, save_samples_unified, add_noise, encode_text, drop_label, load_model_config
|
| 45 |
+
|
| 46 |
+
# Set environment variables for local development
|
| 47 |
+
os.environ["PYTHONPATH"] = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) + ":" + os.environ.get("PYTHONPATH", "")
|
| 48 |
+
WANDB_API_KEY = os.environ.get("WANDB_API_KEY", None)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def self_denoise(model, samples, y, noise_std=0.1, lr=1, steps=1, disable_learnable_denoiser=False):
|
| 52 |
+
if steps == 0:
|
| 53 |
+
return samples
|
| 54 |
+
|
| 55 |
+
outputs = []
|
| 56 |
+
x = samples.clone()
|
| 57 |
+
lr = noise_std ** 2 * lr
|
| 58 |
+
with torch.enable_grad():
|
| 59 |
+
x.requires_grad = True
|
| 60 |
+
model.train()
|
| 61 |
+
z, _, _, logdets = model(x, y)
|
| 62 |
+
loss = model.get_loss(z, logdets)['loss'] * 65536
|
| 63 |
+
grad = float(samples.numel()) / 65536 * torch.autograd.grad(loss, [x])[0]
|
| 64 |
+
outputs += [(x - grad * lr).detach()]
|
| 65 |
+
x = torch.cat(outputs, -1)
|
| 66 |
+
return x
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def main(args):
|
| 76 |
+
# Load model configuration if provided
|
| 77 |
+
if hasattr(args, 'model_config_path') and args.model_config_path:
|
| 78 |
+
# Parse sys.argv to see which args were actually provided
|
| 79 |
+
provided_args = set()
|
| 80 |
+
for i, arg in enumerate(sys.argv[1:]):
|
| 81 |
+
if arg.startswith('--'):
|
| 82 |
+
arg_name = arg[2:].replace('-', '_')
|
| 83 |
+
provided_args.add(arg_name)
|
| 84 |
+
trainer_args = load_model_config(args.model_config_path)
|
| 85 |
+
trainer_dict = vars(trainer_args)
|
| 86 |
+
for k, v in vars(args).items():
|
| 87 |
+
if k in provided_args:
|
| 88 |
+
trainer_dict[k] = v
|
| 89 |
+
args = argparse.Namespace(**trainer_dict)
|
| 90 |
+
|
| 91 |
+
# global setup
|
| 92 |
+
dist = utils.Distributed()
|
| 93 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 94 |
+
seed = args.train_seed if args.train_seed is not None else time.time_ns() % 2**32
|
| 95 |
+
utils.set_random_seed(seed + dist.rank)
|
| 96 |
+
print(f'set random seed {seed}')
|
| 97 |
+
|
| 98 |
+
if dist.rank == 0 and WANDB_API_KEY is not None:
|
| 99 |
+
job_name = f'{args.dataset}'
|
| 100 |
+
if args.wandb_name is not None:
|
| 101 |
+
wandb_names = args.wandb_name.split('+')
|
| 102 |
+
if len(wandb_names) > 1:
|
| 103 |
+
job_name += f'-{wandb_names[0]}-{getattr(args, wandb_names[1])}'
|
| 104 |
+
else:
|
| 105 |
+
job_name += f'-{wandb_names[0]}'
|
| 106 |
+
|
| 107 |
+
wandb.login(key=WANDB_API_KEY)
|
| 108 |
+
wandb.init(project="starflow", name=job_name, config=vars(args))
|
| 109 |
+
wandb.run.save()
|
| 110 |
+
wandb.run.log_code(os.path.dirname(os.path.realpath(__file__)))
|
| 111 |
+
|
| 112 |
+
if args.use_pretrained_lm is not None:
|
| 113 |
+
args.text = args.use_pretrained_lm # need to match the text embedder
|
| 114 |
+
|
| 115 |
+
print(f'{" Config ":-^80}')
|
| 116 |
+
for k, v in sorted(vars(args).items()):
|
| 117 |
+
print(f'{k:32s}: {v}')
|
| 118 |
+
|
| 119 |
+
# dataset
|
| 120 |
+
data_loader = utils.get_data(args, dist)
|
| 121 |
+
total_num_images = data_loader.dataset.total_num_samples
|
| 122 |
+
grad_accum_steps = max(args.acc, 1)
|
| 123 |
+
num_batches_before_acc = len(data_loader)
|
| 124 |
+
num_batches = num_batches_before_acc // grad_accum_steps
|
| 125 |
+
|
| 126 |
+
print(f'{" Dataset Info ":-^80}')
|
| 127 |
+
print(f'{num_batches} batches per epoch ({num_batches_before_acc} steps if consider {grad_accum_steps} accumulation), global batch size {args.batch_size} for {args.epochs} epochs')
|
| 128 |
+
print(f'So it is {num_batches * args.batch_size:,} images per epoch')
|
| 129 |
+
print(f'Target training on {args.batch_size * num_batches * args.epochs:,} images')
|
| 130 |
+
print(f'Total {total_num_images:,} unique training examples')
|
| 131 |
+
|
| 132 |
+
assert args.text is not None, "starflow needs text conditioning"
|
| 133 |
+
|
| 134 |
+
# text encoder
|
| 135 |
+
tokenizer, text_encoder = utils.setup_encoder(args, dist, device)
|
| 136 |
+
text_encoder.requires_grad_(False) # freeze text encoder
|
| 137 |
+
|
| 138 |
+
# VAE & fixed noise
|
| 139 |
+
if args.vae is not None:
|
| 140 |
+
vae = utils.setup_vae(args, dist, device)
|
| 141 |
+
vae.requires_grad_(False) # freeze VAE
|
| 142 |
+
args.img_size = args.img_size // vae.downsample_factor
|
| 143 |
+
|
| 144 |
+
# main model
|
| 145 |
+
model = utils.setup_transformer(args, dist,
|
| 146 |
+
txt_dim=text_encoder.config.hidden_size,
|
| 147 |
+
use_checkpoint=args.gradient_checkpoint,
|
| 148 |
+
use_checkpoint_mlp=args.gradient_checkpoint_mlp).to(device)
|
| 149 |
+
if dist.local_rank == 0:
|
| 150 |
+
torchinfo.summary(model)
|
| 151 |
+
|
| 152 |
+
# Load model before FSDP wrapping to support expansion
|
| 153 |
+
if args.resume_path:
|
| 154 |
+
print(f"Loading checkpoint from local path: {args.resume_path}")
|
| 155 |
+
state_dict = torch.load(args.resume_path, map_location='cpu')
|
| 156 |
+
model.load_state_dict(state_dict, strict=False)
|
| 157 |
+
del state_dict; torch.cuda.empty_cache()
|
| 158 |
+
epoch_start = args.resume_epoch if args.resume_epoch is not None else 0
|
| 159 |
+
else:
|
| 160 |
+
epoch_start = 0
|
| 161 |
+
|
| 162 |
+
# setup for training
|
| 163 |
+
model, model_ddp = utils.parallelize_model(args, model, dist, device)
|
| 164 |
+
if args.text and args.fsdp_text_encoder:
|
| 165 |
+
text_encoder = utils.parallelize_model(args, text_encoder, dist, device, [text_encoder.base_block_name])[1]
|
| 166 |
+
trainable_params = [p for k, p in model_ddp.named_parameters() if p.requires_grad and not k.startswith('learnable_self_denoiser')]
|
| 167 |
+
optimizer = torch.optim.AdamW(trainable_params, betas=(0.9, 0.95), lr=args.lr, weight_decay=1e-4)
|
| 168 |
+
warmup_steps = args.warmup_steps if args.warmup_steps is not None else num_batches
|
| 169 |
+
lr_schedule = utils.CosineLRSchedule(
|
| 170 |
+
optimizer, warmup_steps, args.epochs * num_batches, args.min_lr, args.lr)
|
| 171 |
+
if args.learnable_self_denoiser:
|
| 172 |
+
denoiser_optimizer = torch.optim.AdamW(model_ddp.learnable_self_denoiser.parameters(), lr=1e-4, weight_decay=1e-4)
|
| 173 |
+
denoiser_lr_schedule = utils.CosineLRSchedule(
|
| 174 |
+
denoiser_optimizer, warmup_steps, args.epochs * num_batches, 1e-6, 1e-4)
|
| 175 |
+
print('warmup_steps:', warmup_steps, 'num_batches:', num_batches, 'total steps:', args.epochs * num_batches)
|
| 176 |
+
|
| 177 |
+
# Adjust learning rate schedule and counters if resuming
|
| 178 |
+
lr_schedule.counter += epoch_start * num_batches
|
| 179 |
+
images_start = epoch_start * num_batches * args.batch_size
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
if args.loss_scaling:
|
| 183 |
+
scaler = torch.amp.GradScaler()
|
| 184 |
+
if args.learnable_self_denoiser:
|
| 185 |
+
denoiser_scaler = torch.amp.GradScaler()
|
| 186 |
+
|
| 187 |
+
noise_std = args.noise_std
|
| 188 |
+
model_name = f'{args.patch_size}_{args.channels}_{args.blocks}_{args.layers_per_block}_{noise_std:.2f}'
|
| 189 |
+
sample_dir: pathlib.Path = args.logdir / f'{args.dataset}_samples_{model_name}'
|
| 190 |
+
model_ckpt_file = args.logdir / f'{args.dataset}_model_{model_name}.pth'
|
| 191 |
+
opt_ckpt_file = args.logdir / f'{args.dataset}_opt_{model_name}.pth'
|
| 192 |
+
if dist.local_rank == 0:
|
| 193 |
+
sample_dir.mkdir(parents=True, exist_ok=True)
|
| 194 |
+
|
| 195 |
+
print(f'{" Training ":-^80}')
|
| 196 |
+
total_steps, total_images, total_training_time = epoch_start * num_batches, images_start, 0
|
| 197 |
+
for epoch in range(epoch_start, args.epochs):
|
| 198 |
+
metrics = utils.Metrics()
|
| 199 |
+
for it, (x, y, meta) in enumerate(data_loader):
|
| 200 |
+
if args.secondary_dataset is not None and random.random() < args.secondary_ratio:
|
| 201 |
+
x, y, meta = next(data_loader.secondary_loader) # load data from secondary dataset instead
|
| 202 |
+
y_caption = copy.deepcopy(y)
|
| 203 |
+
data_mode = 'image' if (x.dim() == 4) else 'video'
|
| 204 |
+
|
| 205 |
+
start_time = time.time()
|
| 206 |
+
x_aspect, video_mode = data_loader.dataset.get_batch_modes(x)
|
| 207 |
+
x = x.to(device)
|
| 208 |
+
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
|
| 209 |
+
# apply VAE over images
|
| 210 |
+
if args.vae is not None:
|
| 211 |
+
with torch.no_grad():
|
| 212 |
+
if data_mode == 'video' and args.last_frame_cond:
|
| 213 |
+
x_last = x[:, -4:-3] # use the last frame as additional condition
|
| 214 |
+
x = x[:, :-4]
|
| 215 |
+
x, x_last = vae.encode(x), vae.encode(x_last)
|
| 216 |
+
x = torch.cat([x_last, x], 1)
|
| 217 |
+
y = ["(last) " + desc for desc in y]
|
| 218 |
+
|
| 219 |
+
elif data_mode == 'video' and args.video_to_video:
|
| 220 |
+
x = torch.cat(x.chunk(2, dim=1)[::-1], 0) # data is target:source
|
| 221 |
+
x = vae.encode(x)
|
| 222 |
+
x = torch.cat(x.chunk(2, dim=0), 1)
|
| 223 |
+
y = ["(v2v) " + desc for desc in y]
|
| 224 |
+
|
| 225 |
+
else:
|
| 226 |
+
x = vae.encode(x)
|
| 227 |
+
|
| 228 |
+
# add noise to images
|
| 229 |
+
x, _ = add_noise(x, noise_std, args.noise_type)
|
| 230 |
+
if data_mode == 'video' and args.drop_image > 0 and random.random() < args.drop_image:
|
| 231 |
+
x = x[:, 1:]
|
| 232 |
+
y = ["(extend) " + desc for desc in y]
|
| 233 |
+
|
| 234 |
+
# Enable gradient computation for x
|
| 235 |
+
x.requires_grad_(True)
|
| 236 |
+
|
| 237 |
+
# Process labels/text based on model type
|
| 238 |
+
with torch.no_grad():
|
| 239 |
+
y = encode_text(
|
| 240 |
+
text_encoder, tokenizer,
|
| 241 |
+
drop_label(y, args.drop_label),
|
| 242 |
+
args.txt_size, device,
|
| 243 |
+
aspect_ratio=x_aspect if args.mix_aspect else None,
|
| 244 |
+
fps=meta.get('fps', None) if args.fps_cond else None,
|
| 245 |
+
noise_std=noise_std if args.cond_noise_level else None)
|
| 246 |
+
|
| 247 |
+
# main training step
|
| 248 |
+
needs_update = False # (it + 1) % grad_accum_steps == 0
|
| 249 |
+
needs_zero_grad = it % grad_accum_steps == 0
|
| 250 |
+
|
| 251 |
+
if needs_zero_grad:
|
| 252 |
+
optimizer.zero_grad()
|
| 253 |
+
|
| 254 |
+
# main forward
|
| 255 |
+
z, _, outputs, logdets = model_ddp(x, y)
|
| 256 |
+
weights = noise_std / 0.3 if args.cond_noise_level else None
|
| 257 |
+
loss_dict = model.get_loss(z, logdets, weights)
|
| 258 |
+
loss = loss_dict['loss']
|
| 259 |
+
if args.latent_norm_regularization > 0:
|
| 260 |
+
loss += args.latent_norm_regularization * sum([z.pow(2).mean() for z in outputs[:-1]])
|
| 261 |
+
loss = loss / grad_accum_steps # use gradient accumulation
|
| 262 |
+
|
| 263 |
+
if dist.gather_concat(loss.view(1)).isnan().any():
|
| 264 |
+
if dist.local_rank == 0:
|
| 265 |
+
print('nan detected, skipping step')
|
| 266 |
+
continue
|
| 267 |
+
|
| 268 |
+
with utils.sync_ctx(model_ddp, sync=needs_update) if grad_accum_steps > 1 else contextlib.nullcontext():
|
| 269 |
+
if args.loss_scaling:
|
| 270 |
+
scaler.scale(loss).backward()
|
| 271 |
+
else:
|
| 272 |
+
loss.backward()
|
| 273 |
+
|
| 274 |
+
# Get gradient of x after backward pass
|
| 275 |
+
if args.learnable_self_denoiser:
|
| 276 |
+
x_grad = x.grad.clone().detach() # Clone to preserve the gradient
|
| 277 |
+
scale = (float(x.numel()) / scaler.get_scale()) if args.loss_scaling else float(x.numel())
|
| 278 |
+
score = x_grad * scale * grad_accum_steps * noise_std # roughly std=1, similar to diffusion models
|
| 279 |
+
pred = model_ddp(x, y, denoiser=True)
|
| 280 |
+
loss_denoiser = F.mse_loss(pred, score, reduction='mean') / grad_accum_steps
|
| 281 |
+
loss_dict['loss_denoiser'] = loss_denoiser.item()
|
| 282 |
+
|
| 283 |
+
with utils.sync_ctx(model_ddp, sync=needs_update) if grad_accum_steps > 1 else contextlib.nullcontext():
|
| 284 |
+
if args.loss_scaling:
|
| 285 |
+
denoiser_scaler.scale(loss_denoiser).backward()
|
| 286 |
+
else:
|
| 287 |
+
loss_denoiser.backward()
|
| 288 |
+
|
| 289 |
+
# accumulate time
|
| 290 |
+
total_training_time = total_training_time + (time.time() - start_time)
|
| 291 |
+
if needs_update:
|
| 292 |
+
# Apply gradient clipping and monitor gradient norm
|
| 293 |
+
grad_norm = None
|
| 294 |
+
denoiser_grad_norm = None
|
| 295 |
+
skip_update = False
|
| 296 |
+
|
| 297 |
+
if args.grad_clip > 0:
|
| 298 |
+
if args.loss_scaling:
|
| 299 |
+
scaler.unscale_(optimizer)
|
| 300 |
+
grad_norm = torch.nn.utils.clip_grad_norm_(trainable_params, args.grad_clip)
|
| 301 |
+
skip_update = grad_norm.item() > args.grad_clip if args.grad_skip and (total_steps > 100) else False
|
| 302 |
+
if args.learnable_self_denoiser:
|
| 303 |
+
if args.loss_scaling:
|
| 304 |
+
denoiser_scaler.unscale_(denoiser_optimizer)
|
| 305 |
+
denoiser_grad_norm = torch.nn.utils.clip_grad_norm_(model_ddp.learnable_self_denoiser.parameters(), args.grad_clip)
|
| 306 |
+
skip_update = skip_update or (denoiser_grad_norm.item() > args.grad_clip if args.grad_skip and (total_steps > 100) else False)
|
| 307 |
+
|
| 308 |
+
if skip_update:
|
| 309 |
+
print(f'Skipping update due to large gradient norm {grad_norm.item():.4f} > {args.grad_clip:.4f}')
|
| 310 |
+
optimizer.zero_grad()
|
| 311 |
+
if args.learnable_self_denoiser:
|
| 312 |
+
denoiser_optimizer.zero_grad()
|
| 313 |
+
|
| 314 |
+
if args.loss_scaling:
|
| 315 |
+
scaler.step(optimizer)
|
| 316 |
+
scaler.update()
|
| 317 |
+
else:
|
| 318 |
+
optimizer.step()
|
| 319 |
+
|
| 320 |
+
current_lr = lr_schedule.step()
|
| 321 |
+
|
| 322 |
+
if not skip_update:
|
| 323 |
+
metrics.update(loss_dict)
|
| 324 |
+
|
| 325 |
+
if args.learnable_self_denoiser:
|
| 326 |
+
denoiser_lr = denoiser_lr_schedule.step()
|
| 327 |
+
if args.loss_scaling:
|
| 328 |
+
denoiser_scaler.step(denoiser_optimizer)
|
| 329 |
+
denoiser_scaler.update()
|
| 330 |
+
else:
|
| 331 |
+
denoiser_optimizer.step()
|
| 332 |
+
if not skip_update:
|
| 333 |
+
metrics.update({'loss_denoiser': loss_denoiser.item()})
|
| 334 |
+
|
| 335 |
+
total_steps = total_steps + 1
|
| 336 |
+
total_images = total_images + args.batch_size
|
| 337 |
+
|
| 338 |
+
# end of training step
|
| 339 |
+
if (it // grad_accum_steps) % 10 == 9:
|
| 340 |
+
speed = (total_images - images_start) / total_training_time
|
| 341 |
+
print(f"{total_steps:,} steps/{total_images:,} images ({speed:0.2f} samples/sec) - \t" + "\t".join(
|
| 342 |
+
["{}: {:.4f}".format(k, v) for k, v in loss_dict.items()]))
|
| 343 |
+
|
| 344 |
+
if dist.rank == 0 and WANDB_API_KEY is not None:
|
| 345 |
+
wandb_dict = {'speed': speed, 'steps': total_steps, 'lr': current_lr}
|
| 346 |
+
if grad_norm is not None:
|
| 347 |
+
wandb_dict['grad_norm'] = grad_norm.item()
|
| 348 |
+
if args.learnable_self_denoiser:
|
| 349 |
+
wandb_dict['denoiser_lr'] = denoiser_lr
|
| 350 |
+
if denoiser_grad_norm is not None:
|
| 351 |
+
wandb_dict['denoiser_grad_norm'] = denoiser_grad_norm.item()
|
| 352 |
+
loss_dict.update(wandb_dict)
|
| 353 |
+
wandb.log(loss_dict, step=total_images)
|
| 354 |
+
|
| 355 |
+
if args.dry_run:
|
| 356 |
+
break
|
| 357 |
+
|
| 358 |
+
# metrics_dict = {'lr': current_lr, **metrics.compute(dist)}
|
| 359 |
+
|
| 360 |
+
# print metrics
|
| 361 |
+
if False: # dist.local_rank == 0:
|
| 362 |
+
metrics.print(metrics_dict, epoch + 1)
|
| 363 |
+
print('\tLayer norm', ' '.join([f'{z.pow(2).mean():.4f}' for z in outputs]))
|
| 364 |
+
print('\tLayer stdv', ' '.join([f'{z.std():.4f}' for z in outputs]))
|
| 365 |
+
if dist.rank == 0 and WANDB_API_KEY is not None:
|
| 366 |
+
wandb.log({f'epoch_{k}': v for k, v in metrics_dict.items()}, step=total_images)
|
| 367 |
+
|
| 368 |
+
# save model and optimizer state
|
| 369 |
+
if not args.dry_run:
|
| 370 |
+
utils.save_model(args, dist, model, model_ckpt_file)
|
| 371 |
+
if epoch % args.save_every == 0: # save every 20 epochs
|
| 372 |
+
utils.save_model(args, dist, model, str(model_ckpt_file) + f"_epoch{epoch+1:04d}")
|
| 373 |
+
# utils.save_optimizer(args, dist, optimizer, lr_schedule, opt_ckpt_file)
|
| 374 |
+
dist.barrier()
|
| 375 |
+
|
| 376 |
+
# sample images (should i sample?)
|
| 377 |
+
if args.sample_freq > 0 and (epoch % args.sample_freq == 0 or args.dry_run):
|
| 378 |
+
model.eval()
|
| 379 |
+
|
| 380 |
+
# Simple sampling using current batch data
|
| 381 |
+
with torch.no_grad():
|
| 382 |
+
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
|
| 383 |
+
x_aspect = "16:9"
|
| 384 |
+
x_shape = aspect_ratio_to_image_size(
|
| 385 |
+
args.img_size * vae.downsample_factor, x_aspect,
|
| 386 |
+
multiple=vae.downsample_factor
|
| 387 |
+
)
|
| 388 |
+
if x.dim() == 5:
|
| 389 |
+
x_shape = (x.shape[0], 21, x.shape[2], x_shape[0] // vae.downsample_factor, x_shape[1] // vae.downsample_factor)
|
| 390 |
+
else:
|
| 391 |
+
x_shape = (x.shape[0], x.shape[1], x_shape[0] // vae.downsample_factor, x_shape[1] // vae.downsample_factor)
|
| 392 |
+
noise = torch.randn(*x_shape).to(device)
|
| 393 |
+
cfg = 3.5
|
| 394 |
+
y_caption = ["POV from the boat deck looking at a corgi wearing neon-pink sunglasses; wind noise feel, slight horizon bob, water droplets on lens occasionally, sun reflections flicker on the frames; natural lighting"]
|
| 395 |
+
y_caption = y_caption + [""] * len(y_caption)
|
| 396 |
+
sample_y = encode_text(
|
| 397 |
+
text_encoder, tokenizer, y_caption,
|
| 398 |
+
args.txt_size, device,
|
| 399 |
+
aspect_ratio=x_aspect if args.mix_aspect else None,
|
| 400 |
+
fps=meta.get('fps', [None])[0] if args.fps_cond else None,
|
| 401 |
+
noise_std=noise_std if args.cond_noise_level else None)
|
| 402 |
+
|
| 403 |
+
# Generate samples
|
| 404 |
+
samples = model(noise, sample_y, reverse=True, guidance=cfg,
|
| 405 |
+
jacobi=1 if noise.dim() == 5 else 0, verbose=True)
|
| 406 |
+
|
| 407 |
+
# Apply self denoising if needed
|
| 408 |
+
sample_y = sample_y.chunk(2, dim=0)[0] # Remove null captions for denoising
|
| 409 |
+
samples = simple_denoising(model, samples, sample_y,
|
| 410 |
+
text_encoder, tokenizer, args, noise_std)
|
| 411 |
+
|
| 412 |
+
# Decode with VAE if available
|
| 413 |
+
if args.vae is not None:
|
| 414 |
+
samples = vae.decode(samples)
|
| 415 |
+
|
| 416 |
+
# Save samples using unified function
|
| 417 |
+
save_samples_unified(
|
| 418 |
+
samples=samples,
|
| 419 |
+
save_dir=sample_dir,
|
| 420 |
+
filename_prefix="train_samples",
|
| 421 |
+
epoch_or_iter=epoch+1,
|
| 422 |
+
fps=meta.get('fps', [16])[0],
|
| 423 |
+
dist=dist,
|
| 424 |
+
wandb_log=WANDB_API_KEY is not None,
|
| 425 |
+
wandb_step=total_images,
|
| 426 |
+
grid_arrangement="grid" # Use simple grid for training
|
| 427 |
+
)
|
| 428 |
+
model.train()
|
| 429 |
+
|
| 430 |
+
if args.dry_run:
|
| 431 |
+
break
|
| 432 |
+
|
| 433 |
+
if dist.rank == 0 and WANDB_API_KEY is not None:
|
| 434 |
+
wandb.finish()
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
def get_tarflow_parser():
|
| 438 |
+
parser = argparse.ArgumentParser()
|
| 439 |
+
|
| 440 |
+
# Model config path (same as sample.py)
|
| 441 |
+
parser.add_argument('--model_config_path', default=None, type=str, help='path to YAML config file')
|
| 442 |
+
|
| 443 |
+
# Dataset config
|
| 444 |
+
parser.add_argument('--train_seed', default=None, type=int)
|
| 445 |
+
parser.add_argument('--data', default='data', type=pathlib.Path)
|
| 446 |
+
parser.add_argument('--logdir', default='./logs', type=pathlib.Path)
|
| 447 |
+
parser.add_argument('--dataset', default='dummy', type=str)
|
| 448 |
+
parser.add_argument('--wds', default=0, type=int)
|
| 449 |
+
parser.add_argument('--mix_aspect', default=0, type=int)
|
| 450 |
+
parser.add_argument('--img_size', default=32, type=int)
|
| 451 |
+
parser.add_argument('--vid_size', default=None, type=str, help="num_frames:fps1:fps2 for video datasets. If None, image mode")
|
| 452 |
+
parser.add_argument('--fps_cond', default=0, type=int, help="If 1, use fps from video dataset as condition")
|
| 453 |
+
parser.add_argument('--no_flip', default=0, type=int)
|
| 454 |
+
parser.add_argument('--caption_column', default='syn_detailed_description_w_caption', type=str,
|
| 455 |
+
help="If given 'folder', then extract caption from file name")
|
| 456 |
+
|
| 457 |
+
# Optional Dadaset config
|
| 458 |
+
parser.add_argument('--secondary_dataset', default=None, type=str, help="secondary dataset for training")
|
| 459 |
+
parser.add_argument('--secondary_img_size', default=32, type=int, help="secondary dataset image size")
|
| 460 |
+
parser.add_argument('--secondary_vid_size', default=None, type=str, help="secondary dataset video size")
|
| 461 |
+
|
| 462 |
+
# VAE configuration
|
| 463 |
+
parser.add_argument('--vae', default=None, type=str, help="pretrained VAE name")
|
| 464 |
+
parser.add_argument('--vae_decoder_factor', default=1, type=float, help="VAE decoder scaling factor")
|
| 465 |
+
parser.add_argument('--channel_size', default=3, type=int)
|
| 466 |
+
parser.add_argument('--finetuned_vae', default=None, type=str)
|
| 467 |
+
|
| 468 |
+
# Text encoder configuration
|
| 469 |
+
parser.add_argument('--text', default=None, type=str, help="text encoder")
|
| 470 |
+
parser.add_argument('--txt_size', default=0, type=int, help="maximum text length")
|
| 471 |
+
|
| 472 |
+
# Model configuration
|
| 473 |
+
parser.add_argument('--sos', default=0, type=int)
|
| 474 |
+
parser.add_argument('--seq_order', default="R2L", type=str, choices=['R2L', 'L2R'])
|
| 475 |
+
parser.add_argument('--patch_size', default=4, type=int)
|
| 476 |
+
parser.add_argument('--channels', default=512, type=int)
|
| 477 |
+
parser.add_argument('--top_block_channels', default=None, type=int)
|
| 478 |
+
parser.add_argument('--blocks', default=4, type=int)
|
| 479 |
+
parser.add_argument('--layers_per_block', default=8, type=int, nargs='*')
|
| 480 |
+
parser.add_argument('--rope', default=0, type=int)
|
| 481 |
+
parser.add_argument('--pt_seq_len', default=None, type=int)
|
| 482 |
+
parser.add_argument('--adaln', default=0, type=int)
|
| 483 |
+
parser.add_argument('--nvp', default=1, type=int)
|
| 484 |
+
parser.add_argument('--use_softplus', default=0, type=int)
|
| 485 |
+
parser.add_argument('--cond_top_only', default=0, type=int)
|
| 486 |
+
parser.add_argument('--head_dim', default=64, type=int)
|
| 487 |
+
parser.add_argument('--num_heads', default=None, type=int)
|
| 488 |
+
parser.add_argument('--num_kv_heads', default=None, type=int)
|
| 489 |
+
parser.add_argument('--use_swiglu', default=0, type=int)
|
| 490 |
+
parser.add_argument('--use_qk_norm', default=0, type=int)
|
| 491 |
+
parser.add_argument('--use_post_norm', default=0, type=int)
|
| 492 |
+
parser.add_argument('--use_final_norm', default=0, type=int)
|
| 493 |
+
parser.add_argument('--use_bias', default=1, type=int)
|
| 494 |
+
parser.add_argument('--norm_type', default='layer_norm', type=str)
|
| 495 |
+
parser.add_argument('--use_pretrained_lm', default=None, type=str, choices=['gemma3_4b', 'gemma3_1b', 'gemma2_2b'])
|
| 496 |
+
parser.add_argument('--use_mm_attn', default=0, type=int)
|
| 497 |
+
parser.add_argument('--soft_clip', default=0, type=float, help="soft clip the output values")
|
| 498 |
+
parser.add_argument('--learnable_self_denoiser', default=0, type=int, help="Whether to use learnable self-denoiser")
|
| 499 |
+
parser.add_argument('--conditional_denoiser', default=0, type=int, help="conditional denoiser")
|
| 500 |
+
parser.add_argument('--noise_embed_denoiser', default=0, type=int, help="add noise embedding to the denoiser")
|
| 501 |
+
parser.add_argument('--temporal_causal', default=0, type=int, help="Whether to use temporal causal model")
|
| 502 |
+
parser.add_argument('--shallow_block_local', default=0, type=int, help="Whether to use local attention in shallow blocks")
|
| 503 |
+
parser.add_argument('--denoiser_window', default=None, type=int, help="local window size for denoiser")
|
| 504 |
+
parser.add_argument('--local_attn_window', default=None, type=int, help="Whether to use local attention")
|
| 505 |
+
|
| 506 |
+
# Training configuration
|
| 507 |
+
parser.add_argument('--noise_std', default=0.3, type=float)
|
| 508 |
+
parser.add_argument('--noise_type', default='gaussian', choices=['gaussian', 'uniform'], type=str)
|
| 509 |
+
parser.add_argument('--cond_noise_level', default=0, type=int, help="Whether to sample noise level as in diffusion models")
|
| 510 |
+
parser.add_argument('--batch_size', default=128, type=int)
|
| 511 |
+
parser.add_argument('--secondary_batch_size', default=128, type=int, help="only for secondary dataset")
|
| 512 |
+
parser.add_argument('--secondary_ratio', default=0, type=float, help="a value between 0-1, ratio of using secondary data.")
|
| 513 |
+
|
| 514 |
+
parser.add_argument('--acc', default=1, type=int)
|
| 515 |
+
parser.add_argument('--fp8', default=0, type=int, help='Whether to use FP8 training')
|
| 516 |
+
parser.add_argument('--use_8bit_adam', default=0, type=int, help='Whether to use 8-bit Adam optimizer')
|
| 517 |
+
parser.add_argument('--epochs', default=1000, type=int)
|
| 518 |
+
parser.add_argument('--epoch_length', default=50000, type=int)
|
| 519 |
+
parser.add_argument('--lr', default=1e-4, type=float)
|
| 520 |
+
parser.add_argument('--min_lr', default=1e-6, type=float)
|
| 521 |
+
parser.add_argument('--drop_label', default=0, type=float)
|
| 522 |
+
|
| 523 |
+
parser.add_argument('--drop_image', default=0, type=float)
|
| 524 |
+
parser.add_argument('--last_frame_cond', default=0, type=int)
|
| 525 |
+
parser.add_argument('--video_to_video', default=0, type=int)
|
| 526 |
+
|
| 527 |
+
parser.add_argument('--resume_path', default=None, type=str)
|
| 528 |
+
parser.add_argument('--resume_epoch', default=0, type=int)
|
| 529 |
+
parser.add_argument('--warmup_steps', default=None, type=int, help='Warmup steps for training')
|
| 530 |
+
|
| 531 |
+
parser.add_argument('--fsdp', default=0, type=int)
|
| 532 |
+
parser.add_argument('--fsdp_text_encoder', default=0, type=int)
|
| 533 |
+
parser.add_argument('--gradient_checkpoint', default=0, type=int)
|
| 534 |
+
parser.add_argument('--gradient_checkpoint_mlp', default=None, type=int)
|
| 535 |
+
parser.add_argument('--compile', default=0, type=int, help='Whether to use torch.compile')
|
| 536 |
+
parser.add_argument('--latent_norm_regularization', default=0, type=float, help='Regularization on latent norm, 1e-4 is a good value')
|
| 537 |
+
parser.add_argument('--loss_scaling', default=1, type=int, help='Whether to use AMP')
|
| 538 |
+
parser.add_argument('--grad_clip', default=0, type=float, help='Gradient clipping threshold, 0 to disable')
|
| 539 |
+
parser.add_argument('--grad_skip', default=0, type=int, help='Skip gradient computation for the model')
|
| 540 |
+
parser.add_argument('--dry_run', default=0, type=int, help='Dry run for quick tests')
|
| 541 |
+
parser.add_argument('--wandb_name', default=None, type=str, help='Wandb name for the run')
|
| 542 |
+
parser.add_argument('--save_every', default=20, type=int, help='Save model every N epochs')
|
| 543 |
+
parser.add_argument('--sample_freq', default=1, type=int, help="sample every N epochs, 0 to disable")
|
| 544 |
+
|
| 545 |
+
# Sampling configuration
|
| 546 |
+
parser.add_argument('--cfg', default=0, type=float, nargs='+')
|
| 547 |
+
parser.add_argument('--num_samples', default=4096, type=int)
|
| 548 |
+
parser.add_argument('--sample_batch_size', default=256, type=int)
|
| 549 |
+
return parser
|
| 550 |
+
|
| 551 |
+
|
| 552 |
+
if __name__ == '__main__':
|
| 553 |
+
parser = get_tarflow_parser()
|
| 554 |
+
args = parser.parse_args()
|
| 555 |
+
main(args)
|