Spaces:
Runtime error
Runtime error
File size: 6,794 Bytes
f3d0a26 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 | # pipeline_adapter.py
import numpy as np
import tempfile
from utils.video_utils import load_video, save_video
import numpy as np
from skimage.metrics import peak_signal_noise_ratio, structural_similarity
def compute_psnr(original, result):
"""Mean PSNR across all frames."""
scores = []
for f1, f2 in zip(original, result):
scores.append(peak_signal_noise_ratio(f1, f2, data_range=255))
return float(np.mean(scores))
def compute_ssim_video(original, result):
"""Mean SSIM across all frames."""
scores = []
for f1, f2 in zip(original, result):
scores.append(structural_similarity(f1, f2, channel_axis=-1, data_range=255))
return float(np.mean(scores))
def compute_lpips_video(original, result, device="cuda"):
"""Mean LPIPS across all frames (lower = better)."""
import torch
import lpips
loss_fn = lpips.LPIPS(net="alex").to(device)
scores = []
for f1, f2 in zip(original, result):
# Convert [H, W, 3] uint8 β [1, 3, H, W] float in [-1, 1]
t1 = torch.from_numpy(f1).permute(2, 0, 1).unsqueeze(0).float() / 127.5 - 1.0
t2 = torch.from_numpy(f2).permute(2, 0, 1).unsqueeze(0).float() / 127.5 - 1.0
t1, t2 = t1.to(device), t2.to(device)
with torch.no_grad():
score = loss_fn(t1, t2)
scores.append(score.item())
return float(np.mean(scores))
def extract_first_frame(video_path: str) -> np.ndarray:
frames = load_video(video_path, max_frames=1)
return frames[0]
def load_all_frames(video_path: str) -> np.ndarray:
return load_video(video_path, max_frames=81)
def run_pipeline_motion_edit(
video_path: str,
start_box: list,
end_box: list,
prompt: str,
stage1_method: str = "linear",
use_vace: bool = False,
progress_callback=None
) -> tuple:
from pipeline import TRACEPrototype
from stage1_approx import stage1_linear, stage1_cotracker
# from evaluation.metrics import (
# compute_psnr, compute_ssim_video, compute_lpips_video
# )
if progress_callback:
progress_callback(0.1, "Loading video...")
frames = load_all_frames(video_path)
T, H, W, _ = frames.shape
keyboxes = {0: start_box, T - 1: end_box}
proto = TRACEPrototype(
use_vace=use_vace,
use_cotracker=(stage1_method == "cotracker")
)
if progress_callback:
progress_callback(0.3, "Computing trajectory...")
if stage1_method == "cotracker" and proto.cotracker is not None:
pred_boxes = stage1_cotracker(frames, keyboxes, proto.cotracker)
else:
pred_boxes = stage1_linear(keyboxes, T)
if progress_callback:
progress_callback(0.5, "Running video synthesis...")
result = proto.run_motion_edit(
video_path=video_path,
keyboxes=keyboxes,
text_prompt=prompt,
output_path=None
)
tmp = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
save_video(result, tmp.name)
if progress_callback:
progress_callback(0.9, "Computing metrics...")
psnr = compute_psnr(result, frames)
ssim = compute_ssim_video(result, frames)
lpips = compute_lpips_video(result, frames)
metrics_text = (
f"**Video Quality**\n"
f"- PSNR: {psnr:.2f} dB (TRACE paper: 20.48)\n"
f"- SSIM: {ssim:.3f} (TRACE paper: 0.71)\n"
f"- LPIPS: {lpips:.3f} (TRACE paper: 0.19)\n\n"
f"**Settings**\n"
f"- Stage 1: `{stage1_method}`\n"
f"- Frames: {T} | Resolution: {W}x{H}\n"
)
if progress_callback:
progress_callback(1.0, "Done!")
return tmp.name, result, pred_boxes, metrics_text
def run_pipeline_insertion(
video_path: str,
edited_first_frame: np.ndarray, # Qwen/FLUX output β already edited
start_box: list,
end_box: list,
prompt: str,
use_vace: bool = False,
progress_callback=None
) -> tuple:
"""
Run insertion pipeline using a pre-edited first frame.
The first frame has already been modified by Qwen or FLUX-Fill
before this function is called β this function handles
the trajectory + video synthesis steps only.
"""
from pipeline import TRACEPrototype
from stage1_approx import stage1_linear
from stage2_vace import VACEWrapper, SimpleCompositeStage2
from utils.box_utils import boxes_to_mask_sequence
#from evaluation.metrics import compute_psnr, compute_ssim_video
if progress_callback:
progress_callback(0.1, "Loading video...")
frames = load_all_frames(video_path)
T, H, W, _ = frames.shape
keyboxes = {0: start_box, T - 1: end_box}
if progress_callback:
progress_callback(0.3, "Computing trajectory...")
# Stage 1: interpolate trajectory
# (cotracker optional β linear fine for insertion prototype)
pred_boxes = stage1_linear(keyboxes, T)
# Build masks
synthesis_masks = boxes_to_mask_sequence(pred_boxes, H, W)
# No inpainting mask β object wasn't in original video
inpaint_masks = np.zeros_like(synthesis_masks)
if progress_callback:
progress_callback(0.5, "Running video synthesis...")
if use_vace:
stage2 = VACEWrapper()
result = stage2.synthesize(
original_frames=frames,
synthesis_masks=synthesis_masks,
inpaint_masks=inpaint_masks,
first_frame_ref=edited_first_frame, # β Qwen-edited frame
text_prompt=prompt
)
else:
# Debug mode: simple alpha compositing
stage2 = SimpleCompositeStage2()
x1, y1, x2, y2 = [int(v) for v in start_box]
obj_crop = edited_first_frame[y1:y2, x1:x2]
# Build object mask from non-black pixels in crop
obj_mask = (obj_crop.sum(axis=2) > 10).astype(np.float32)
result = stage2.synthesize(
original_frames=frames,
synthesis_masks=synthesis_masks,
inpaint_masks=inpaint_masks,
object_crop=obj_crop,
object_mask=obj_mask
)
if progress_callback:
progress_callback(0.9, "Saving output...")
tmp = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
save_video(result, tmp.name)
psnr = compute_psnr(result, frames)
ssim = compute_ssim_video(result, frames)
metrics_text = (
f"**Insertion Result**\n"
f"- PSNR: {psnr:.2f} dB\n"
f"- SSIM: {ssim:.3f}\n\n"
f"**Settings**\n"
f"- First frame editor: Qwen/FLUX (run separately)\n"
f"- VACE synthesis: {'on' if use_vace else 'off (debug mode)'}\n"
f"- Frames: {T} | Resolution: {W}x{H}\n"
)
if progress_callback:
progress_callback(1.0, "Done!")
return tmp.name, result, pred_boxes, metrics_text
|