starflow / utils /inference.py
leoeric's picture
Initial commit for HF Space - code files only
0b4562b
#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2025 Apple Inc. All Rights Reserved.
#
"""
Inference utilities for STARFlow.
"""
import torch
import datetime
from typing import List
from torchmetrics.image.fid import FrechetInceptionDistance, _compute_fid
from torchmetrics.image.inception import InceptionScore
from torchmetrics.multimodal.clip_score import CLIPScore
from torchmetrics.utilities.data import dim_zero_cat
# Import Distributed from training module
from .training import Distributed
# ==== Metrics ====
class FID(FrechetInceptionDistance):
def __init__(self, feature=2048, reset_real_features=True, normalize=False, input_img_size=..., **kwargs):
super().__init__(feature, reset_real_features, normalize, input_img_size, **kwargs)
self.reset_real_features = reset_real_features
def add_state(self, name, default, *args, **kwargs):
self.register_buffer(name, default)
def manual_compute(self, dist):
# manually gather the features
self.fake_features_num_samples = dist.reduce(self.fake_features_num_samples)
self.fake_features_sum = dist.reduce(self.fake_features_sum)
self.fake_features_cov_sum = dist.reduce(self.fake_features_cov_sum)
if self.reset_real_features:
self.real_features_num_samples = dist.reduce(self.real_features_num_samples)
self.real_features_sum = dist.reduce(self.real_features_sum)
self.real_features_cov_sum = dist.reduce(self.real_features_cov_sum)
print(f'Gathered {self.fake_features_num_samples} samples for FID computation')
# compute FID
mean_real = (self.real_features_sum / self.real_features_num_samples).unsqueeze(0)
mean_fake = (self.fake_features_sum / self.fake_features_num_samples).unsqueeze(0)
cov_real_num = self.real_features_cov_sum - self.real_features_num_samples * mean_real.t().mm(mean_real)
cov_real = cov_real_num / (self.real_features_num_samples - 1)
cov_fake_num = self.fake_features_cov_sum - self.fake_features_num_samples * mean_fake.t().mm(mean_fake)
cov_fake = cov_fake_num / (self.fake_features_num_samples - 1)
if dist.rank == 0:
fid_score = _compute_fid(mean_real.squeeze(0), cov_real, mean_fake.squeeze(0), cov_fake).to(
dtype=self.orig_dtype, device=self.real_features_sum.device)
print(f'FID: {fid_score.item()} DONE')
else:
fid_score = torch.tensor(0.0, dtype=self.orig_dtype, device=self.real_features_sum.device)
dist.barrier()
# reset the state
self.fake_features_num_samples *= 0
self.fake_features_sum *= 0
self.fake_features_cov_sum *= 0
if self.reset_real_features:
self.real_features_num_samples *= 0
self.real_features_sum *= 0
self.real_features_cov_sum *= 0
return fid_score
class IS(InceptionScore):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def manual_compute(self, dist):
# manually gather the features
self.features = dim_zero_cat(self.features)
features = dist.gather_concat(self.features)
print(f'Gathered {features.shape[0]} samples for IS computation')
if dist.rank == 0:
idx = torch.randperm(features.shape[0])
features = features[idx]
# calculate probs and logits
prob = features.softmax(dim=1)
log_prob = features.log_softmax(dim=1)
# split into groups
prob = prob.chunk(self.splits, dim=0)
log_prob = log_prob.chunk(self.splits, dim=0)
# calculate score per split
mean_prob = [p.mean(dim=0, keepdim=True) for p in prob]
kl_ = [p * (log_p - m_p.log()) for p, log_p, m_p in zip(prob, log_prob, mean_prob)]
kl_ = [k.sum(dim=1).mean().exp() for k in kl_]
kl = torch.stack(kl_)
mean = kl.mean()
std = kl.std()
else:
mean = torch.tensor(0.0, device=self.features.device)
std = torch.tensor(0.0, device=self.features.device)
dist.barrier()
return mean, std
class CLIP(CLIPScore):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def manual_compute(self, dist):
# manually gather the features
self.n_samples = dist.reduce(self.n_samples)
self.score = dist.reduce(self.score)
print(f'Gathered {self.n_samples} samples for CLIP computation')
# compute CLIP
clip_score = torch.max(self.score / self.n_samples, torch.zeros_like(self.score))
print(f'CLIP: {clip_score.item()} DONE')
# reset the state
self.n_samples *= 0
self.score *= 0
return clip_score
class Metrics:
def __init__(self):
self.metrics: dict[str, list[float]] = {}
def update(self, metrics: dict[str, torch.Tensor | float]):
for k, v in metrics.items():
if isinstance(v, torch.Tensor):
v = v.item()
if k in self.metrics:
self.metrics[k].append(v)
else:
self.metrics[k] = [v]
def compute(self, dist: Distributed | None) -> dict[str, float]:
out: dict[str, float] = {}
for k, v in self.metrics.items():
v = sum(v) / len(v)
if dist is not None:
v = dist.gather_concat(torch.tensor(v, device='cuda').view(1)).mean().item()
out[k] = v
return out
@staticmethod
def print(metrics: dict[str, float], epoch: int):
print(f'Epoch {epoch} Time {datetime.datetime.now()}')
print('\n'.join((f'\t{k:40s}: {v: .4g}' for k, v in sorted(metrics.items()))))
# ==== Denoising Functions (from starflow_utils.py) ====
def apply_denoising(model, x_chunk: torch.Tensor, y_batch,
text_encoder, tokenizer, args,
text_encoder_kwargs: dict, sigma_curr: float, sigma_next: float = 0) -> torch.Tensor:
"""Apply denoising to a chunk of data."""
from .common import encode_text # Import here to avoid circular imports
noise_std_const = 0.3 # a constant used for noise levels.
# Handle both encoded tensors and raw captions
if isinstance(y_batch, torch.Tensor):
y_ = y_batch
elif y_batch is not None:
y_ = encode_text(text_encoder, tokenizer, y_batch, args.txt_size,
text_encoder.device, **text_encoder_kwargs)
else:
y_ = None
if getattr(args, 'disable_learnable_denoiser', False) or not hasattr(model, 'learnable_self_denoiser'):
return self_denoise(
model, x_chunk, y_,
noise_std=sigma_curr,
steps=1,
disable_learnable_denoiser=getattr(args, 'disable_learnable_denoiser', False)
)
else:
# Learnable denoiser
if sigma_curr is not None and isinstance(y_batch, (list, type(None))):
text_encoder_kwargs['noise_std'] = sigma_curr
denoiser_output = model(x_chunk, y_, denoiser=True)
return x_chunk - denoiser_output * noise_std_const * (sigma_curr - sigma_next) / sigma_curr
def self_denoise(model, samples, y, noise_std=0.1, lr=1, steps=1, disable_learnable_denoiser=False):
"""Self-denoising function - same as in train.py"""
if steps == 0:
return samples
outputs = []
x = samples.clone()
lr = noise_std ** 2 * lr
with torch.enable_grad():
x.requires_grad = True
model.train()
z, _, _, logdets = model(x, y)
loss = model.get_loss(z, logdets)['loss'] * 65536
grad = float(samples.numel()) / 65536 * torch.autograd.grad(loss, [x])[0]
outputs += [(x - grad * lr).detach()]
x = torch.cat(outputs, -1)
return x
def process_denoising(samples: torch.Tensor, y: List[str], args,
model, text_encoder, tokenizer, text_encoder_kwargs: dict,
noise_std: float) -> torch.Tensor:
"""Process samples through denoising if enabled."""
if not (args.finetuned_vae == 'none' and
getattr(args, 'vae_adapter', None) is None and
getattr(args, 'return_sequence', 0) == 0):
# Denoising not enabled or not applicable
return samples
torch.cuda.empty_cache()
assert isinstance(samples, torch.Tensor)
samples = samples.cpu()
# Use smaller batch size for training to avoid memory issues
b = samples.size(0)
db = min(getattr(args, 'denoising_batch_size', 1), b)
denoised_samples = []
is_video = samples.dim() == 5
for j in range(b // db):
x_all = torch.clone(samples[j * db : (j + 1) * db]).detach().cuda()
y_batch = y[j * db : (j + 1) * db] if y is not None else None
if is_video:
# Chunk-wise denoising for videos
s_idx, overlap = 0, 0
steps = x_all.size(1) if getattr(args, 'local_attn_window', None) is None else args.local_attn_window
while s_idx < x_all.size(1):
x_chunk = x_all[:, s_idx : s_idx + steps].detach().clone()
x_denoised = apply_denoising(
model, x_chunk, y_batch, text_encoder, tokenizer,
args, text_encoder_kwargs, noise_std
)
x_all[:, s_idx + overlap: s_idx + steps] = x_denoised[:, overlap:]
overlap = steps - 1 if getattr(args, 'denoiser_window', None) is None else args.denoiser_window
s_idx += steps - overlap
else:
# Process entire batch for images
x_all = apply_denoising(
model, x_all, y_batch, text_encoder, tokenizer,
args, text_encoder_kwargs, noise_std
)
torch.cuda.empty_cache()
denoised_samples.append(x_all.detach().cpu())
return torch.cat(denoised_samples, dim=0).cuda()
def simple_denoising(model, samples: torch.Tensor, y_encoded,
text_encoder, tokenizer, args, noise_std: float) -> torch.Tensor:
"""Simplified denoising for training - reuses apply_denoising without chunking."""
if args.finetuned_vae != 'none' and args.finetuned_vae is not None:
return samples
# Reuse apply_denoising - it now handles both encoded tensors and raw captions
text_encoder_kwargs = {}
return apply_denoising(
model, samples, y_encoded, text_encoder, tokenizer,
args, text_encoder_kwargs, noise_std, sigma_next=0
)