StitchTool / app.py
Shalmoni's picture
Update app.py
938e7a2 verified
raw
history blame
10.4 kB
import os, io, time, base64, random, subprocess
from typing import Optional, List
from urllib.parse import quote
import requests
from PIL import Image
import gradio as gr
# -------- Modal inference endpoint (dev) --------
INFERENCE_URL = "https://moonmath-ai-dev--moonmath-i2v-backend-moonmathinference-run.modal.run"
# -------- settings --------
MAX_SLOTS = 12 # max image slots user can reveal
# -------- small helpers --------
def _save_video_bytes(data: bytes, tag: str) -> str:
os.makedirs("/tmp", exist_ok=True)
path = f"/tmp/{tag}_{int(time.time())}.mp4"
with open(path, "wb") as f:
f.write(data)
return path
def _png_bytes(img: Image.Image) -> bytes:
buf = io.BytesIO()
img.save(buf, format="PNG")
return buf.getvalue()
def _download_to_bytes(url: str) -> bytes:
r = requests.get(url, timeout=180)
r.raise_for_status()
return r.content
def stitch_call(start_img: Image.Image, end_img: Image.Image, prompt: str, seed: Optional[int]) -> Optional[str]:
"""
Calls your Modal backend with two images + prompt + seed and returns a local /tmp video path.
"""
if start_img is None or end_img is None:
return None
if seed in (None, 0, -1):
seed = random.randint(1, 2**31 - 1)
url = f"{INFERENCE_URL}?prompt={quote(prompt or '')}&seed={seed}"
files = {
"image_bytes": ("start.png", _png_bytes(start_img), "image/png"),
"image_bytes_end": ("end.png", _png_bytes(end_img), "image/png"),
}
headers = {"accept": "application/json"}
try:
resp = requests.post(url, files=files, headers=headers, timeout=600)
ctype = (resp.headers.get("content-type") or "").lower()
# Raw video bytes
if "application/json" not in ctype:
resp.raise_for_status()
return _save_video_bytes(resp.content, "stitch")
# JSON with url or base64
data = resp.json()
video_url = data.get("video_url") or data.get("url") or data.get("result") or data.get("output")
if isinstance(video_url, str) and video_url.startswith(("http://", "https://")):
return _save_video_bytes(_download_to_bytes(video_url), "stitch")
video_b64 = data.get("video_b64") or data.get("videoBase64")
if isinstance(video_b64, str):
pad = (-len(video_b64)) % 4
if pad:
video_b64 += "=" * pad
return _save_video_bytes(base64.b64decode(video_b64), "stitch")
except Exception as e:
print("stitch_call error:", e)
return None
# -------- FFmpeg-based concatenation (N clips) --------
def concat_many(videos: List[str]) -> Optional[str]:
vids = [v for v in videos if v]
if len(vids) < 2:
return None
try:
os.makedirs("/tmp", exist_ok=True)
out_path = f"/tmp/final_{int(time.time())}.mp4"
list_file = f"/tmp/list_{int(time.time())}.txt"
with open(list_file, "w") as f:
for v in vids:
f.write(f"file '{v}'\n")
subprocess.run(
["ffmpeg", "-y", "-f", "concat", "-safe", "0", "-i", list_file, "-c", "copy", out_path],
check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
return out_path
except Exception as e:
print("concat_many error:", e)
return None
# -------- Timeline HTML renderer --------
def render_timeline_html(paths: List[str]):
vids = [p for p in (paths or []) if p]
if not vids:
return "<div class='tl-grid tl-empty'>No clips yet. Generate and click ‘Add to timeline’.</div>"
items = []
for i, p in enumerate(vids, 1):
items.append(
f"""
<div class="tl-item">
<video src="{p}" controls playsinline></video>
<div class="tl-label">Clip {i}</div>
</div>
"""
)
return f"<div class='tl-grid'>{''.join(items)}</div>"
# =========================
# Gradio callbacks / state ops
# =========================
def add_image_slot(visible_slots: int):
"""Reveal one more upload slot (up to MAX_SLOTS)."""
return min(MAX_SLOTS, int(visible_slots) + 1)
def _reveal_slots(n, *imgs):
"""Update visibility of image upload components based on visible_slots state."""
n = int(n)
updates = []
for i in range(MAX_SLOTS):
updates.append(gr.update(visible=(i < n)))
return updates
def collect_choices(*imgs):
"""Build dropdown choices of available indices (1-based labels) based on non-empty slots."""
choices = []
for i, img in enumerate(imgs, start=1):
if img is not None:
choices.append(str(i))
return gr.update(choices=choices), gr.update(choices=choices)
def stitch_selected(prompt, seed, start_idx_str, end_idx_str, *imgs):
"""Run inference for selected start/end indices (1-based strings)."""
if not start_idx_str or not end_idx_str:
gr.Warning("Please select Start and End frames.")
return None
try:
s = int(start_idx_str) - 1
e = int(end_idx_str) - 1
except Exception:
gr.Warning("Invalid Start/End selection.")
return None
if s < 0 or e < 0 or s >= len(imgs) or e >= len(imgs):
gr.Warning("Start/End out of range.")
return None
start_img = imgs[s]
end_img = imgs[e]
if start_img is None or end_img is None:
gr.Warning("Selected slots are empty.")
return None
vid = stitch_call(start_img, end_img, prompt or "", int(seed or 0))
if not vid:
gr.Warning("Generation failed.")
return None
return vid # path for preview
def add_to_timeline(preview_path, timeline_paths: List[str]):
"""Append preview to timeline; return updated state and HTML."""
tl = list(timeline_paths or [])
if not preview_path:
gr.Warning("Generate a clip first.")
return tl, gr.update(value=render_timeline_html(tl))
tl.append(preview_path)
return tl, gr.update(value=render_timeline_html(tl))
def stitch_all_from_timeline(timeline_paths: List[str]):
vids = list(timeline_paths or [])
if len(vids) < 2:
gr.Warning("Add at least two clips to the timeline first.")
return None
out = concat_many(vids)
if not out:
gr.Warning("Failed to concatenate clips.")
return out
# =========================
# UI
# =========================
CSS = """
.gradio-container { padding: 24px; }
.pill button { border-radius: 999px !important; padding: 10px 18px; }
.rounded textarea { border-radius: 16px !important; }
.gallery-row { display:flex; gap:16px; overflow-x:auto; padding:8px 4px; }
.gallery-row .gradio-image { min-width: 220px; }
.tl-grid {
display: grid;
grid-template-columns: repeat(auto-fill, minmax(180px, 1fr));
gap: 12px;
}
.stitch-box {
background-color: #f0f4ff; /* pick any color you like */
border-radius: 12px;
padding: 16px;
}
.tl-grid video {
width: 100%;
height: 120px;
object-fit: cover;
border-radius: 12px;
display: block;
}
.tl-label {
font-size: 12px;
color: #9aa0a6;
margin-top: 4px;
text-align: center;
}
.tl-empty { color: #9aa0a6; padding: 8px 4px; }
"""
with gr.Blocks(css=CSS, title="StitchTool") as demo:
gr.Markdown("## StitchTool")
# --- State ---
visible_slots = gr.State(value=3) # number of visible image slots
timeline_state = gr.State(value=[]) # list[str] of video file paths (timeline)
# --- Image gallery (horizontal, grows on demand) ---
with gr.Row(elem_classes=["gallery-row"]):
img_comps = []
for i in range(MAX_SLOTS):
comp = gr.Image(label=f"Image {i+1} upload", type="pil", visible=(i < 3))
img_comps.append(comp)
add_btn = gr.Button("+ Add image")
# clicking add → reveal one more slot
add_btn.click(
fn=add_image_slot,
inputs=[visible_slots],
outputs=[visible_slots],
)
# reflect visibility changes whenever visible_slots changes
visible_slots.change(
fn=_reveal_slots,
inputs=[visible_slots] + img_comps,
outputs=img_comps
)
# Seed + Start/End selection + Prompt + Stitch + Preview
seed = gr.Number(value=0, precision=0, label="Seed (0 = random)")
with gr.Row():
# Left column: controls (with colored background via .stitch-box)
with gr.Column(scale=1, min_width=420, elem_classes=["stitch-box"]):
start_dd = gr.Dropdown(label="Start frame", choices=[], interactive=True)
end_dd = gr.Dropdown(label="End frame", choices=[], interactive=True)
prompt = gr.Textbox(
placeholder="Describe the transition between the selected start and end frames…",
lines=3,
label="Prompt",
elem_classes=["rounded"]
)
run_btn = gr.Button("Generate", elem_classes=["pill"])
add_tl_btn = gr.Button("Add to timeline", elem_classes=["pill"])
# Right column: preview video
with gr.Column(scale=1, min_width=420):
preview = gr.Video(label="Video output", interactive=False)
# keep start/end dropdowns up to date based on which slots have images
for comp in img_comps:
comp.change(
fn=collect_choices,
inputs=img_comps,
outputs=[start_dd, end_dd]
)
# stitch action → preview
run_btn.click(
fn=stitch_selected,
inputs=[prompt, seed, start_dd, end_dd] + img_comps,
outputs=[preview]
)
# --- Dynamic timeline (no placeholders) ---
with gr.Row():
timeline_html = gr.HTML(value=render_timeline_html([]))
add_tl_btn.click(
fn=add_to_timeline,
inputs=[preview, timeline_state],
outputs=[timeline_state, timeline_html]
)
# final stitch all (concatenate in order)
with gr.Row():
with gr.Column(scale=1, min_width=420):
stitch_all_btn = gr.Button("Stitch All", elem_classes=["pill"])
with gr.Column(scale=1, min_width=420):
final_vid = gr.Video(label="Stitched Video Output", interactive=False)
stitch_all_btn.click(
fn=stitch_all_from_timeline,
inputs=[timeline_state],
outputs=[final_vid]
)
if __name__ == "__main__":
demo.queue().launch()