leoeric commited on
Commit
6df4481
·
1 Parent(s): 0e6f557

Ensure train.py, dataset.py, and misc/ are included in Space

Browse files
Files changed (1) hide show
  1. train.py +555 -0
train.py ADDED
@@ -0,0 +1,555 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # For licensing see accompanying LICENSE file.
3
+ # Copyright (C) 2025 Apple Inc. All Rights Reserved.
4
+ #
5
+ #!/usr/bin/env python3
6
+ """
7
+ Scalable Transformer Autoregressive Flow (STARFlow) Training Script
8
+
9
+ This script provides functionality for training transformer autoregressive flow models
10
+ with support for both image and video generation.
11
+
12
+ Usage:
13
+ python train.py --model_config_path config.yaml --epochs 100
14
+ """
15
+
16
+ import argparse
17
+ import builtins
18
+ import pathlib
19
+ import copy
20
+ import torch
21
+ import torch.nn.functional as F
22
+ import torchinfo
23
+ import torch.amp
24
+ import torch.utils
25
+ import torch.utils.data
26
+ import torchvision as tv
27
+ import numpy as np
28
+ import random
29
+ import transformer_flow
30
+ import utils
31
+ import time
32
+ import contextlib
33
+ import tqdm
34
+ import os
35
+ import gc
36
+ import sys
37
+ import wandb
38
+ import yaml
39
+ from typing import Dict, List, Optional, Tuple, Union
40
+
41
+ from dataset import read_tsv, aspect_ratio_to_image_size
42
+ from contextlib import nullcontext
43
+ from misc import print # local_rank=0 print
44
+ from utils import simple_denoising, save_samples_unified, add_noise, encode_text, drop_label, load_model_config
45
+
46
+ # Set environment variables for local development
47
+ os.environ["PYTHONPATH"] = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) + ":" + os.environ.get("PYTHONPATH", "")
48
+ WANDB_API_KEY = os.environ.get("WANDB_API_KEY", None)
49
+
50
+
51
+ def self_denoise(model, samples, y, noise_std=0.1, lr=1, steps=1, disable_learnable_denoiser=False):
52
+ if steps == 0:
53
+ return samples
54
+
55
+ outputs = []
56
+ x = samples.clone()
57
+ lr = noise_std ** 2 * lr
58
+ with torch.enable_grad():
59
+ x.requires_grad = True
60
+ model.train()
61
+ z, _, _, logdets = model(x, y)
62
+ loss = model.get_loss(z, logdets)['loss'] * 65536
63
+ grad = float(samples.numel()) / 65536 * torch.autograd.grad(loss, [x])[0]
64
+ outputs += [(x - grad * lr).detach()]
65
+ x = torch.cat(outputs, -1)
66
+ return x
67
+
68
+
69
+
70
+
71
+
72
+
73
+
74
+
75
+ def main(args):
76
+ # Load model configuration if provided
77
+ if hasattr(args, 'model_config_path') and args.model_config_path:
78
+ # Parse sys.argv to see which args were actually provided
79
+ provided_args = set()
80
+ for i, arg in enumerate(sys.argv[1:]):
81
+ if arg.startswith('--'):
82
+ arg_name = arg[2:].replace('-', '_')
83
+ provided_args.add(arg_name)
84
+ trainer_args = load_model_config(args.model_config_path)
85
+ trainer_dict = vars(trainer_args)
86
+ for k, v in vars(args).items():
87
+ if k in provided_args:
88
+ trainer_dict[k] = v
89
+ args = argparse.Namespace(**trainer_dict)
90
+
91
+ # global setup
92
+ dist = utils.Distributed()
93
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
94
+ seed = args.train_seed if args.train_seed is not None else time.time_ns() % 2**32
95
+ utils.set_random_seed(seed + dist.rank)
96
+ print(f'set random seed {seed}')
97
+
98
+ if dist.rank == 0 and WANDB_API_KEY is not None:
99
+ job_name = f'{args.dataset}'
100
+ if args.wandb_name is not None:
101
+ wandb_names = args.wandb_name.split('+')
102
+ if len(wandb_names) > 1:
103
+ job_name += f'-{wandb_names[0]}-{getattr(args, wandb_names[1])}'
104
+ else:
105
+ job_name += f'-{wandb_names[0]}'
106
+
107
+ wandb.login(key=WANDB_API_KEY)
108
+ wandb.init(project="starflow", name=job_name, config=vars(args))
109
+ wandb.run.save()
110
+ wandb.run.log_code(os.path.dirname(os.path.realpath(__file__)))
111
+
112
+ if args.use_pretrained_lm is not None:
113
+ args.text = args.use_pretrained_lm # need to match the text embedder
114
+
115
+ print(f'{" Config ":-^80}')
116
+ for k, v in sorted(vars(args).items()):
117
+ print(f'{k:32s}: {v}')
118
+
119
+ # dataset
120
+ data_loader = utils.get_data(args, dist)
121
+ total_num_images = data_loader.dataset.total_num_samples
122
+ grad_accum_steps = max(args.acc, 1)
123
+ num_batches_before_acc = len(data_loader)
124
+ num_batches = num_batches_before_acc // grad_accum_steps
125
+
126
+ print(f'{" Dataset Info ":-^80}')
127
+ print(f'{num_batches} batches per epoch ({num_batches_before_acc} steps if consider {grad_accum_steps} accumulation), global batch size {args.batch_size} for {args.epochs} epochs')
128
+ print(f'So it is {num_batches * args.batch_size:,} images per epoch')
129
+ print(f'Target training on {args.batch_size * num_batches * args.epochs:,} images')
130
+ print(f'Total {total_num_images:,} unique training examples')
131
+
132
+ assert args.text is not None, "starflow needs text conditioning"
133
+
134
+ # text encoder
135
+ tokenizer, text_encoder = utils.setup_encoder(args, dist, device)
136
+ text_encoder.requires_grad_(False) # freeze text encoder
137
+
138
+ # VAE & fixed noise
139
+ if args.vae is not None:
140
+ vae = utils.setup_vae(args, dist, device)
141
+ vae.requires_grad_(False) # freeze VAE
142
+ args.img_size = args.img_size // vae.downsample_factor
143
+
144
+ # main model
145
+ model = utils.setup_transformer(args, dist,
146
+ txt_dim=text_encoder.config.hidden_size,
147
+ use_checkpoint=args.gradient_checkpoint,
148
+ use_checkpoint_mlp=args.gradient_checkpoint_mlp).to(device)
149
+ if dist.local_rank == 0:
150
+ torchinfo.summary(model)
151
+
152
+ # Load model before FSDP wrapping to support expansion
153
+ if args.resume_path:
154
+ print(f"Loading checkpoint from local path: {args.resume_path}")
155
+ state_dict = torch.load(args.resume_path, map_location='cpu')
156
+ model.load_state_dict(state_dict, strict=False)
157
+ del state_dict; torch.cuda.empty_cache()
158
+ epoch_start = args.resume_epoch if args.resume_epoch is not None else 0
159
+ else:
160
+ epoch_start = 0
161
+
162
+ # setup for training
163
+ model, model_ddp = utils.parallelize_model(args, model, dist, device)
164
+ if args.text and args.fsdp_text_encoder:
165
+ text_encoder = utils.parallelize_model(args, text_encoder, dist, device, [text_encoder.base_block_name])[1]
166
+ trainable_params = [p for k, p in model_ddp.named_parameters() if p.requires_grad and not k.startswith('learnable_self_denoiser')]
167
+ optimizer = torch.optim.AdamW(trainable_params, betas=(0.9, 0.95), lr=args.lr, weight_decay=1e-4)
168
+ warmup_steps = args.warmup_steps if args.warmup_steps is not None else num_batches
169
+ lr_schedule = utils.CosineLRSchedule(
170
+ optimizer, warmup_steps, args.epochs * num_batches, args.min_lr, args.lr)
171
+ if args.learnable_self_denoiser:
172
+ denoiser_optimizer = torch.optim.AdamW(model_ddp.learnable_self_denoiser.parameters(), lr=1e-4, weight_decay=1e-4)
173
+ denoiser_lr_schedule = utils.CosineLRSchedule(
174
+ denoiser_optimizer, warmup_steps, args.epochs * num_batches, 1e-6, 1e-4)
175
+ print('warmup_steps:', warmup_steps, 'num_batches:', num_batches, 'total steps:', args.epochs * num_batches)
176
+
177
+ # Adjust learning rate schedule and counters if resuming
178
+ lr_schedule.counter += epoch_start * num_batches
179
+ images_start = epoch_start * num_batches * args.batch_size
180
+
181
+
182
+ if args.loss_scaling:
183
+ scaler = torch.amp.GradScaler()
184
+ if args.learnable_self_denoiser:
185
+ denoiser_scaler = torch.amp.GradScaler()
186
+
187
+ noise_std = args.noise_std
188
+ model_name = f'{args.patch_size}_{args.channels}_{args.blocks}_{args.layers_per_block}_{noise_std:.2f}'
189
+ sample_dir: pathlib.Path = args.logdir / f'{args.dataset}_samples_{model_name}'
190
+ model_ckpt_file = args.logdir / f'{args.dataset}_model_{model_name}.pth'
191
+ opt_ckpt_file = args.logdir / f'{args.dataset}_opt_{model_name}.pth'
192
+ if dist.local_rank == 0:
193
+ sample_dir.mkdir(parents=True, exist_ok=True)
194
+
195
+ print(f'{" Training ":-^80}')
196
+ total_steps, total_images, total_training_time = epoch_start * num_batches, images_start, 0
197
+ for epoch in range(epoch_start, args.epochs):
198
+ metrics = utils.Metrics()
199
+ for it, (x, y, meta) in enumerate(data_loader):
200
+ if args.secondary_dataset is not None and random.random() < args.secondary_ratio:
201
+ x, y, meta = next(data_loader.secondary_loader) # load data from secondary dataset instead
202
+ y_caption = copy.deepcopy(y)
203
+ data_mode = 'image' if (x.dim() == 4) else 'video'
204
+
205
+ start_time = time.time()
206
+ x_aspect, video_mode = data_loader.dataset.get_batch_modes(x)
207
+ x = x.to(device)
208
+ with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
209
+ # apply VAE over images
210
+ if args.vae is not None:
211
+ with torch.no_grad():
212
+ if data_mode == 'video' and args.last_frame_cond:
213
+ x_last = x[:, -4:-3] # use the last frame as additional condition
214
+ x = x[:, :-4]
215
+ x, x_last = vae.encode(x), vae.encode(x_last)
216
+ x = torch.cat([x_last, x], 1)
217
+ y = ["(last) " + desc for desc in y]
218
+
219
+ elif data_mode == 'video' and args.video_to_video:
220
+ x = torch.cat(x.chunk(2, dim=1)[::-1], 0) # data is target:source
221
+ x = vae.encode(x)
222
+ x = torch.cat(x.chunk(2, dim=0), 1)
223
+ y = ["(v2v) " + desc for desc in y]
224
+
225
+ else:
226
+ x = vae.encode(x)
227
+
228
+ # add noise to images
229
+ x, _ = add_noise(x, noise_std, args.noise_type)
230
+ if data_mode == 'video' and args.drop_image > 0 and random.random() < args.drop_image:
231
+ x = x[:, 1:]
232
+ y = ["(extend) " + desc for desc in y]
233
+
234
+ # Enable gradient computation for x
235
+ x.requires_grad_(True)
236
+
237
+ # Process labels/text based on model type
238
+ with torch.no_grad():
239
+ y = encode_text(
240
+ text_encoder, tokenizer,
241
+ drop_label(y, args.drop_label),
242
+ args.txt_size, device,
243
+ aspect_ratio=x_aspect if args.mix_aspect else None,
244
+ fps=meta.get('fps', None) if args.fps_cond else None,
245
+ noise_std=noise_std if args.cond_noise_level else None)
246
+
247
+ # main training step
248
+ needs_update = False # (it + 1) % grad_accum_steps == 0
249
+ needs_zero_grad = it % grad_accum_steps == 0
250
+
251
+ if needs_zero_grad:
252
+ optimizer.zero_grad()
253
+
254
+ # main forward
255
+ z, _, outputs, logdets = model_ddp(x, y)
256
+ weights = noise_std / 0.3 if args.cond_noise_level else None
257
+ loss_dict = model.get_loss(z, logdets, weights)
258
+ loss = loss_dict['loss']
259
+ if args.latent_norm_regularization > 0:
260
+ loss += args.latent_norm_regularization * sum([z.pow(2).mean() for z in outputs[:-1]])
261
+ loss = loss / grad_accum_steps # use gradient accumulation
262
+
263
+ if dist.gather_concat(loss.view(1)).isnan().any():
264
+ if dist.local_rank == 0:
265
+ print('nan detected, skipping step')
266
+ continue
267
+
268
+ with utils.sync_ctx(model_ddp, sync=needs_update) if grad_accum_steps > 1 else contextlib.nullcontext():
269
+ if args.loss_scaling:
270
+ scaler.scale(loss).backward()
271
+ else:
272
+ loss.backward()
273
+
274
+ # Get gradient of x after backward pass
275
+ if args.learnable_self_denoiser:
276
+ x_grad = x.grad.clone().detach() # Clone to preserve the gradient
277
+ scale = (float(x.numel()) / scaler.get_scale()) if args.loss_scaling else float(x.numel())
278
+ score = x_grad * scale * grad_accum_steps * noise_std # roughly std=1, similar to diffusion models
279
+ pred = model_ddp(x, y, denoiser=True)
280
+ loss_denoiser = F.mse_loss(pred, score, reduction='mean') / grad_accum_steps
281
+ loss_dict['loss_denoiser'] = loss_denoiser.item()
282
+
283
+ with utils.sync_ctx(model_ddp, sync=needs_update) if grad_accum_steps > 1 else contextlib.nullcontext():
284
+ if args.loss_scaling:
285
+ denoiser_scaler.scale(loss_denoiser).backward()
286
+ else:
287
+ loss_denoiser.backward()
288
+
289
+ # accumulate time
290
+ total_training_time = total_training_time + (time.time() - start_time)
291
+ if needs_update:
292
+ # Apply gradient clipping and monitor gradient norm
293
+ grad_norm = None
294
+ denoiser_grad_norm = None
295
+ skip_update = False
296
+
297
+ if args.grad_clip > 0:
298
+ if args.loss_scaling:
299
+ scaler.unscale_(optimizer)
300
+ grad_norm = torch.nn.utils.clip_grad_norm_(trainable_params, args.grad_clip)
301
+ skip_update = grad_norm.item() > args.grad_clip if args.grad_skip and (total_steps > 100) else False
302
+ if args.learnable_self_denoiser:
303
+ if args.loss_scaling:
304
+ denoiser_scaler.unscale_(denoiser_optimizer)
305
+ denoiser_grad_norm = torch.nn.utils.clip_grad_norm_(model_ddp.learnable_self_denoiser.parameters(), args.grad_clip)
306
+ skip_update = skip_update or (denoiser_grad_norm.item() > args.grad_clip if args.grad_skip and (total_steps > 100) else False)
307
+
308
+ if skip_update:
309
+ print(f'Skipping update due to large gradient norm {grad_norm.item():.4f} > {args.grad_clip:.4f}')
310
+ optimizer.zero_grad()
311
+ if args.learnable_self_denoiser:
312
+ denoiser_optimizer.zero_grad()
313
+
314
+ if args.loss_scaling:
315
+ scaler.step(optimizer)
316
+ scaler.update()
317
+ else:
318
+ optimizer.step()
319
+
320
+ current_lr = lr_schedule.step()
321
+
322
+ if not skip_update:
323
+ metrics.update(loss_dict)
324
+
325
+ if args.learnable_self_denoiser:
326
+ denoiser_lr = denoiser_lr_schedule.step()
327
+ if args.loss_scaling:
328
+ denoiser_scaler.step(denoiser_optimizer)
329
+ denoiser_scaler.update()
330
+ else:
331
+ denoiser_optimizer.step()
332
+ if not skip_update:
333
+ metrics.update({'loss_denoiser': loss_denoiser.item()})
334
+
335
+ total_steps = total_steps + 1
336
+ total_images = total_images + args.batch_size
337
+
338
+ # end of training step
339
+ if (it // grad_accum_steps) % 10 == 9:
340
+ speed = (total_images - images_start) / total_training_time
341
+ print(f"{total_steps:,} steps/{total_images:,} images ({speed:0.2f} samples/sec) - \t" + "\t".join(
342
+ ["{}: {:.4f}".format(k, v) for k, v in loss_dict.items()]))
343
+
344
+ if dist.rank == 0 and WANDB_API_KEY is not None:
345
+ wandb_dict = {'speed': speed, 'steps': total_steps, 'lr': current_lr}
346
+ if grad_norm is not None:
347
+ wandb_dict['grad_norm'] = grad_norm.item()
348
+ if args.learnable_self_denoiser:
349
+ wandb_dict['denoiser_lr'] = denoiser_lr
350
+ if denoiser_grad_norm is not None:
351
+ wandb_dict['denoiser_grad_norm'] = denoiser_grad_norm.item()
352
+ loss_dict.update(wandb_dict)
353
+ wandb.log(loss_dict, step=total_images)
354
+
355
+ if args.dry_run:
356
+ break
357
+
358
+ # metrics_dict = {'lr': current_lr, **metrics.compute(dist)}
359
+
360
+ # print metrics
361
+ if False: # dist.local_rank == 0:
362
+ metrics.print(metrics_dict, epoch + 1)
363
+ print('\tLayer norm', ' '.join([f'{z.pow(2).mean():.4f}' for z in outputs]))
364
+ print('\tLayer stdv', ' '.join([f'{z.std():.4f}' for z in outputs]))
365
+ if dist.rank == 0 and WANDB_API_KEY is not None:
366
+ wandb.log({f'epoch_{k}': v for k, v in metrics_dict.items()}, step=total_images)
367
+
368
+ # save model and optimizer state
369
+ if not args.dry_run:
370
+ utils.save_model(args, dist, model, model_ckpt_file)
371
+ if epoch % args.save_every == 0: # save every 20 epochs
372
+ utils.save_model(args, dist, model, str(model_ckpt_file) + f"_epoch{epoch+1:04d}")
373
+ # utils.save_optimizer(args, dist, optimizer, lr_schedule, opt_ckpt_file)
374
+ dist.barrier()
375
+
376
+ # sample images (should i sample?)
377
+ if args.sample_freq > 0 and (epoch % args.sample_freq == 0 or args.dry_run):
378
+ model.eval()
379
+
380
+ # Simple sampling using current batch data
381
+ with torch.no_grad():
382
+ with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
383
+ x_aspect = "16:9"
384
+ x_shape = aspect_ratio_to_image_size(
385
+ args.img_size * vae.downsample_factor, x_aspect,
386
+ multiple=vae.downsample_factor
387
+ )
388
+ if x.dim() == 5:
389
+ x_shape = (x.shape[0], 21, x.shape[2], x_shape[0] // vae.downsample_factor, x_shape[1] // vae.downsample_factor)
390
+ else:
391
+ x_shape = (x.shape[0], x.shape[1], x_shape[0] // vae.downsample_factor, x_shape[1] // vae.downsample_factor)
392
+ noise = torch.randn(*x_shape).to(device)
393
+ cfg = 3.5
394
+ y_caption = ["POV from the boat deck looking at a corgi wearing neon-pink sunglasses; wind noise feel, slight horizon bob, water droplets on lens occasionally, sun reflections flicker on the frames; natural lighting"]
395
+ y_caption = y_caption + [""] * len(y_caption)
396
+ sample_y = encode_text(
397
+ text_encoder, tokenizer, y_caption,
398
+ args.txt_size, device,
399
+ aspect_ratio=x_aspect if args.mix_aspect else None,
400
+ fps=meta.get('fps', [None])[0] if args.fps_cond else None,
401
+ noise_std=noise_std if args.cond_noise_level else None)
402
+
403
+ # Generate samples
404
+ samples = model(noise, sample_y, reverse=True, guidance=cfg,
405
+ jacobi=1 if noise.dim() == 5 else 0, verbose=True)
406
+
407
+ # Apply self denoising if needed
408
+ sample_y = sample_y.chunk(2, dim=0)[0] # Remove null captions for denoising
409
+ samples = simple_denoising(model, samples, sample_y,
410
+ text_encoder, tokenizer, args, noise_std)
411
+
412
+ # Decode with VAE if available
413
+ if args.vae is not None:
414
+ samples = vae.decode(samples)
415
+
416
+ # Save samples using unified function
417
+ save_samples_unified(
418
+ samples=samples,
419
+ save_dir=sample_dir,
420
+ filename_prefix="train_samples",
421
+ epoch_or_iter=epoch+1,
422
+ fps=meta.get('fps', [16])[0],
423
+ dist=dist,
424
+ wandb_log=WANDB_API_KEY is not None,
425
+ wandb_step=total_images,
426
+ grid_arrangement="grid" # Use simple grid for training
427
+ )
428
+ model.train()
429
+
430
+ if args.dry_run:
431
+ break
432
+
433
+ if dist.rank == 0 and WANDB_API_KEY is not None:
434
+ wandb.finish()
435
+
436
+
437
+ def get_tarflow_parser():
438
+ parser = argparse.ArgumentParser()
439
+
440
+ # Model config path (same as sample.py)
441
+ parser.add_argument('--model_config_path', default=None, type=str, help='path to YAML config file')
442
+
443
+ # Dataset config
444
+ parser.add_argument('--train_seed', default=None, type=int)
445
+ parser.add_argument('--data', default='data', type=pathlib.Path)
446
+ parser.add_argument('--logdir', default='./logs', type=pathlib.Path)
447
+ parser.add_argument('--dataset', default='dummy', type=str)
448
+ parser.add_argument('--wds', default=0, type=int)
449
+ parser.add_argument('--mix_aspect', default=0, type=int)
450
+ parser.add_argument('--img_size', default=32, type=int)
451
+ parser.add_argument('--vid_size', default=None, type=str, help="num_frames:fps1:fps2 for video datasets. If None, image mode")
452
+ parser.add_argument('--fps_cond', default=0, type=int, help="If 1, use fps from video dataset as condition")
453
+ parser.add_argument('--no_flip', default=0, type=int)
454
+ parser.add_argument('--caption_column', default='syn_detailed_description_w_caption', type=str,
455
+ help="If given 'folder', then extract caption from file name")
456
+
457
+ # Optional Dadaset config
458
+ parser.add_argument('--secondary_dataset', default=None, type=str, help="secondary dataset for training")
459
+ parser.add_argument('--secondary_img_size', default=32, type=int, help="secondary dataset image size")
460
+ parser.add_argument('--secondary_vid_size', default=None, type=str, help="secondary dataset video size")
461
+
462
+ # VAE configuration
463
+ parser.add_argument('--vae', default=None, type=str, help="pretrained VAE name")
464
+ parser.add_argument('--vae_decoder_factor', default=1, type=float, help="VAE decoder scaling factor")
465
+ parser.add_argument('--channel_size', default=3, type=int)
466
+ parser.add_argument('--finetuned_vae', default=None, type=str)
467
+
468
+ # Text encoder configuration
469
+ parser.add_argument('--text', default=None, type=str, help="text encoder")
470
+ parser.add_argument('--txt_size', default=0, type=int, help="maximum text length")
471
+
472
+ # Model configuration
473
+ parser.add_argument('--sos', default=0, type=int)
474
+ parser.add_argument('--seq_order', default="R2L", type=str, choices=['R2L', 'L2R'])
475
+ parser.add_argument('--patch_size', default=4, type=int)
476
+ parser.add_argument('--channels', default=512, type=int)
477
+ parser.add_argument('--top_block_channels', default=None, type=int)
478
+ parser.add_argument('--blocks', default=4, type=int)
479
+ parser.add_argument('--layers_per_block', default=8, type=int, nargs='*')
480
+ parser.add_argument('--rope', default=0, type=int)
481
+ parser.add_argument('--pt_seq_len', default=None, type=int)
482
+ parser.add_argument('--adaln', default=0, type=int)
483
+ parser.add_argument('--nvp', default=1, type=int)
484
+ parser.add_argument('--use_softplus', default=0, type=int)
485
+ parser.add_argument('--cond_top_only', default=0, type=int)
486
+ parser.add_argument('--head_dim', default=64, type=int)
487
+ parser.add_argument('--num_heads', default=None, type=int)
488
+ parser.add_argument('--num_kv_heads', default=None, type=int)
489
+ parser.add_argument('--use_swiglu', default=0, type=int)
490
+ parser.add_argument('--use_qk_norm', default=0, type=int)
491
+ parser.add_argument('--use_post_norm', default=0, type=int)
492
+ parser.add_argument('--use_final_norm', default=0, type=int)
493
+ parser.add_argument('--use_bias', default=1, type=int)
494
+ parser.add_argument('--norm_type', default='layer_norm', type=str)
495
+ parser.add_argument('--use_pretrained_lm', default=None, type=str, choices=['gemma3_4b', 'gemma3_1b', 'gemma2_2b'])
496
+ parser.add_argument('--use_mm_attn', default=0, type=int)
497
+ parser.add_argument('--soft_clip', default=0, type=float, help="soft clip the output values")
498
+ parser.add_argument('--learnable_self_denoiser', default=0, type=int, help="Whether to use learnable self-denoiser")
499
+ parser.add_argument('--conditional_denoiser', default=0, type=int, help="conditional denoiser")
500
+ parser.add_argument('--noise_embed_denoiser', default=0, type=int, help="add noise embedding to the denoiser")
501
+ parser.add_argument('--temporal_causal', default=0, type=int, help="Whether to use temporal causal model")
502
+ parser.add_argument('--shallow_block_local', default=0, type=int, help="Whether to use local attention in shallow blocks")
503
+ parser.add_argument('--denoiser_window', default=None, type=int, help="local window size for denoiser")
504
+ parser.add_argument('--local_attn_window', default=None, type=int, help="Whether to use local attention")
505
+
506
+ # Training configuration
507
+ parser.add_argument('--noise_std', default=0.3, type=float)
508
+ parser.add_argument('--noise_type', default='gaussian', choices=['gaussian', 'uniform'], type=str)
509
+ parser.add_argument('--cond_noise_level', default=0, type=int, help="Whether to sample noise level as in diffusion models")
510
+ parser.add_argument('--batch_size', default=128, type=int)
511
+ parser.add_argument('--secondary_batch_size', default=128, type=int, help="only for secondary dataset")
512
+ parser.add_argument('--secondary_ratio', default=0, type=float, help="a value between 0-1, ratio of using secondary data.")
513
+
514
+ parser.add_argument('--acc', default=1, type=int)
515
+ parser.add_argument('--fp8', default=0, type=int, help='Whether to use FP8 training')
516
+ parser.add_argument('--use_8bit_adam', default=0, type=int, help='Whether to use 8-bit Adam optimizer')
517
+ parser.add_argument('--epochs', default=1000, type=int)
518
+ parser.add_argument('--epoch_length', default=50000, type=int)
519
+ parser.add_argument('--lr', default=1e-4, type=float)
520
+ parser.add_argument('--min_lr', default=1e-6, type=float)
521
+ parser.add_argument('--drop_label', default=0, type=float)
522
+
523
+ parser.add_argument('--drop_image', default=0, type=float)
524
+ parser.add_argument('--last_frame_cond', default=0, type=int)
525
+ parser.add_argument('--video_to_video', default=0, type=int)
526
+
527
+ parser.add_argument('--resume_path', default=None, type=str)
528
+ parser.add_argument('--resume_epoch', default=0, type=int)
529
+ parser.add_argument('--warmup_steps', default=None, type=int, help='Warmup steps for training')
530
+
531
+ parser.add_argument('--fsdp', default=0, type=int)
532
+ parser.add_argument('--fsdp_text_encoder', default=0, type=int)
533
+ parser.add_argument('--gradient_checkpoint', default=0, type=int)
534
+ parser.add_argument('--gradient_checkpoint_mlp', default=None, type=int)
535
+ parser.add_argument('--compile', default=0, type=int, help='Whether to use torch.compile')
536
+ parser.add_argument('--latent_norm_regularization', default=0, type=float, help='Regularization on latent norm, 1e-4 is a good value')
537
+ parser.add_argument('--loss_scaling', default=1, type=int, help='Whether to use AMP')
538
+ parser.add_argument('--grad_clip', default=0, type=float, help='Gradient clipping threshold, 0 to disable')
539
+ parser.add_argument('--grad_skip', default=0, type=int, help='Skip gradient computation for the model')
540
+ parser.add_argument('--dry_run', default=0, type=int, help='Dry run for quick tests')
541
+ parser.add_argument('--wandb_name', default=None, type=str, help='Wandb name for the run')
542
+ parser.add_argument('--save_every', default=20, type=int, help='Save model every N epochs')
543
+ parser.add_argument('--sample_freq', default=1, type=int, help="sample every N epochs, 0 to disable")
544
+
545
+ # Sampling configuration
546
+ parser.add_argument('--cfg', default=0, type=float, nargs='+')
547
+ parser.add_argument('--num_samples', default=4096, type=int)
548
+ parser.add_argument('--sample_batch_size', default=256, type=int)
549
+ return parser
550
+
551
+
552
+ if __name__ == '__main__':
553
+ parser = get_tarflow_parser()
554
+ args = parser.parse_args()
555
+ main(args)