abreza's picture
revert
9ae90a1
import gradio as gr
import os
import numpy as np
import cv2
import time
import shutil
from pathlib import Path
from einops import rearrange
from typing import Union
try:
import spaces
except ImportError:
def spaces(func):
return func
import torch
import torchvision.transforms as T
import logging
from concurrent.futures import ThreadPoolExecutor
import atexit
import uuid
import decord
from models.SpaTrackV2.models.vggt4track.models.vggt_moe import VGGT4Track
from models.SpaTrackV2.models.vggt4track.utils.load_fn import preprocess_image
from models.SpaTrackV2.models.predictor import Predictor
from models.SpaTrackV2.models.utils import get_points_on_a_grid
from diffusers.utils import export_to_video, load_image
from pipelines.wan_pipeline import WanImageToVideoTTMPipeline
from pipelines.utils import compute_hw_from_area
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
MAX_FRAMES = 81
OUTPUT_FPS = 24
RENDER_WIDTH = 512
RENDER_HEIGHT = 384
WAN_MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers"
CAMERA_MOVEMENTS = [
"static",
"move_forward",
"move_backward",
"move_left",
"move_right",
"move_up",
"move_down"
]
thread_pool_executor = ThreadPoolExecutor(max_workers=2)
def delete_later(path: Union[str, os.PathLike], delay: int = 600):
def _delete():
try:
if os.path.isfile(path):
os.remove(path)
elif os.path.isdir(path):
shutil.rmtree(path)
except Exception as e:
logger.warning(f"Failed to delete {path}: {e}")
def _wait_and_delete():
time.sleep(delay)
_delete()
thread_pool_executor.submit(_wait_and_delete)
atexit.register(_delete)
def create_user_temp_dir():
session_id = str(uuid.uuid4())[:8]
temp_dir = os.path.join("temp_local", f"session_{session_id}")
os.makedirs(temp_dir, exist_ok=True)
delete_later(temp_dir, delay=600)
return temp_dir
print("🚀 Initializing tracking models...")
vggt4track_model = VGGT4Track.from_pretrained(
"Yuxihenry/SpatialTrackerV2_Front")
vggt4track_model.eval()
if not hasattr(vggt4track_model, 'infer'):
vggt4track_model.infer = vggt4track_model.forward
tracker_model = Predictor.from_pretrained("Yuxihenry/SpatialTrackerV2-Offline")
tracker_model.eval()
wan_pipeline = WanImageToVideoTTMPipeline.from_pretrained(
WAN_MODEL_ID,
torch_dtype=torch.bfloat16
)
wan_pipeline.vae.enable_tiling()
wan_pipeline.vae.enable_slicing()
print("✅ Tracking models loaded successfully!")
gr.set_static_paths(paths=[Path.cwd().absolute()/"_viz"])
def generate_camera_trajectory(num_frames: int, movement_type: str, base_intrinsics: np.ndarray, scene_scale: float = 1.0) -> tuple:
speed = scene_scale * 0.02
extrinsics = np.zeros((num_frames, 4, 4), dtype=np.float32)
for t in range(num_frames):
ext = np.eye(4, dtype=np.float32)
if movement_type == "move_forward":
ext[2, 3] = -speed * t
elif movement_type == "move_backward":
ext[2, 3] = speed * t
elif movement_type == "move_left":
ext[0, 3] = -speed * t
elif movement_type == "move_right":
ext[0, 3] = speed * t
elif movement_type == "move_up":
ext[1, 3] = -speed * t
elif movement_type == "move_down":
ext[1, 3] = speed * t
extrinsics[t] = ext
return extrinsics
def render_from_pointcloud(rgb_frames, depth_frames, intrinsics, original_extrinsics, new_extrinsics, output_path, fps=24, generate_ttm_inputs=False):
T, H, W, _ = rgb_frames.shape
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_path, fourcc, fps, (W, H))
motion_signal_path = mask_path = out_motion_signal = out_mask = None
if generate_ttm_inputs:
base_dir = os.path.dirname(output_path)
motion_signal_path = os.path.join(base_dir, "motion_signal.mp4")
mask_path = os.path.join(base_dir, "mask.mp4")
out_motion_signal = cv2.VideoWriter(
motion_signal_path, fourcc, fps, (W, H))
out_mask = cv2.VideoWriter(mask_path, fourcc, fps, (W, H))
u, v = np.meshgrid(np.arange(W), np.arange(H))
for t in range(T):
rgb, depth, K = rgb_frames[t], depth_frames[t], intrinsics[t]
orig_c2w = np.linalg.inv(original_extrinsics[t])
if t == 0:
base_c2w = orig_c2w.copy()
new_c2w = base_c2w @ new_extrinsics[t]
new_w2c = np.linalg.inv(new_c2w)
K_inv = np.linalg.inv(K)
pixels = np.stack([u, v, np.ones_like(u)], axis=-1).reshape(-1, 3)
rays_cam = (K_inv @ pixels.T).T
points_cam = rays_cam * depth.reshape(-1, 1)
points_world = (orig_c2w[:3, :3] @ points_cam.T).T + orig_c2w[:3, 3]
points_new_cam = (new_w2c[:3, :3] @ points_world.T).T + new_w2c[:3, 3]
points_proj = (K @ points_new_cam.T).T
z = np.clip(points_proj[:, 2:3], 1e-6, None)
uv_new = points_proj[:, :2] / z
rendered = np.zeros((H, W, 3), dtype=np.uint8)
z_buffer = np.full((H, W), np.inf, dtype=np.float32)
colors, depths_new = rgb.reshape(-1, 3), points_new_cam[:, 2]
for i in range(len(uv_new)):
uu, vv = int(round(uv_new[i, 0])), int(round(uv_new[i, 1]))
if 0 <= uu < W and 0 <= vv < H and depths_new[i] > 0:
if depths_new[i] < z_buffer[vv, uu]:
z_buffer[vv, uu] = depths_new[i]
rendered[vv, uu] = colors[i]
valid_mask = (rendered.sum(axis=-1) > 0).astype(np.uint8) * 255
motion_signal_frame = rendered.copy()
hole_mask = (motion_signal_frame.sum(axis=-1) == 0).astype(np.uint8)
if hole_mask.sum() > 0:
kernel = np.ones((3, 3), np.uint8)
for _ in range(10): # Iterative fill
if hole_mask.sum() == 0:
break
dilated = cv2.dilate(motion_signal_frame, kernel)
motion_signal_frame = np.where(
hole_mask[:, :, None] > 0, dilated, motion_signal_frame)
hole_mask = (motion_signal_frame.sum(
axis=-1) == 0).astype(np.uint8)
if generate_ttm_inputs:
out_motion_signal.write(cv2.cvtColor(
motion_signal_frame, cv2.COLOR_RGB2BGR))
out_mask.write(np.stack([valid_mask]*3, axis=-1))
out.write(cv2.cvtColor(motion_signal_frame, cv2.COLOR_RGB2BGR))
out.release()
if generate_ttm_inputs:
out_motion_signal.release()
out_mask.release()
return {'rendered': output_path, 'motion_signal': motion_signal_path, 'mask': mask_path}
@spaces.GPU
def run_spatial_tracker(video_tensor: torch.Tensor):
"""
GPU-intensive spatial tracking function.
Args:
video_tensor: Preprocessed video tensor (T, C, H, W)
Returns:
Dictionary containing tracking results
"""
global vggt4track_model
global tracker_model
global wan_pipeline
video_input = preprocess_image(video_tensor)[None].cuda()
vggt4track_model = vggt4track_model.to("cuda")
with torch.no_grad():
with torch.amp.autocast('cuda', dtype=torch.bfloat16):
predictions = vggt4track_model(video_input / 255)
extrinsic = predictions["poses_pred"]
intrinsic = predictions["intrs"]
depth_map = predictions["points_map"][..., 2]
depth_conf = predictions["unc_metric"]
depth_tensor = depth_map.squeeze().cpu().numpy()
extrs = extrinsic.squeeze().cpu().numpy()
intrs = intrinsic.squeeze().cpu().numpy()
video_tensor_gpu = video_input.squeeze()
unc_metric = depth_conf.squeeze().cpu().numpy() > 0.5
tracker_model.spatrack.track_num = 512
tracker_model.to("cuda")
frame_H, frame_W = video_tensor_gpu.shape[2:]
grid_pts = get_points_on_a_grid(30, (frame_H, frame_W), device="cpu")
query_xyt = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)[
0].numpy()
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
(
c2w_traj, intrs_out, point_map, conf_depth,
track3d_pred, track2d_pred, vis_pred, conf_pred, video_out
) = tracker_model.forward(
video_tensor_gpu, depth=depth_tensor,
intrs=intrs, extrs=extrs,
queries=query_xyt,
fps=1, full_point=False, iters_track=4,
query_no_BA=True, fixed_cam=False, stage=1,
unc_metric=unc_metric,
support_frame=len(video_tensor_gpu)-1, replace_ratio=0.2
)
max_size = 384
h, w = video_out.shape[2:]
scale = min(max_size / h, max_size / w)
if scale < 1:
new_h, new_w = int(h * scale), int(w * scale)
video_out = T.Resize((new_h, new_w))(video_out)
point_map = T.Resize((new_h, new_w))(point_map)
conf_depth = T.Resize((new_h, new_w))(conf_depth)
intrs_out[:, :2, :] = intrs_out[:, :2, :] * scale
return {
'video_out': video_out.cpu(),
'point_map': point_map.cpu(),
'conf_depth': conf_depth.cpu(),
'intrs_out': intrs_out.cpu(),
'c2w_traj': c2w_traj.cpu(),
}
@spaces.GPU
def run_wan_ttm_generation(prompt, tweak_index, tstrong_index, first_frame_path, motion_video_path, mask_video_path, progress=gr.Progress()):
if not first_frame_path or not motion_video_path or not mask_video_path:
return None, "❌ TTM Inputs missing. Please run 3D tracking first."
progress(0, desc="Loading Wan TTM Pipeline...")
import decord
vr = decord.VideoReader(motion_video_path)
actual_frame_count = len(vr)
target_num_frames = ((actual_frame_count - 1) // 4) * 4 + 1
if target_num_frames < 5:
return None, f"❌ Video too short. Only {actual_frame_count} frames tracked."
logger.info(f"Setting Wan num_frames to {target_num_frames} based on tracking output.")
progress(0.2, desc="Preparing inputs...")
image = load_image(first_frame_path)
negative_prompt = (
"色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,"
"低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,"
"毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
)
wan_pipeline.to("cuda")
max_area = 480 * 832
mod_value = wan_pipeline.vae_scale_factor_spatial * \
wan_pipeline.transformer.config.patch_size[1]
height, width = compute_hw_from_area(
image.height, image.width, max_area, mod_value)
image = image.resize((width, height))
progress(0.4, desc=f"Generating {target_num_frames} frames (this may take a few minutes)...")
generator = torch.Generator(device="cuda").manual_seed(0)
with torch.inference_mode():
result = wan_pipeline(
image=image,
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
num_frames=target_num_frames,
guidance_scale=3.5,
num_inference_steps=50,
generator=generator,
motion_signal_video_path=motion_video_path,
motion_signal_mask_path=mask_video_path,
tweak_index=int(tweak_index),
tstrong_index=int(tstrong_index),
)
output_path = os.path.join(os.path.dirname(
first_frame_path), "wan_ttm_output.mp4")
export_to_video(result.frames[0], output_path, fps=16)
return output_path, f"✅ TTM Video ({target_num_frames} frames) generated successfully!"
# --- MODIFIED PROCESS VIDEO TO RETURN FILE PATHS ---
def process_video(video_path, camera_movement, generate_ttm=True, progress=gr.Progress()):
if video_path is None:
return None, None, None, None, "❌ Please upload a video first"
progress(0, desc="Initializing...")
temp_dir = create_user_temp_dir()
out_dir = os.path.join(temp_dir, "results")
os.makedirs(out_dir, exist_ok=True)
try:
progress(0.1, desc="Loading video...")
video_reader = decord.VideoReader(video_path)
video_tensor = torch.from_numpy(video_reader.get_batch(
range(len(video_reader))).asnumpy()).permute(0, 3, 1, 2).float()
video_tensor = video_tensor[::max(
1, len(video_tensor)//MAX_FRAMES)][:MAX_FRAMES]
h, w = video_tensor.shape[2:]
scale = 336 / min(h, w)
if scale < 1:
video_tensor = T.Resize(
(int(h*scale)//2*2, int(w*scale)//2*2))(video_tensor)
progress(0.4, desc="Running 3D tracking...")
tracking_results = run_spatial_tracker(video_tensor)
rgb_frames = rearrange(
tracking_results['video_out'].numpy(), "T C H W -> T H W C").astype(np.uint8)
depth_frames = tracking_results['point_map'][:, 2].numpy()
depth_frames[tracking_results['conf_depth'].numpy() < 0.5] = 0
scene_scale = np.median(depth_frames[depth_frames > 0]) if np.any(
depth_frames > 0) else 1.0
new_exts = generate_camera_trajectory(len(
rgb_frames), camera_movement, tracking_results['intrs_out'].numpy(), scene_scale)
progress(0.8, desc="Rendering viewpoint...")
output_video_path = os.path.join(out_dir, "rendered_video.mp4")
render_results = render_from_pointcloud(rgb_frames, depth_frames, tracking_results['intrs_out'].numpy(),
torch.inverse(
tracking_results['c2w_traj']).numpy(),
new_exts, output_video_path, fps=OUTPUT_FPS, generate_ttm_inputs=generate_ttm)
first_frame_path = os.path.join(out_dir, "first_frame.png")
cv2.imwrite(first_frame_path, cv2.cvtColor(
rgb_frames[0], cv2.COLOR_RGB2BGR))
status_msg = f"✅ 3D results ready! You can now use the prompt below to generate a high-quality TTM video."
return render_results['rendered'], render_results['motion_signal'], render_results['mask'], first_frame_path, status_msg
except Exception as e:
logger.error(f"Error: {e}")
return None, None, None, None, f"❌ Error: {str(e)}"
# --- GRADIO INTERFACE ---
with gr.Blocks(theme=gr.themes.Soft(), title="🎬 TTM Wan Video Generator") as demo:
gr.Markdown("# 🎬 Video to Point Cloud & TTM Wan Generator")
gr.Markdown(
"Transform standard videos into 3D-aware motion signals for Time-to-Move (TTM) generation.")
first_frame_file = gr.State("")
motion_signal_file = gr.State("")
mask_file = gr.State("")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### 1. Tracking & Viewpoint")
video_input = gr.Video(label="Upload Video")
camera_movement = gr.Dropdown(
choices=CAMERA_MOVEMENTS,
value="static",
label="Camera Movement"
)
generate_btn = gr.Button(
"🚀 1. Run Spatial Tracker", variant="primary")
output_video = gr.Video(label="Point Cloud Render (Draft)")
status_text = gr.Markdown("Ready...")
with gr.Column(scale=1):
gr.Markdown("### 2. Time-to-Move (Wan 2.2)")
ttm_prompt = gr.Textbox(
label="Prompt",
placeholder="Describe the scene (e.g., 'A monkey walking in the forest, high quality')"
)
with gr.Row():
tweak_idx = gr.Number(
label="Tweak Index", value=3, precision=0)
tstrong_idx = gr.Number(
label="Tstrong Index", value=6, precision=0)
wan_generate_btn = gr.Button(
"✨ 2. Generate TTM Video (Wan)", variant="secondary")
wan_output_video = gr.Video(label="Final High-Quality TTM Video")
wan_status = gr.Markdown("Awaiting 3D inputs...")
with gr.Accordion("Debug: TTM Intermediate Inputs", open=False):
with gr.Row():
motion_signal_output = gr.Video(label="motion_signal.mp4")
mask_output = gr.Video(label="mask.mp4")
first_frame_output = gr.Image(
label="first_frame.png", type="filepath")
generate_btn.click(
fn=process_video,
inputs=[video_input, camera_movement],
outputs=[
output_video,
motion_signal_output,
mask_output,
first_frame_output,
status_text
]
).then(
fn=lambda a, b, c, d, e: (b, c, d),
inputs=[
output_video,
motion_signal_output,
mask_output,
first_frame_output,
status_text
],
outputs=[motion_signal_file, mask_file, first_frame_file]
)
wan_generate_btn.click(
fn=run_wan_ttm_generation,
inputs=[
ttm_prompt,
tweak_idx,
tstrong_idx,
first_frame_file,
motion_signal_file,
mask_file
],
outputs=[wan_output_video, wan_status]
)
if __name__ == "__main__":
demo.launch(share=False)