Spaces:
Sleeping
Sleeping
| # | |
| # For licensing see accompanying LICENSE file. | |
| # Copyright (C) 2025 Apple Inc. All Rights Reserved. | |
| # | |
| from math import pi, sqrt | |
| import torch | |
| from torch import nn | |
| from einops import rearrange, repeat | |
| def broadcat(tensors, dim = -1): | |
| num_tensors = len(tensors) | |
| shape_lens = set(list(map(lambda t: len(t.shape), tensors))) | |
| assert len(shape_lens) == 1, 'tensors must all have the same number of dimensions' | |
| shape_len = list(shape_lens)[0] | |
| dim = (dim + shape_len) if dim < 0 else dim | |
| dims = list(zip(*map(lambda t: list(t.shape), tensors))) | |
| expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim] | |
| assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), 'invalid dimensions for broadcastable concatentation' | |
| max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims)) | |
| expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims)) | |
| expanded_dims.insert(dim, (dim, dims[dim])) | |
| expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims))) | |
| tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes))) | |
| return torch.cat(tensors, dim = dim) | |
| def rotate_half(x): | |
| x = rearrange(x, '... (d r) -> ... d r', r = 2) | |
| x1, x2 = x.unbind(dim = -1) | |
| x = torch.stack((-x2, x1), dim = -1) | |
| return rearrange(x, '... d r -> ... (d r)') | |
| def apply_rope(t, freqs): | |
| return t * freqs.cos() + rotate_half(t) * freqs.sin() | |
| def get_positions(h=0, w=0, txt_size=0, pt_seq_len=None, duplicate=0, mode='3d'): | |
| assert mode in ['1d', '2d', '3d'], "mode must be one of ['1d', '2d', '3d']" | |
| assert h * w + txt_size > 0, "at least one of img_size or txt_size must be greater than 0" | |
| mean_len = sqrt(h * w) | |
| pt_seq_len = pt_seq_len or mean_len | |
| if mode == '1d': | |
| pos_txt = torch.arange(txt_size) | |
| pos_img = torch.arange(h * w) # / (h * w) * (pt_seq_len ** 2) | |
| pos = torch.cat([pos_txt, pos_img + txt_size], dim=0).unsqueeze(-1) | |
| else: | |
| assert h * w > 0, "2D/3D RoPE requires img_size > 0" | |
| px = torch.arange(h) / mean_len * pt_seq_len | |
| py = torch.arange(w) / mean_len * pt_seq_len | |
| px, py = [pi.reshape(-1) for pi in torch.meshgrid(px, py, indexing='ij')] | |
| if mode == '2d': | |
| assert txt_size == 0, "2D RoPE does not support text conditioning" | |
| pos = [px, py] | |
| else: # mode == '3d' | |
| if duplicate == 0: | |
| pos = [px, py, torch.zeros_like(px)] | |
| else: # it has sequence length, this is for VideoData | |
| pos = [torch.cat([px for _ in range(duplicate)]), | |
| torch.cat([py for _ in range(duplicate)]), | |
| torch.arange(duplicate).repeat_interleave(h * w)] | |
| if txt_size > 0: # text is used as conditioned | |
| pt = torch.arange(txt_size) / txt_size * pt_seq_len | |
| pos = [ torch.cat([torch.zeros_like(pt), pos[0]]), | |
| torch.cat([torch.zeros_like(pt), pos[1]]), | |
| torch.cat([pt, pos[2]])] | |
| pos = torch.stack(pos, dim=-1) | |
| return pos | |
| class VisionRotaryEmbeddingFast(nn.Module): | |
| def __init__( | |
| self, | |
| dim, # half-dim | |
| pt_seq_len=16, | |
| ft_seq_len=None, | |
| latent_len=0, | |
| custom_freqs = None, | |
| freqs_for = 'lang', | |
| theta = 10000, | |
| max_freq = 10, | |
| num_freqs = 1, | |
| dim_split=None, | |
| no_buffer=False, | |
| is_1d=False, | |
| ): | |
| super().__init__() | |
| # length is normalized to pt_seq_len | |
| if is_1d: # standard 1D-RoPE | |
| assert freqs_for == 'lang', "RoPE for language settings" | |
| dim_split, dim = [dim], 2 * dim | |
| self.freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) | |
| else: | |
| if ft_seq_len is None: | |
| ft_seq_len = pt_seq_len | |
| if latent_len > 0: | |
| if dim_split is None: dim_split = [dim - 8, 8] | |
| dim, latent_dim = dim_split | |
| else: | |
| dim_split = [dim] | |
| if custom_freqs: | |
| self.freqs = custom_freqs | |
| elif freqs_for == 'lang': | |
| self.freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) | |
| elif freqs_for == 'pixel': | |
| self.freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi | |
| elif freqs_for == 'constant': | |
| self.freqs = torch.ones(num_freqs).float() | |
| else: | |
| raise ValueError(f'unknown modality {freqs_for}') | |
| if latent_len > 0: | |
| self.freqs2 = 1. / (theta ** (torch.arange(0, latent_dim).float() / latent_dim)) | |
| self.is_1d = is_1d | |
| self.pt_seq_len = pt_seq_len | |
| self.ft_seq_len = ft_seq_len | |
| self.latent_len = latent_len | |
| # NOTE: deprecated (do not touch, will affect old checkpoints) # | |
| if not no_buffer and pt_seq_len > 0: | |
| _deprecated = torch.zeros(pt_seq_len ** 2, sum(dim_split) * 2) | |
| if latent_len > 0: | |
| _deprecated = torch.cat([torch.zeros(latent_len, sum(dim_split) * 2), _deprecated], dim=0) | |
| self.register_buffer("freqs_cos", _deprecated) | |
| self.register_buffer("freqs_sin", _deprecated) | |
| # ------------------------------------------------------------ # | |
| def forward(self, pos): | |
| if not isinstance(pos, torch.Tensor): | |
| pos = torch.tensor(pos).to(self.freqs_cos.device) | |
| if not self.is_1d: # this is 2D or 3D rope | |
| assert pos.shape[-1] > 1, "2D/3D RoPE requires multi-dimensional positions" | |
| freqs_all = [ | |
| torch.einsum('..., f -> ... f', pos[..., 0], self.freqs.to(pos.device)), | |
| torch.einsum('..., f -> ... f', pos[..., 1], self.freqs.to(pos.device)), | |
| ] | |
| if pos.shape[-1] == 3: # additional latent dimension (maybe text) | |
| freqs_all.append(torch.einsum('..., f -> ... f', pos[..., 2], self.freqs2.to(pos.device))) | |
| freqs_all = torch.cat(freqs_all, -1) | |
| else: | |
| freqs_all = torch.einsum('..., f -> ... f', pos[..., 0], self.freqs.to(pos.device)) | |
| freqs_all = repeat(freqs_all, '... n -> ... (n r)', r = 2) | |
| return freqs_all | |