ObjectInsertion / pipeline_adapter.py
Leema Krishna Murali
Initial commit
f3d0a26
# pipeline_adapter.py
import numpy as np
import tempfile
from utils.video_utils import load_video, save_video
import numpy as np
from skimage.metrics import peak_signal_noise_ratio, structural_similarity
def compute_psnr(original, result):
"""Mean PSNR across all frames."""
scores = []
for f1, f2 in zip(original, result):
scores.append(peak_signal_noise_ratio(f1, f2, data_range=255))
return float(np.mean(scores))
def compute_ssim_video(original, result):
"""Mean SSIM across all frames."""
scores = []
for f1, f2 in zip(original, result):
scores.append(structural_similarity(f1, f2, channel_axis=-1, data_range=255))
return float(np.mean(scores))
def compute_lpips_video(original, result, device="cuda"):
"""Mean LPIPS across all frames (lower = better)."""
import torch
import lpips
loss_fn = lpips.LPIPS(net="alex").to(device)
scores = []
for f1, f2 in zip(original, result):
# Convert [H, W, 3] uint8 β†’ [1, 3, H, W] float in [-1, 1]
t1 = torch.from_numpy(f1).permute(2, 0, 1).unsqueeze(0).float() / 127.5 - 1.0
t2 = torch.from_numpy(f2).permute(2, 0, 1).unsqueeze(0).float() / 127.5 - 1.0
t1, t2 = t1.to(device), t2.to(device)
with torch.no_grad():
score = loss_fn(t1, t2)
scores.append(score.item())
return float(np.mean(scores))
def extract_first_frame(video_path: str) -> np.ndarray:
frames = load_video(video_path, max_frames=1)
return frames[0]
def load_all_frames(video_path: str) -> np.ndarray:
return load_video(video_path, max_frames=81)
def run_pipeline_motion_edit(
video_path: str,
start_box: list,
end_box: list,
prompt: str,
stage1_method: str = "linear",
use_vace: bool = False,
progress_callback=None
) -> tuple:
from pipeline import TRACEPrototype
from stage1_approx import stage1_linear, stage1_cotracker
# from evaluation.metrics import (
# compute_psnr, compute_ssim_video, compute_lpips_video
# )
if progress_callback:
progress_callback(0.1, "Loading video...")
frames = load_all_frames(video_path)
T, H, W, _ = frames.shape
keyboxes = {0: start_box, T - 1: end_box}
proto = TRACEPrototype(
use_vace=use_vace,
use_cotracker=(stage1_method == "cotracker")
)
if progress_callback:
progress_callback(0.3, "Computing trajectory...")
if stage1_method == "cotracker" and proto.cotracker is not None:
pred_boxes = stage1_cotracker(frames, keyboxes, proto.cotracker)
else:
pred_boxes = stage1_linear(keyboxes, T)
if progress_callback:
progress_callback(0.5, "Running video synthesis...")
result = proto.run_motion_edit(
video_path=video_path,
keyboxes=keyboxes,
text_prompt=prompt,
output_path=None
)
tmp = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
save_video(result, tmp.name)
if progress_callback:
progress_callback(0.9, "Computing metrics...")
psnr = compute_psnr(result, frames)
ssim = compute_ssim_video(result, frames)
lpips = compute_lpips_video(result, frames)
metrics_text = (
f"**Video Quality**\n"
f"- PSNR: {psnr:.2f} dB (TRACE paper: 20.48)\n"
f"- SSIM: {ssim:.3f} (TRACE paper: 0.71)\n"
f"- LPIPS: {lpips:.3f} (TRACE paper: 0.19)\n\n"
f"**Settings**\n"
f"- Stage 1: `{stage1_method}`\n"
f"- Frames: {T} | Resolution: {W}x{H}\n"
)
if progress_callback:
progress_callback(1.0, "Done!")
return tmp.name, result, pred_boxes, metrics_text
def run_pipeline_insertion(
video_path: str,
edited_first_frame: np.ndarray, # Qwen/FLUX output β€” already edited
start_box: list,
end_box: list,
prompt: str,
use_vace: bool = False,
progress_callback=None
) -> tuple:
"""
Run insertion pipeline using a pre-edited first frame.
The first frame has already been modified by Qwen or FLUX-Fill
before this function is called β€” this function handles
the trajectory + video synthesis steps only.
"""
from pipeline import TRACEPrototype
from stage1_approx import stage1_linear
from stage2_vace import VACEWrapper, SimpleCompositeStage2
from utils.box_utils import boxes_to_mask_sequence
#from evaluation.metrics import compute_psnr, compute_ssim_video
if progress_callback:
progress_callback(0.1, "Loading video...")
frames = load_all_frames(video_path)
T, H, W, _ = frames.shape
keyboxes = {0: start_box, T - 1: end_box}
if progress_callback:
progress_callback(0.3, "Computing trajectory...")
# Stage 1: interpolate trajectory
# (cotracker optional β€” linear fine for insertion prototype)
pred_boxes = stage1_linear(keyboxes, T)
# Build masks
synthesis_masks = boxes_to_mask_sequence(pred_boxes, H, W)
# No inpainting mask β€” object wasn't in original video
inpaint_masks = np.zeros_like(synthesis_masks)
if progress_callback:
progress_callback(0.5, "Running video synthesis...")
if use_vace:
stage2 = VACEWrapper()
result = stage2.synthesize(
original_frames=frames,
synthesis_masks=synthesis_masks,
inpaint_masks=inpaint_masks,
first_frame_ref=edited_first_frame, # ← Qwen-edited frame
text_prompt=prompt
)
else:
# Debug mode: simple alpha compositing
stage2 = SimpleCompositeStage2()
x1, y1, x2, y2 = [int(v) for v in start_box]
obj_crop = edited_first_frame[y1:y2, x1:x2]
# Build object mask from non-black pixels in crop
obj_mask = (obj_crop.sum(axis=2) > 10).astype(np.float32)
result = stage2.synthesize(
original_frames=frames,
synthesis_masks=synthesis_masks,
inpaint_masks=inpaint_masks,
object_crop=obj_crop,
object_mask=obj_mask
)
if progress_callback:
progress_callback(0.9, "Saving output...")
tmp = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
save_video(result, tmp.name)
psnr = compute_psnr(result, frames)
ssim = compute_ssim_video(result, frames)
metrics_text = (
f"**Insertion Result**\n"
f"- PSNR: {psnr:.2f} dB\n"
f"- SSIM: {ssim:.3f}\n\n"
f"**Settings**\n"
f"- First frame editor: Qwen/FLUX (run separately)\n"
f"- VACE synthesis: {'on' if use_vace else 'off (debug mode)'}\n"
f"- Frames: {T} | Resolution: {W}x{H}\n"
)
if progress_callback:
progress_callback(1.0, "Done!")
return tmp.name, result, pred_boxes, metrics_text