# # For licensing see accompanying LICENSE file. # Copyright (C) 2025 Apple Inc. All Rights Reserved. # import copy import tqdm import numpy as np import torch import torch.nn.functional as F from typing import List, Tuple from misc.pe import VisionRotaryEmbeddingFast, apply_rope, get_positions from misc import print from functools import partial from einops import rearrange, repeat from torch.utils.checkpoint import checkpoint INV_SOFTPLUS_1 = 0.541324854612918 def modulate(x, shift, scale): if shift is None: return x * (1 + scale) return x * (1 + scale) + shift def stable_neg_log_softplus(x): return torch.where( x > 20, # softplus(x) ≈ x → log ≈ log(x) -x.log(), # so -log(softplus(x)) ≈ -log(x) -F.softplus(x).log() ) class KVCache: def __init__(self): self._is_empty = True self.prefix_cache = None self.meta_data = {} def initialize(self, num_blocks, *size): self._is_empty = False self.num_blocks = num_blocks self.size = size self.kv_caches = [torch.zeros(2, *size) for _ in range(num_blocks)] self.kv_index = [0] * num_blocks def register_prefix_cache(self, prefix_cache): self.prefix_cache = prefix_cache @property def is_empty(self): return self._is_empty @property def is_full(self): if self.is_empty: return False return all(index == self.size[2] for index in self.kv_index) def delete(self): if not self.is_empty: self._is_empty = True del self.kv_caches del self.kv_index def to(self, device, dtype=torch.bfloat16): for i in range(self.num_blocks): self.kv_caches[i] = self.kv_caches[i].to(device=device, dtype=dtype) def extend_length(self, length): assert not self.is_empty, "KVCache is empty, cannot extend length" self.size = (self.size[0], self.size[1], self.size[2] + length, self.size[3]) for i in range(self.num_blocks): pad = self.kv_caches[i].new_zeros((2, *self.size)) pad[:, :, :, :self.kv_caches[i].size(3)] = self.kv_caches[i] self.kv_caches[i] = pad def expand_batch(self, ratio=2): self.size = (self.size[0] * ratio, *self.size[1:]) for i in range(self.num_blocks): self.kv_caches[i] = torch.cat([self.kv_caches[i] for _ in range(ratio)], dim=1) def remove_negative_cache(self): self.size = (self.size[0] // 2, *self.size[1:]) for i in range(self.num_blocks): self.kv_caches[i] = self.kv_caches[i].chunk(2, dim=1)[0] def backward_in_time(self, l): for i in range(self.num_blocks): self.kv_index[i] = max(0, self.kv_index[i] - l) def reset_kv_index(self): for i in range(self.num_blocks): self.kv_index[i] = 0 def __call__(self, block_idx, k, v): assert block_idx < self.num_blocks, f'block_idx {block_idx} out of range {self.num_blocks}' # write cache l = k.size(2) kv_index = self.kv_index[block_idx] if kv_index + l > self.size[2]: raise NotImplementedError("Overflow mode is not implemented") self.kv_caches[block_idx][0][:, :, kv_index: kv_index+l] = k self.kv_caches[block_idx][1][:, :, kv_index: kv_index+l] = v self.kv_index[block_idx] = kv_index + l # read cache kv_index = self.kv_index[block_idx] return self.kv_caches[block_idx][0][:, :, :kv_index], self.kv_caches[block_idx][1][:, :, :kv_index] class Permutation(torch.nn.Module): def __init__(self, seq_length: int): super().__init__() self.seq_length = seq_length self.input_shape = None def forward(self, x: torch.Tensor | List[torch.Tensor], dim: int = 1, inverse: bool = False): if not inverse: self.input_shape = x.shape x = rearrange(x, 'b t h w c -> b (t h w) c' if x.dim() == 5 else 'b h w c -> b (h w) c') x = self.permute(x, dim, self.input_shape, inverse=False) else: x = self.permute(x, dim, self.input_shape, inverse=True) x = x.reshape(-1, *self.input_shape[1:]) return x def permute(self, x: torch.Tensor, dim: int = 1, shape=None, inverse: bool = False) -> torch.Tensor: raise NotImplementedError('Overload me') class PermutationIdentity(Permutation): def permute(self, x: torch.Tensor, dim: int = 1, shape=None, inverse: bool = False) -> torch.Tensor: return x.clone() class PermutationFlip(Permutation): def permute(self, x: torch.Tensor, dim: int = 1, shape=None, inverse: bool = False) -> torch.Tensor: return x.flip(dims=[dim]) class PermutationFlipInBlock(Permutation): def permute(self, x: torch.Tensor, dim: int = 1, shape=None, inverse: bool = False) -> torch.Tensor: assert shape is not None, "shape must be provided for PermutationFlipInBlock" if len(shape) == 5: assert dim == 1, "dim must be 1 for 5D tensor in PermutationFlipInBlock" # flip the tensor within blocks of size `block_size`, globally still in the same order x = x.view(x.size(0), shape[1], -1, x.size(-1)).flip(dims=[2]).view_as(x) else: x = x.flip(dims=[dim]) return x class RMSNorm(torch.nn.Module): def __init__( self, dim: int, eps: float = 1e-6, add_unit_offset: bool = True, ): super().__init__() self.eps = eps self.add_unit_offset = add_unit_offset self.weight = torch.nn.Parameter(torch.zeros(dim)) def _norm(self, x): return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) def forward(self, x): # Llama does x.to(float16) * w whilst Gemma2 is (x * w).to(float16) # See https://github.com/huggingface/transformers/pull/29402 output = self._norm(x.float()) if self.add_unit_offset: output = output * (1 + self.weight.float()) else: output = output * self.weight.float() return output.type_as(x) class Attention(torch.nn.Module): def __init__(self, in_channels: int, head_channels: int, norm_type: str = "layer_norm", num_heads=None, num_kv_heads=None, use_qk_norm=False, use_post_norm=False, use_bias=True, hf_style_rope=False, non_causal=False): super().__init__() if norm_type == "layer_norm": self.norm = torch.nn.LayerNorm(in_channels) elif norm_type == "rms_norm": self.norm = RMSNorm(in_channels) else: self.norm = torch.nn.Identity() self.head_channels = head_channels self.num_heads = num_heads if num_heads is not None else in_channels // head_channels self.num_kv_heads = num_kv_heads if num_kv_heads is not None else self.num_heads # GQA self.q_size = self.num_heads * head_channels self.kv_size = self.num_kv_heads * head_channels self.qkv = torch.nn.Linear(in_channels, self.q_size + 2 * self.kv_size, bias=use_bias) self.proj = torch.nn.Linear(self.q_size, in_channels, bias=use_bias) self.query_norm = (RMSNorm(self.head_channels) if use_qk_norm else None) self.key_norm = (RMSNorm(self.head_channels) if use_qk_norm else None) self.post_norm = (RMSNorm(in_channels) if use_post_norm else None) self.sqrt_scale = head_channels ** (-0.25) self.hf_style_rope = hf_style_rope self.non_causal = non_causal def apply_rope(self, x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: if self.hf_style_rope: return rearrange(apply_rope(rearrange(x, '... (u d) -> ... (d u)', u=2), freqs_cis), '... (d u) -> ... (u d)', u=2) return apply_rope(x, freqs_cis) def prepare_for_attention(self, x: torch.Tensor, freqs_cis=None, kv_cache=None): B, T, _ = x.size() q, k, v = self.qkv(self.norm(x)).split([self.q_size, self.kv_size, self.kv_size], dim=-1) q = q.view(B, T, self.num_heads, self.head_channels).transpose(1, 2) # (b, h, t, d) k = k.view(B, T, self.num_kv_heads, self.head_channels).transpose(1, 2) # (b, h, t, d) v = v.view(B, T, self.num_kv_heads, self.head_channels).transpose(1, 2) # (b, h, t, d) if self.query_norm is not None and self.key_norm is not None: q, k = self.query_norm(q), self.key_norm(k) if kv_cache is not None: k, v = kv_cache(k, v) if freqs_cis is not None: lq, lk = q.size(2), k.size(2) q, k = self.apply_rope(q, freqs_cis[lk-lq:lk]), self.apply_rope(k, freqs_cis[:lk]) if self.num_kv_heads != self.num_heads: # GQA (b, h, t, d) k = torch.repeat_interleave(k, self.num_heads // self.num_kv_heads, dim=1) v = torch.repeat_interleave(v, self.num_heads // self.num_kv_heads, dim=1) return q.to(x.dtype), k.to(x.dtype), v.to(x.dtype) def output_after_attention(self, x: torch.Tensor): B, _, T, _ = x.shape x = x.transpose(1, 2).reshape(B, T, self.q_size) x = self.proj(x) if self.post_norm is not None: x = self.post_norm(x) return x def apply_attention(self, q, k, v, mask=None, temp=1.0): scale = self.sqrt_scale**2 / temp is_causal = not self.non_causal if is_causal and q.size(2) < k.size(2) and mask is None: prefix_len = k.size(2) - q.size(2) mask = torch.tril(torch.ones(q.size(2), k.size(2), device=q.device, dtype=torch.bool), diagonal=prefix_len) if mask is not None: mask = mask.bool() is_causal = False # spda x = torch.nn.functional.scaled_dot_product_attention( q, k, v, attn_mask=mask, is_causal=is_causal, scale=scale) return x def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None, temp: float = 1.0, freqs_cis=None, kv_cache=None, ) -> torch.Tensor: q, k, v = self.prepare_for_attention(x, freqs_cis, kv_cache) x = self.apply_attention(q, k, v, mask, temp) x = self.output_after_attention(x) return x class MLP(torch.nn.Module): def __init__(self, channels: int, expansion: float, use_swiglu=False, norm_type="layer_norm", use_post_norm=False, use_bias=True): super().__init__() if norm_type == "layer_norm": self.norm = torch.nn.LayerNorm(channels) elif norm_type == "rms_norm": self.norm = RMSNorm(channels) else: self.norm = torch.nn.Identity() self.post_norm = (RMSNorm(channels) if use_post_norm else None) self.use_swiglu = use_swiglu intermediate_channels = int(channels * expansion) if use_swiglu: self.gate_proj = torch.nn.Linear(channels, intermediate_channels, bias=use_bias) self.up_proj = torch.nn.Linear(channels, intermediate_channels, bias=use_bias) self.down_proj = torch.nn.Linear(intermediate_channels, channels, bias=use_bias) else: self.main = torch.nn.Sequential( torch.nn.Linear(channels, intermediate_channels, bias=use_bias), torch.nn.GELU(), torch.nn.Linear(intermediate_channels, channels, bias=use_bias) ) def forward(self, x: torch.Tensor) -> torch.Tensor: if self.use_swiglu: x = self.norm(x) x = self.down_proj(F.gelu(self.gate_proj(x), approximate='tanh') * self.up_proj(x)) else: x = self.main(self.norm(x)) return self.post_norm(x) if self.post_norm is not None else x class AttentionBlock(torch.nn.Module): def __init__(self, channels: int, head_channels: int, expansion: float = 4, use_adaln: bool = False, use_swiglu=False, norm_type="layer_norm", num_heads=None, num_kv_heads=None, use_qk_norm=False, use_post_norm=False, use_bias=True, hf_style_rope=False, non_causal=False): super().__init__() if use_adaln: self.adaLN_modulation = torch.nn.Sequential( torch.nn.SiLU(), torch.nn.Linear(channels, 4 * channels, bias=True), ) self.norm1 = torch.nn.LayerNorm(channels, elementwise_affine=False, eps=1e-6) self.norm2 = torch.nn.LayerNorm(channels, elementwise_affine=False, eps=1e-6) torch.nn.init.constant_(self.adaLN_modulation[-1].weight, 0) torch.nn.init.constant_(self.adaLN_modulation[-1].bias, 0) # Hard-coded norm_type=="none" for adaLN norm_type = 'none' else: self.adaLN_modulation = None self.attention = Attention(channels, head_channels, norm_type, num_heads, num_kv_heads, use_qk_norm, use_post_norm, use_bias, hf_style_rope, non_causal) self.mlp = MLP(channels, expansion, use_swiglu, norm_type, use_post_norm, use_bias) def forward( self, x: torch.Tensor, y: torch.Tensor | None = None, attn_mask: torch.Tensor | None = None, attn_temp: float = 1.0, c=None, freqs_cis=None, kv_cache=None, checkpoint_attn: bool = False, checkpoint_mlp: bool = False ) -> torch.Tensor: assert (x is not None) or (y is not None), "x or y must be provided" z = torch.cat([y, x], 1) if (x is not None) and (y is not None) else x if x is not None else y if self.adaLN_modulation is not None and c is not None: shift_msa, scale_msa, shift_mlp, scale_mlp = self.adaLN_modulation(c).chunk(4, dim=-1) z = z + self._forward_attention(z, attn_mask, attn_temp, freqs_cis, kv_cache, checkpoint_attn, shift_msa, scale_msa) z = z + self._forward_mlp(z, checkpoint_mlp, shift_mlp, scale_mlp) else: z = z + self._forward_attention(z, attn_mask, attn_temp, freqs_cis, kv_cache, checkpoint_attn) z = z + self._forward_mlp(z, checkpoint_mlp) x, y = (z[:, y.size(1):], z[:, :y.size(1)]) if (x is not None) and (y is not None) \ else (z, None) if x is not None else (None, z) return x, y def _forward_attention(self, z, attn_mask, attn_temp, freqs_cis, kv_cache, checkpoint_attn, shift=None, scale=None): def attn_fn(z_in): if shift is not None and scale is not None: z_in = modulate(self.norm1(z_in), shift, scale) return self.attention(z_in, attn_mask, attn_temp, freqs_cis, kv_cache) return checkpoint(attn_fn, z, use_reentrant=False) if checkpoint_attn and self.training else attn_fn(z) def _forward_mlp(self, z, checkpoint_mlp, shift=None, scale=None): def mlp_fn(z_in): if shift is not None and scale is not None: z_in = modulate(self.norm2(z_in), shift, scale) return self.mlp(z_in) return checkpoint(mlp_fn, z, use_reentrant=False) if checkpoint_mlp and self.training else mlp_fn(z) class MetaBlock(torch.nn.Module): attn_mask: torch.Tensor def __init__( self, in_channels: int, channels: int, img_size: int, permutation: Permutation, pt_seq_len: int | None = None, num_layers: int = 1, head_dim: int = 64, num_heads: None | int = None, num_kv_heads: None | int = None, txt_size: int = 0, txt_dim: int = 0, expansion: float = 4, use_rope: bool = False, use_sos: bool = False, use_softplus: bool = False, use_swiglu: bool = False, use_qk_norm: bool =False, use_post_norm: bool = False, use_final_norm: bool = False, use_bias: bool = True, use_proj_txt: bool = True, hf_style_rope: bool = False, norm_type: str ="layer_norm", use_mm_attn: bool = False, use_checkpoint: int = False, use_checkpoint_mlp: int = None, soft_clip: float = 0, local_attn_window: int = None, ): super().__init__() out_channels = in_channels * 2 self.proj_in = torch.nn.Linear(in_channels, channels) self.proj_out = torch.nn.Linear(channels, out_channels) if use_sos: self.sos_embed = torch.nn.Parameter(torch.randn(1, 1, in_channels)) torch.nn.init.constant_(self.proj_out.weight, 0) self.txt_size = txt_size self.img_size = img_size self.txt_dim = txt_dim self.pt_seq_len = pt_seq_len or img_size # KV cache configurations num_kv_heads = num_kv_heads or (num_heads or channels // head_dim) self.kv_cache_size = [num_kv_heads, head_dim] if not use_rope: self.pos_embed = torch.nn.Parameter(torch.randn(img_size ** 2, channels) * 1e-2) else: self.pos_embed = None if txt_dim > 0: self.proj_txt = torch.nn.Linear(txt_dim, channels) if use_proj_txt else torch.nn.Identity() assert use_proj_txt or (txt_dim == channels), 'text dimension must equal channels when not using projection' self.attn_blocks = torch.nn.ModuleList( [AttentionBlock(channels, head_dim, expansion, False, use_swiglu, norm_type, num_heads, num_kv_heads, use_qk_norm, use_post_norm, use_bias, hf_style_rope) for _ in range(num_layers)]) self.use_final_norm = use_final_norm if use_final_norm: self.final_norm = RMSNorm(channels) self.use_softplus = use_softplus self.permutation = permutation self.use_checkpoint = use_checkpoint self.use_checkpoint_mlp = use_checkpoint_mlp self.use_sos = use_sos self.soft_clip = soft_clip self.local_attn_window = local_attn_window self.block_masks = {} # for local attention # ---- DEPRECATED: do not pass mask to enable flash attention ----- For compatibility ----- # self.register_buffer('attn_mask', torch.tril(torch.ones(pt_seq_len ** 2 + txt_size, pt_seq_len ** 2 + txt_size))) def get_freqs_cis(self, x, y, rope): # get the input shape h, w = x.size(-3), x.size(-2) d = x.size(1) if x.dim() == 5 else 0 txt_size = y.size(1) if self.txt_size > 0 and y is not None else 0 if not rope.is_1d: # prepare 2D RoPE if self.txt_size > 0 or d > 0: # prepare 3D RoPE if self.txt_dim > 0: # text is conditioned pos = get_positions(h, w, txt_size, rope.pt_seq_len, d, mode='3d') else: # text is not conditioned pos = get_positions(h, w, 0, rope.pt_seq_len, d, mode='3d') else: pos = get_positions(h, w, 0, rope.pt_seq_len, mode='2d') else: # prepare 1D RoPE pos = get_positions(h, w, txt_size, rope.pt_seq_len, mode='1d') return rope(pos.type_as(x)) def get_sos_embed(self, x): sos_embed = self.sos_embed.expand(x.size(0), -1, -1) return sos_embed def get_prepared(self, x): # input, output, freqs_cis x_in = x.clone() if self.use_sos: # add SOS token, predict the first token sos->x_in[0] x = torch.cat([self.get_sos_embed(x), x[:, :-1]], dim=1) return x_in, x def get_proj_in(self, x): x = self.proj_in(x) return x def get_proj_out(self, x): x = self.proj_out(x) if hasattr(self, "soft_clip") and self.soft_clip > 0: x = self.soft_clip * torch.tanh(x / self.soft_clip) return x def get_local_window_mask(self, x, y): _, T, H, W, _ = x.shape L = y.size(1) if y is not None else 0 B = H * W N = T * B S = L + N G = self.local_attn_window def mask(q, k): return (k <= q) & ((k < L) | ((k - L) // B > (q - L) // B - G)) return mask(torch.arange(S, device=x.device)[:, None], torch.arange(S, device=x.device)[None, :]) def initialize_kv_cache(self, kv_cache, x, freqs_cis, reuse_kv_cache=False): if self.local_attn_window is not None and self.local_attn_window > 0: video_frame_size = x.size(-3) * x.size(-2) kv_cache_length = self.local_attn_window * video_frame_size kv_cache_length += self.txt_size if self.txt_dim > 0 else 0 kv_cache.meta_data.update( {"frame_size": video_frame_size, "txt_size": self.txt_size + 1 if self.txt_dim > 0 else 0}) else: kv_cache_length = freqs_cis.size(0) kv_cache_size = (x.size(0), self.kv_cache_size[0], kv_cache_length, self.kv_cache_size[1]) if kv_cache.is_empty: kv_cache.initialize(len(self.attn_blocks), *kv_cache_size) kv_cache.to(x.device, x.dtype) else: target_size = kv_cache_size[-2] if reuse_kv_cache: target_size = target_size - kv_cache.kv_index[0] kv_cache.extend_length(target_size) return kv_cache def forward(self, x: torch.Tensor | List[torch.Tensor], y: torch.Tensor | None = None, rope=None, kv_cache=None, guidance=None): freqs_cis = self.get_freqs_cis(x, y, rope) if rope is not None else None attn_mask = None if kv_cache is not None: kv_cache = self.initialize_kv_cache(kv_cache, x, freqs_cis) x = self.permutation(x) pos_embed = self.permutation(self.pos_embed, dim=0) if self.pos_embed is not None else None # prepare input x_in, x = self.get_prepared(x) if kv_cache is not None: kv_cache.register_prefix_cache(x_in) # input projection x = self.get_proj_in(x) if pos_embed is not None: x = x + pos_embed # conditioning if self.txt_dim > 0: y = self.proj_txt(y) else: y = None # main block for it, block in enumerate(self.attn_blocks): _kv_cache = partial(kv_cache, it) if kv_cache is not None else None # Frequency-based checkpointing strategy: # - Checkpoint attention every use_checkpoint blocks (if use_checkpoint > 0) # - Checkpoint MLP every use_checkpoint_mlp blocks (if provided), otherwise every use_checkpoint blocks checkpoint_attn = self.training and self.use_checkpoint > 0 and ((it + 1) % self.use_checkpoint == 0) if self.use_checkpoint_mlp is not None: checkpoint_mlp = self.training and self.use_checkpoint_mlp > 0 and ((it + 1) % self.use_checkpoint_mlp == 0) else: checkpoint_mlp = self.training and self.use_checkpoint > 0 and ((it + 1) % self.use_checkpoint == 0) x, y = block(x, y, attn_mask, 1.0, None, freqs_cis, _kv_cache, checkpoint_attn=checkpoint_attn, checkpoint_mlp=checkpoint_mlp) # final norm if self.use_final_norm: x, y = self.final_norm(x), self.final_norm(y) if y is not None else None x = self.get_proj_out(x) if not self.use_sos: # no SOS token, we need to shift the sequence x = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) xa, xb = x.chunk(2, dim=-1) # Store original dtype for output conversion original_dtype = xa.dtype # Convert to fp32 for numerical stability xa, xb, x_in = xa.float(), xb.float(), x_in.float() if not self.use_softplus: xa = xa.exp() else: xa = F.softplus(xa + INV_SOFTPLUS_1) if guidance is not None and guidance > 0: xb, xa = self.guidance(xa, xb, guidance, 1.0, 'ab') # NOTE: this "scale" is in fact 1/sigma, not sigma x = self.permutation((x_in - xb) / xa, inverse=True) logdet = -torch.log(xa) # keep all the dimensions # Convert back to original precision x = x.to(original_dtype) return x, y, logdet def guidance(self, za, zb, guidance, r=1.0, guide_what='ab'): za, za_u = [torch.cat([a, a]) for a in za.chunk(2, dim=0)] zb, zb_u = [torch.cat([a, a]) for a in zb.chunk(2, dim=0)] g = r * guidance def logits_guided(mu_c, sigma_c, mu_u, sigma_u, w): # inspired from: (1+w) * logP_cond - w * logP_uncond # sigma_c = torch.minimum(sigma_c, sigma_u) s = (sigma_c / sigma_u).clip(max=1.0).square() sigma_eff = sigma_c / (1 + w - w * s).sqrt() mu_eff = ((1 + w) * mu_c - (w * s) * mu_u) / (1 + w - w * s) return mu_eff, sigma_eff def original_guidance(mu_c, sigma_c, mu_u, sigma_u, w): if 'a' in guide_what: sigma_c = sigma_c + g * (sigma_c - sigma_u) if 'b' in guide_what: mu_c = mu_c + g * (mu_c - mu_u) return mu_c, sigma_c #zb, za = original_guidance(zb, za, zb_u, za_u, guidance) zb, za = logits_guided(zb, za, zb_u, za_u, guidance) return zb, za def reverse_step( self, x: torch.Tensor, t: int, kv_cache: KVCache, pos_embed: torch.Tensor | None = None, y: torch.Tensor | None = None, attn_temp: float = 1.0, freqs_cis=None ) -> tuple[torch.Tensor, torch.Tensor]: # Store original dtype for sampling tensor original_dtype = x.dtype if self.use_sos: # get i-th patch but keep the sequence dimension x_in = self.get_sos_embed(x[:, :1]) if t == 0 else x[:, t - 1 : t] else: x_in = x[:, t : t + 1] # Convert to model's dtype for neural network computation if hasattr(self.proj_in, 'weight'): target_dtype = self.proj_in.weight.dtype x_in = x_in.to(target_dtype) x = self.get_proj_in(x_in) # if positional embedding if pos_embed is not None: x = x + pos_embed[t: t+1] # main block for i, block in enumerate(self.attn_blocks): x, _ = block(x, None, attn_temp=attn_temp, freqs_cis=freqs_cis, kv_cache=partial(kv_cache, i)) # final norm if self.use_final_norm: x = self.final_norm(x) x = self.get_proj_out(x) xa, xb = x.chunk(2, dim=-1) # Convert back to original dtype for sampling computations return xa.to(original_dtype), xb.to(original_dtype) def reverse_step_condition(self, y, kv_cache, pos_embed=None, attn_temp: float = 1.0, freqs_cis=None): # Convert to model's dtype for neural network computation if hasattr(self.proj_txt, 'weight'): target_dtype = self.proj_txt.weight.dtype y = y.to(target_dtype) y = self.proj_txt(y) for i, block in enumerate(self.attn_blocks): _, y = block(None, y, attn_temp=attn_temp, freqs_cis=freqs_cis, kv_cache=partial(kv_cache, i)) return y def reverse( self, z: torch.Tensor, y: torch.Tensor | None = None, guidance: float = 0, guide_what: str = 'ab', attn_temp: float = 1.0, annealed_guidance: bool = False, rope=None, verbose=False, kv_cache: KVCache=KVCache(), **unused_kwargs ) -> torch.Tensor: # Ensure sampling tensors are in float32 for numerical stability original_dtype = z.dtype z = z.float() freqs_cis = self.get_freqs_cis(z, y, rope) if rope is not None else None if guidance > 0: z = torch.cat([z, z], 0) # kv cache reuse_kv_cache = kv_cache.prefix_cache is not None and kv_cache.kv_index[0] > 0 kv_cache = self.initialize_kv_cache(kv_cache, z, freqs_cis, reuse_kv_cache) # permute the input z = self.permutation(z) pos_embed = self.permutation(self.pos_embed, dim=0) if self.pos_embed is not None else None # run additional text condition, results will be used in KV cache. if self.txt_dim > 0: if not reuse_kv_cache: self.reverse_step_condition(y, kv_cache, pos_embed, attn_temp, freqs_cis) txt_size = y.size(1) if self.txt_dim > 0 else 0 # run the reverse process x = z.clone() if reuse_kv_cache: x[:, :kv_cache.prefix_cache.size(1)] = kv_cache.prefix_cache # fill the prefix cache T = x.size(1) - 1 if not self.use_sos else x.size(1) for t in tqdm.trange(T, disable=not verbose, desc='Sub-flow Sampling', leave=False): if reuse_kv_cache and kv_cache.kv_index[0] > t + txt_size: continue za, zb = self.reverse_step(x, t, kv_cache, pos_embed, y, attn_temp, freqs_cis) # Ensure sampling computations stay in float32 za, zb = za.float(), zb.float() if not self.use_softplus: za, zb = za.exp().squeeze(1), zb.squeeze(1) else: za, zb = F.softplus(za + INV_SOFTPLUS_1).squeeze(1), zb.squeeze(1) if guidance > 0 and guide_what: r = (t + 1) / T if annealed_guidance else 1.0 zb, za = self.guidance(za, zb, guidance, r, guide_what) if self.use_sos: x[:, t] = z[:, t] * za + zb else: x[:, t + 1] = z[:, t + 1] * za + zb if guidance > 0: x = x.chunk(2, dim=0)[0] kv_cache.remove_negative_cache() # remove the second half of the cache x = self.permutation(x, inverse=True) # Convert back to original dtype if needed return x.to(original_dtype) def jacobi(self, z: torch.Tensor, y: torch.Tensor | None = None, guidance: float = 0, rope=None, kv_cache=None, verbose=False, jacobi_block_size: int = 32, jacobi_max_iter: int = 32, jacobi_th: float = 0.001, context_length: int = None, **unused_kwargs) -> torch.Tensor: assert self.use_sos, "Jacobi iteration requires SOS token to be used" assert self.pos_embed is None, "Jacobi iteration does not support positional embedding" # Ensure sampling tensors are in float32 for numerical stability original_dtype = z.dtype z = z.float() freqs_cis = self.get_freqs_cis(z, y, rope) if rope is not None else None if guidance > 0: z = torch.cat([z, z], 0) # kv cache reuse_kv_cache = kv_cache.prefix_cache is not None and kv_cache.kv_index[0] > 0 kv_cache = self.initialize_kv_cache(kv_cache, z, freqs_cis, reuse_kv_cache) video_length = z.size(1) if z.dim() == 5 else 1 # permute the input z = self.permutation(z) # prepare input x_full = torch.cat([self.get_sos_embed(z), z.clone()], dim=1) if reuse_kv_cache: x_full[:, 1: kv_cache.prefix_cache.size(1) + 1] = kv_cache.prefix_cache # fill the prefix cache # conditioning if self.txt_dim > 0: if not reuse_kv_cache: self.reverse_step_condition(y, kv_cache, freqs_cis=freqs_cis) txt_size = y.size(1) if self.txt_dim > 0 else 0 video_frame_size = z.size(1) // video_length start_idx = 0 if reuse_kv_cache: start_idx = kv_cache.kv_index[0] - txt_size # start from the last cached index prog_bar = tqdm.tqdm(total=z.size(1), disable=not verbose, desc='Block-wise Jacobi Iteration', leave=False) prog_bar.update(start_idx) local_attn_window = self.local_attn_window * video_frame_size if self.local_attn_window is not None else None target_frame_size = z.size(1) if local_attn_window is None else min(z.size(1), local_attn_window) context_size = None if local_attn_window is None else context_length * video_frame_size while target_frame_size <= z.size(1): while start_idx < target_frame_size: chunk_size = jacobi_block_size if start_idx <= video_frame_size else jacobi_block_size * 4 local_done = torch.zeros((), dtype=torch.bool, device=x_full.device) for i in tqdm.tqdm(range(jacobi_max_iter), disable=True, desc='Jacobi Iteration', leave=False): if start_idx + chunk_size >= target_frame_size: chunk_size = target_frame_size - start_idx if i == 0 and start_idx > video_frame_size: # optional to use past frame to initialize the current frame x = x_full[:, start_idx - video_frame_size: start_idx + chunk_size - video_frame_size] else: x = x_full[:, start_idx: start_idx + chunk_size] # main forward - convert to model dtype for neural network computation if hasattr(self.proj_in, 'weight'): target_dtype = self.proj_in.weight.dtype x = x.to(target_dtype) x = self.get_proj_in(x) for it, block in enumerate(self.attn_blocks): _kv_cache = partial(kv_cache, it) if kv_cache is not None else None x = block(x, None, freqs_cis=freqs_cis, kv_cache=_kv_cache)[0] if self.use_final_norm: x = self.final_norm(x) x = self.get_proj_out(x) xa, xb = x.chunk(2, dim=-1) # Convert back to float32 for sampling computations xa, xb = xa.float(), xb.float() if not self.use_softplus: xa = xa.exp() else: xa = F.softplus(xa + INV_SOFTPLUS_1) if guidance > 0: xb, xa = self.guidance(xa, xb, guidance, 1.0, 'ab') # compute the Jacobi Iteration - all in float32 new_x = xb + xa * z[:, start_idx: start_idx+chunk_size] diff = ((new_x - x_full[:, start_idx+1: start_idx+1+chunk_size]) ** 2).mean() / (new_x ** 2).mean() x_full[:, start_idx+1: start_idx+1+chunk_size] = new_x if diff < jacobi_th or i == jacobi_max_iter - 1: # do not clean the cache on the last iteration local_done.fill_(1) global_done = local_done.clone() torch.distributed.all_reduce(global_done, op=torch.distributed.ReduceOp.MIN) if int(global_done.item()) == 1: break kv_cache.backward_in_time(chunk_size) start_idx += chunk_size prog_bar.update(chunk_size) if target_frame_size >= z.size(1): break target_frame_size += local_attn_window - context_size if local_attn_window is not None else video_frame_size target_frame_size = min(target_frame_size, z.size(1)) # re-encode the context with attention blocks print(f're-encoding the context {start_idx+1-context_size}:{start_idx+1}') kv_cache.reset_kv_index() if self.txt_dim > 0: self.reverse_step_condition(y, kv_cache, freqs_cis=freqs_cis) x_context = x_full[:, start_idx+1-context_size: start_idx+1] x_context_in, x_context = self.get_prepared(x_context) x_context = self.get_proj_in(x_context) for it, block in enumerate(self.attn_blocks): _kv_cache = partial(kv_cache, it) if kv_cache is not None else None x_context = block(x_context, None, freqs_cis=freqs_cis, kv_cache=_kv_cache)[0] x = x_full[:, 1:] if guidance > 0: x = x.chunk(2, dim=0)[0] # remove SOS token x = self.permutation(x, inverse=True) # Convert back to original dtype if needed return x.to(original_dtype) class IdentityBlock(MetaBlock): def __init__(self, *args, **kwargs): super(MetaBlock, self).__init__() def forward(self, x, y=None, rope=None, **unused): return x, y, x.new_zeros(x.size(0)) def reverse(self, z: torch.Tensor, y: torch.Tensor | None = None, guidance: float = 0, guide_what: str = 'ab', attn_temp: float = 1.0, annealed_guidance: bool = False, rope=None, verbose=False, kv_cache: KVCache=KVCache(), **unused): # Preserve original dtype return z def jacobi(self, z: torch.Tensor, y: torch.Tensor | None = None, guidance: float = 0, rope=None, kv_cache=None, verbose=False, jacobi_block_size: int = 64, jacobi_th: float = 0.005, **unused_kwargs) -> torch.Tensor: return z class NonCausalBlock(MetaBlock): def __init__( self, in_channels: int, channels: int, img_size: int, pt_seq_len: int | None = None, num_layers: int = 8, head_dim: int = 64, num_heads: None | int = None, num_kv_heads: None | int = None, txt_size: int = 0, txt_dim: int = 0, expansion: float = 4, use_rope: bool = False, use_swiglu: bool = False, use_qk_norm: bool =False, use_post_norm: bool = False, use_final_norm: bool = False, use_bias: bool = True, hf_style_rope: bool = False, norm_type: str ="layer_norm", use_checkpoint: int = False, use_checkpoint_mlp: int = None, block_causal: int = 0, window: int = None, **unused_kwargs, ): super(MetaBlock, self).__init__() out_channels = in_channels self.proj_in = torch.nn.Linear(in_channels, channels) self.proj_out = torch.nn.Linear(channels, out_channels) torch.nn.init.constant_(self.proj_out.weight, 0) self.txt_size = txt_size self.img_size = img_size self.txt_dim = txt_dim self.pt_seq_len = pt_seq_len or img_size self.block_causal = block_causal self.window = window # KV cache configurations num_kv_heads = num_kv_heads or (num_heads or channels // head_dim) self.kv_cache_size = [num_kv_heads, head_dim] if txt_dim > 0: self.proj_txt = torch.nn.Linear(txt_dim, channels) self.attn_blocks = torch.nn.ModuleList( [AttentionBlock(channels, head_dim, expansion, False, use_swiglu, norm_type, num_heads, num_kv_heads, use_qk_norm, use_post_norm, use_bias, hf_style_rope, non_causal=True) for _ in range(num_layers)]) self.use_final_norm = use_final_norm if use_final_norm: self.final_norm = RMSNorm(channels) self.use_checkpoint = use_checkpoint self.use_checkpoint_mlp = use_checkpoint_mlp self.block_masks = {} # for local attention def get_local_window_mask(self, x, y): _, T, H, W, _ = x.shape L = y.size(1) if y is not None else 0 B = H * W N = T * B S = L + N A = self.block_causal G = self.window if self.window is not None else 10000 def mask(q, k): return (k < L) | ( ((k - L) // B >= (q - L) // B + A - 1 - G) & ((k - L) // B <= torch.relu(q - L) // B + A - 1) ) return mask(torch.arange(S, device=x.device)[:, None], torch.arange(S, device=x.device)[None, :]) def forward(self, x, y, rope, **unused): freqs_cis = self.get_freqs_cis(x, y, rope) if rope is not None else None if self.block_causal > 0 and x.dim() == 5: attn_mask = self.get_local_window_mask(x, y if self.txt_dim > 0 else None) else: attn_mask = None if x.dim() == 5: # video input N, H, W, x = x.size(1), x.size(2), x.size(3), rearrange(x, 'b t h w c -> b (t h w) c') # flatten x else: N, H, W, x = 0, x.size(1), x.size(2), rearrange(x, 'b h w c -> b (h w) c') # flatten x x = self.get_proj_in(x) y = self.proj_txt(y) if self.txt_dim > 0 else None for it, block in enumerate(self.attn_blocks): # Frequency-based checkpointing strategy: # - Checkpoint attention every use_checkpoint blocks (if use_checkpoint > 0) # - Checkpoint MLP every use_checkpoint_mlp blocks (if provided), otherwise every use_checkpoint blocks checkpoint_attn = self.training and self.use_checkpoint > 0 and ((it + 1) % self.use_checkpoint == 0) if self.use_checkpoint_mlp is not None: checkpoint_mlp = self.training and self.use_checkpoint_mlp > 0 and ((it + 1) % self.use_checkpoint_mlp == 0) else: checkpoint_mlp = self.training and self.use_checkpoint > 0 and ((it + 1) % self.use_checkpoint == 0) x, y = block(x, y, attn_mask, 1.0, None, freqs_cis, checkpoint_attn=checkpoint_attn, checkpoint_mlp=checkpoint_mlp) if self.use_final_norm: x = self.final_norm(x) x = self.get_proj_out(x) if N > 0: x = rearrange(x, 'b (t h w) d -> b t h w d', t=N, h=H, w=W) else: x = rearrange(x, 'b (h w) d -> b h w d', h=H, w=W) return x class Model(torch.nn.Module): def __init__( self, in_channels: int, img_size: int, patch_size: int, channels: int, num_blocks: int, layers_per_block: List[int], head_dim: int = 64, num_heads: None | int = None, num_kv_heads: None | int = None, rope: bool = False, pt_seq_len: None | int = None, sos: bool = False, txt_size: int = 0, txt_dim: int = 0, cond_top_only: bool = False, use_softplus: bool = False, use_swiglu: bool = False, use_bias: bool = True, use_qk_norm: bool = False, use_post_norm: bool = False, use_final_norm: bool = False, hf_style_rope: bool = False, norm_type: str = "layer_norm", use_checkpoint: int = 0, use_checkpoint_mlp: int = None, use_pretrained_lm: str | None = None, use_mm_attn: bool = False, soft_clip: float = 0, seq_order: str = "R2L", learnable_self_denoiser: bool = False, conditional_denoiser: bool = False, temporal_causal: int = 0, top_block_channels: int = None, # If specified, top block uses different size shallow_block_local: bool = False, # If True, shallow blocks only constrained within a frame denoiser_window: int = None, # If specified, use local attention in the denoiser with given window size local_attn_window: int = None, # If specified, use local attention in all blocks with given window size **unused_kwargs, ): super().__init__() self.img_size = img_size self.in_channels = in_channels self.patch_size = patch_size self.pt_seq_len = pt_seq_len or img_size // patch_size self.num_patches = self.pt_seq_len ** 2 self.use_rope = rope self.use_sos = sos self.use_softplus = use_softplus self.cond_top_only = cond_top_only self.seq_order = seq_order self.temporal_causal = temporal_causal self.top_block_channels = top_block_channels or channels self.shallow_block_local = shallow_block_local self.expansion_init_std = 0.02 assert (not local_attn_window) or shallow_block_local, 'local_attn_window requires shallow_block_local' assert (not shallow_block_local) or self.cond_top_only, 'shallow_block_local requires cond_top_only' assert (not self.cond_top_only) or (txt_size > 0), 'cond_top_only requires txt_size > 0' assert (seq_order == 'L2R') or (temporal_causal == 0), 'seq_order must be L2R if temporal causal is True' permutations = [PermutationIdentity(self.num_patches), PermutationFlip(self.num_patches)] if temporal_causal == 0 else \ [PermutationIdentity(self.num_patches), PermutationFlipInBlock(self.num_patches)] blocks = [] if len(layers_per_block) == 1: layers_per_block = [layers_per_block[0]] * num_blocks base_kwargs = dict( in_channels=in_channels * patch_size**2, channels=channels, img_size=img_size // patch_size, pt_seq_len=self.pt_seq_len, txt_size=txt_size, use_rope=self.use_rope, hf_style_rope=hf_style_rope, use_sos=self.use_sos, use_softplus=self.use_softplus, use_swiglu=use_swiglu, use_qk_norm=use_qk_norm, use_post_norm=use_post_norm, use_final_norm=use_final_norm, use_bias=use_bias, norm_type=norm_type, num_heads=num_heads, num_kv_heads=num_kv_heads, head_dim=head_dim, use_checkpoint=use_checkpoint, use_checkpoint_mlp=use_checkpoint_mlp, soft_clip=soft_clip, ) # bottom blocks for i in range(num_blocks-1): permutation = permutations[i % 2] if seq_order == 'R2L' else permutations[(i+1) % 2] Block = IdentityBlock if layers_per_block[i] == 0 else MetaBlock blocks.append(Block(permutation=permutation, num_layers=layers_per_block[i], txt_dim=0 if cond_top_only else txt_dim, **base_kwargs)) # top block gen_kwargs = copy.deepcopy(base_kwargs) if self.top_block_channels != channels: gen_kwargs['channels'] = self.top_block_channels if num_heads is None: gen_kwargs['num_heads'] = self.top_block_channels // head_dim if use_pretrained_lm is not None: gen_kwargs.update(eval(f"{use_pretrained_lm}_kwargs")) if use_mm_attn: gen_kwargs.update({"use_mm_attn": True}) # only top block will receive this else: gen_kwargs.update({"num_layers": layers_per_block[-1]}) permutation = permutations[(num_blocks-1) % 2] if seq_order == 'R2L' else permutations[(num_blocks) % 2] top_block = MetaBlock(permutation=permutation, txt_dim=txt_dim, local_attn_window=local_attn_window, **gen_kwargs) blocks.append(top_block) # put together self.blocks = torch.nn.ModuleList(blocks) # Self-denoiser if learnable_self_denoiser: self.learnable_self_denoiser = NonCausalBlock( num_layers=8, block_causal=temporal_causal, window=denoiser_window, txt_dim=0 if not conditional_denoiser else txt_dim, **base_kwargs) # setup rotary embeddings if self.use_rope: self.feat_rope = VisionRotaryEmbeddingFast( dim=base_kwargs['head_dim'] // 2, pt_seq_len=base_kwargs['pt_seq_len'], latent_len=txt_size) if use_pretrained_lm is not None: # using standard 1D RoPE self.feat_rope_gen = VisionRotaryEmbeddingFast( dim=gen_kwargs['head_dim'] // 2, pt_seq_len=gen_kwargs['pt_seq_len'], no_buffer=True, is_1d=True) else: self.feat_rope_gen = VisionRotaryEmbeddingFast( dim=gen_kwargs['head_dim'] // 2, pt_seq_len=gen_kwargs['pt_seq_len'], latent_len=txt_size, no_buffer=True) else: self.feat_rope = self.feat_rope_gen = None # ----- DEPRECATED: not useful ------- self.register_buffer('var', torch.ones(self.num_patches, in_channels * patch_size**2)) def patchify(self, x: List[torch.Tensor] | torch.Tensor, p: int | None = None) -> torch.Tensor: """Convert an image (N,C',H,W) to a sequence of patches (N,T,C')""" if len(x.shape) < 4: return x # no need patchify H, W = x.shape[-2], x.shape[-1] p = self.patch_size * p if p is not None else self.patch_size assert H % p == 0 and W % p == 0, "H and W must be divisible by patch_size" x = rearrange(x, '... c (h p1) (w p2) -> ... h w (p1 p2 c)', p1=p, p2=p) return x def unpatchify(self, x: List[torch.Tensor] | torch.Tensor, p: int | None = None) -> torch.Tensor: """Convert a sequence of patches (N,T,C) to an image (N,C',H,W)""" if len(x.shape) < 4: return x # no need unpatchify p = self.patch_size * p if p is not None else self.patch_size H, W = x.shape[-3], x.shape[-2] return rearrange(x, '... h w (p1 p2 c) -> ... c (h p1) (w p2)', h=H, w=W, p1=p, p2=p) def get_loss(self, z: torch.Tensor | List[torch.Tensor], logdets: torch.Tensor | List[torch.Tensor], weights: torch.Tensor | None = None, drop_first=False) -> dict[str, torch.Tensor]: if drop_first: z, logdets = z[:, 1:], [logdet[:, 1:] for logdet in logdets] loss_z = 0.5 * z.pow(2).mean(dim=tuple(range(1, z.dim()))) loss_logdet = -sum([logdet.mean(dim=tuple(range(1, logdet.dim()))) for logdet in logdets]) loss = loss_z + loss_logdet if weights is not None: loss = loss * weights loss = loss.mean() return {'loss': loss, 'loss_z': loss_z.detach().mean(), 'loss_logdet': loss_logdet.detach().mean()} def forward( self, x: torch.Tensor, y: torch.Tensor | None = None, reverse=False, kv_caches=None, denoiser=False, context=False, **kwargs ) -> tuple[torch.Tensor, list[torch.Tensor], torch.Tensor]: if context: return self.forward_context(x, y, kv_caches=kv_caches, **kwargs) if reverse: # inference mode return self.reverse(x, y, kv_caches=kv_caches, **kwargs) if denoiser: # forward with self-denoiser x = self.patchify(x) x = self.learnable_self_denoiser(x, y, self.feat_rope, **kwargs) return self.unpatchify(x) logdets, outputs = [], [] guidance = kwargs.get('guidance', 0) # Bottom blocks x = self.patchify(x) outputs += [x] for it, block in enumerate(self.blocks[:-1]): if self.shallow_block_local and x.dim() == 5: # video input x = rearrange(x, 'b t h w c -> (b t) 1 h w c') x, _, logdet = block(x, y.chunk(2, dim=0)[0] if self.cond_top_only and guidance > 0 else y, self.feat_rope, kv_cache=kv_caches[-(it+1)] if kv_caches is not None else None) if self.shallow_block_local and x.dim() == 5: # video input x = rearrange(x, '(b t) 1 h w c -> b t h w c', b=outputs[0].size(0), t=outputs[0].size(1)) logdet = rearrange(logdet, '(b t) l c -> b t l c', b=outputs[0].size(0), t=outputs[0].size(1)) logdets += [logdet] outputs += x if isinstance(x, list) else [x] # Top block x, y, logdet = self.blocks[-1](x, y, self.feat_rope_gen, kv_cache=kv_caches[0] if kv_caches is not None else None, guidance=guidance) outputs += [x] x = self.unpatchify(x) logdets += [logdet] return x, y, outputs, logdets def forward_context(self, x: torch.Tensor, y: torch.Tensor | None = None, kv_caches: List[KVCache] | None = None, **kwargs): if kv_caches is None: kv_caches = [KVCache() for _ in range(len(self.blocks))] use_cfg = (x.size(0) * 2 == y.size(0)) if (y is not None and self.cond_top_only) else False if use_cfg: x = torch.cat([x, x], 0) # duplicate for classifier-free guidance generation self.forward(x, y, kv_caches=kv_caches, **kwargs) # run once to fill the cache if use_cfg: for kv in kv_caches[1:]: kv.remove_negative_cache() # remove negative cache except for the first block kv.prefix_cache = kv.prefix_cache.chunk(2, dim=0)[0] if kv.prefix_cache is not None else None return kv_caches def reverse_deep(self, x: List[torch.Tensor] | torch.Tensor, y: torch.Tensor | None = None, guidance: float = 0, verbose: bool = False, kv_caches: List[KVCache] | None = None, jacobi: bool = False, need_caches: bool = False, seq: List[torch.Tensor] = [], **sampling_kwargs,): x = self.patchify(x) x = (self.blocks[-1].jacobi if jacobi else self.blocks[-1].reverse)( x, y, guidance, rope=self.feat_rope_gen, kv_cache=kv_caches[0], verbose=verbose, **sampling_kwargs) x = self.unpatchify(x) if not need_caches: kv_caches[0].delete() seq.append(x) return x def reverse_shallow(self, x: List[torch.Tensor] | torch.Tensor, y: torch.Tensor | None = None, guidance: float = 0, verbose: bool = False, kv_caches: List[KVCache] | None = None, jacobi: bool = False, need_caches: bool = False, seq: List[torch.Tensor] = [], **sampling_kwargs,): x = self.patchify(x) for it, block in enumerate(reversed(self.blocks[:-1])): if self.shallow_block_local and x.dim() == 5: # video input x = rearrange(x, 'b t h w c -> (b t) 1 h w c') kv_caches[it+1]._is_empty = True kv_caches[it+1].prefix_cache = None x = (block.jacobi if jacobi else block.reverse)( x, y, guidance, rope=self.feat_rope, kv_cache=kv_caches[it+1], verbose=verbose, **sampling_kwargs) if self.shallow_block_local and x.dim() == 5: # video input x = rearrange(x, '(b t) 1 h w c -> b t h w c', b=seq[0].size(0), t=seq[0].size(1)) seq.append(self.unpatchify(x)) if not need_caches: kv_caches[it+1].delete() x = self.unpatchify(x) return x def reverse( self, x: List[torch.Tensor] | torch.Tensor, y: torch.Tensor | None = None, guidance: float = 0, guide_top: int | None = None, return_sequence: bool = False, verbose: bool = False, kv_caches: List[KVCache] | None = None, jacobi: bool = False, **sampling_kwargs, ) -> torch.Tensor | list[torch.Tensor]: seq, need_caches, kv_caches = [x], (kv_caches is not None), kv_caches or [KVCache() for _ in range(len(self.blocks))] # run the deep block first x = self.reverse_deep(x, y, guidance, verbose, kv_caches, jacobi, need_caches, seq, **sampling_kwargs) # remove guidance if bottom is unconditional if (guide_top is not None or self.cond_top_only) and guidance > 0: guidance, y = 0, y.chunk(2, dim=0)[0] # run the shallow blocks x = self.reverse_shallow(x, y, guidance, verbose, kv_caches, jacobi, need_caches, seq, **sampling_kwargs) return seq if return_sequence else x ################################################################################# # TARFLow Configs # ################################################################################# def TarFlow_XL_1(**kwargs): return Model(num_blocks=6, layers_per_block=[2,2,2,2,10,10], channels=2048, patch_size=1, head_dim=64, rope=1, **kwargs) def TarFlow_XL_2(**kwargs): return Model(num_blocks=6, layers_per_block=[2,2,2,2,10,10], channels=2048, patch_size=2, head_dim=64, rope=1, **kwargs) def TarFlow_XXL_1(**kwargs): return Model(num_blocks=6, layers_per_block=[2,2,2,2,13,13], channels=3072, patch_size=1, head_dim=64, rope=1, **kwargs) def TarFlow_XLv2_1(**kwargs): # 1.4B return Model(num_blocks=6, layers_per_block=[2,2,2,2,2,18], channels=2048, patch_size=1, head_dim=64, rope=1, **kwargs) def TarFlow_XXLv2_1(**kwargs): # 4B return Model(num_blocks=6, layers_per_block=[2,2,2,2,2,24], channels=3072, patch_size=1, head_dim=64, rope=1, **kwargs) def TarFlow_Gemma2B(**kwargs): # 2B return Model(num_blocks=6, layers_per_block=[2,2,2,2,2,26], channels=2304, patch_size=1, rope=1, use_rope=True, hf_style_rope=True, use_adaln=False, use_swiglu=True, use_qk_norm=False, use_post_norm=True, use_final_norm=True, use_bias=False, norm_type="rms_norm", num_heads=8, num_kv_heads=4, head_dim=256, **kwargs) # Pre-trained model configs pre_model_configs = { "TarFlow_XL_1": TarFlow_XL_1, "TarFlow_XLv2_1": TarFlow_XLv2_1, "TarFlow_XL_2": TarFlow_XL_2, "TarFlow_XXL_1": TarFlow_XXL_1, "TarFlow_XXLv2_1": TarFlow_XXLv2_1, } ################################################################################# # Pretrained LLMs # ################################################################################# gemma3_4b_kwargs = dict( use_rope=True, hf_style_rope=True, use_adaln=False, use_swiglu=True, use_qk_norm=True, use_post_norm=True, use_final_norm=True, use_bias=False, norm_type="rms_norm", num_heads=8, num_kv_heads=4, head_dim=256, channels=2560, num_layers=34, use_proj_txt=False) gemma3_1b_kwargs = dict( use_rope=True, hf_style_rope=True, use_adaln=False, use_swiglu=True, use_qk_norm=True, use_post_norm=True, use_final_norm=True, use_bias=False, norm_type="rms_norm", num_heads=4, num_kv_heads=1, head_dim=256, channels=1152, expansion=6, num_layers=26, use_proj_txt=False) gemma2_2b_kwargs = dict( use_rope=True, hf_style_rope=True, use_adaln=False, use_swiglu=True, use_qk_norm=False, use_post_norm=True, use_final_norm=True, use_bias=False, norm_type="rms_norm", num_heads=8, num_kv_heads=4, head_dim=256, channels=2304, num_layers=26, use_proj_txt=False)