ObjectInsertion / utils /video_utils.py
Leema Krishna Murali
Initial commit
f3d0a26
# utils/video_utils.py
import cv2
import numpy as np
import imageio
import torch
def load_video(path: str, max_frames: int = 81) -> np.ndarray:
"""
Returns: [T, H, W, 3] uint8 RGB array
"""
cap = cv2.VideoCapture(path)
frames = []
while len(frames) < max_frames:
ret, frame = cap.read()
if not ret:
break
frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
cap.release()
return np.stack(frames)
def save_video(frames: np.ndarray, path: str, fps: int = 24):
"""
frames: [T, H, W, 3] uint8 RGB
"""
writer = imageio.get_writer(path, fps=fps)
for frame in frames:
writer.append_data(frame)
writer.close()
def frames_to_tensor(frames: np.ndarray) -> torch.Tensor:
"""
[T, H, W, 3] uint8 → [T, 3, H, W] float32 in [-1, 1]
"""
t = torch.from_numpy(frames).float() / 127.5 - 1.0
return t.permute(0, 3, 1, 2)
def tensor_to_frames(t: torch.Tensor) -> np.ndarray:
"""
[T, 3, H, W] float32 in [-1, 1] → [T, H, W, 3] uint8
"""
t = ((t + 1.0) * 127.5).clamp(0, 255)
return t.permute(0, 2, 3, 1).byte().numpy()