Shalmoni commited on
Commit
aa92cac
Β·
verified Β·
1 Parent(s): 2b3eb94

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -43
app.py CHANGED
@@ -1,15 +1,14 @@
1
- import os, io, time, base64, random, subprocess
2
  from typing import Optional
3
  from urllib.parse import quote
4
-
5
- import requests
6
  from PIL import Image
7
  import gradio as gr
 
8
 
9
  # -------- Modal inference endpoint --------
10
  INFERENCE_URL = "https://moonmath-ai-dev--moonmath-i2v-backend-moonmathinference-run.modal.run"
11
 
12
- # -------- Helpers --------
13
  def _save_video_bytes(data: bytes, tag: str) -> str:
14
  os.makedirs("/tmp", exist_ok=True)
15
  path = f"/tmp/{tag}_{int(time.time())}.mp4"
@@ -46,28 +45,26 @@ def stitch_call(start_img: Image.Image, end_img: Image.Image, prompt: str, seed:
46
  resp = requests.post(url, files=files, headers=headers, timeout=600)
47
  ctype = (resp.headers.get("content-type") or "").lower()
48
 
49
- # Raw video bytes
50
  if "application/json" not in ctype:
51
  resp.raise_for_status()
52
  return _save_video_bytes(resp.content, "stitch")
53
 
54
- # JSON with url or base64
55
  data = resp.json()
56
  video_url = data.get("video_url") or data.get("url") or data.get("result") or data.get("output")
57
- if isinstance(video_url, str) and (video_url.startswith("http://") or video_url.startswith("https://")):
58
  b = _download_to_bytes(video_url)
59
  return _save_video_bytes(b, "stitch")
60
 
61
  video_b64 = data.get("video_b64") or data.get("videoBase64")
62
  if isinstance(video_b64, str):
63
  pad = (-len(video_b64)) % 4
64
- if pad:
65
  video_b64 += "=" * pad
66
  b = base64.b64decode(video_b64)
67
  return _save_video_bytes(b, "stitch")
68
 
69
  except Exception as e:
70
- print("Stitch call failed:", e)
71
 
72
  return None
73
 
@@ -78,7 +75,7 @@ def stitch_12(prompt12, seed, img1, img2):
78
  return None
79
  path = stitch_call(img1, img2, prompt12 or "", int(seed or 0))
80
  if path is None:
81
- gr.Warning("Stitch 1&2 failed. Try again or adjust the prompt.")
82
  return path
83
 
84
  def stitch_23(prompt23, seed, img2, img3):
@@ -87,31 +84,23 @@ def stitch_23(prompt23, seed, img2, img3):
87
  return None
88
  path = stitch_call(img2, img3, prompt23 or "", int(seed or 0))
89
  if path is None:
90
- gr.Warning("Stitch 2&3 failed. Try again or adjust the prompt.")
91
  return path
92
 
93
  def stitch_all(video12, video23):
94
- if not video12 or not video23:
95
- gr.Warning("Please generate both stitched videos first.")
96
  return None
97
-
98
  try:
99
- # Final output path
 
 
100
  out_path = f"/tmp/stitch_all_{int(time.time())}.mp4"
101
-
102
- # Concatenate with ffmpeg
103
- txt_file = f"/tmp/concat_{int(time.time())}.txt"
104
- with open(txt_file, "w") as f:
105
- f.write(f"file '{video12}'\n")
106
- f.write(f"file '{video23}'\n")
107
-
108
- cmd = ["ffmpeg", "-y", "-f", "concat", "-safe", "0", "-i", txt_file, "-c", "copy", out_path]
109
- subprocess.run(cmd, check=True)
110
-
111
  return out_path
112
  except Exception as e:
113
- print("Stitch all failed:", e)
114
- gr.Warning("Failed to stitch all videos together.")
115
  return None
116
 
117
  # -------- UI --------
@@ -121,29 +110,32 @@ CSS = """
121
  .rounded textarea { border-radius: 16px !important; }
122
  """
123
 
124
- with gr.Blocks(css=CSS, title="Stitch β€” 3 uploads, 2 stitches, concat") as demo:
125
  gr.Markdown("## Stitch β€” Upload 3 images, generate videos between 1β†’2 and 2β†’3, then merge them.")
126
 
 
127
  with gr.Row():
128
- with gr.Column(scale=1, min_width=320):
129
- img1 = gr.Image(label="Image 1 upload", type="pil")
130
- img2 = gr.Image(label="Image 2 upload", type="pil")
131
- img3 = gr.Image(label="Image 3 upload", type="pil")
 
132
 
133
- with gr.Column(scale=1, min_width=320):
134
- seed = gr.Number(value=0, precision=0, label="Seed (0 = random)")
135
- prompt12 = gr.Textbox(placeholder="Prompt for stitching 1β†’2", lines=2, label="Prompt (1β†’2)", elem_classes=["rounded"])
136
- btn12 = gr.Button("Stitch 1&2", elem_classes=["pill"])
137
- vid12 = gr.Video(label="Video (image 1+2) output")
138
 
139
- prompt23 = gr.Textbox(placeholder="Prompt for stitching 2β†’3", lines=2, label="Prompt (2β†’3)", elem_classes=["rounded"])
140
- btn23 = gr.Button("Stitch 2&3", elem_classes=["pill"])
141
- vid23 = gr.Video(label="Video (image 2+3) output")
 
142
 
143
- btn_all = gr.Button("Stitch All", elem_classes=["pill"])
144
- vid_all = gr.Video(label="Final concatenated video")
 
145
 
146
- # Wire buttons
147
  btn12.click(stitch_12, inputs=[prompt12, seed, img1, img2], outputs=[vid12])
148
  btn23.click(stitch_23, inputs=[prompt23, seed, img2, img3], outputs=[vid23])
149
  btn_all.click(stitch_all, inputs=[vid12, vid23], outputs=[vid_all])
 
1
+ import os, io, time, base64, random, requests
2
  from typing import Optional
3
  from urllib.parse import quote
 
 
4
  from PIL import Image
5
  import gradio as gr
6
+ from moviepy.editor import VideoFileClip, concatenate_videoclips
7
 
8
  # -------- Modal inference endpoint --------
9
  INFERENCE_URL = "https://moonmath-ai-dev--moonmath-i2v-backend-moonmathinference-run.modal.run"
10
 
11
+ # -------- small helpers --------
12
  def _save_video_bytes(data: bytes, tag: str) -> str:
13
  os.makedirs("/tmp", exist_ok=True)
14
  path = f"/tmp/{tag}_{int(time.time())}.mp4"
 
45
  resp = requests.post(url, files=files, headers=headers, timeout=600)
46
  ctype = (resp.headers.get("content-type") or "").lower()
47
 
 
48
  if "application/json" not in ctype:
49
  resp.raise_for_status()
50
  return _save_video_bytes(resp.content, "stitch")
51
 
 
52
  data = resp.json()
53
  video_url = data.get("video_url") or data.get("url") or data.get("result") or data.get("output")
54
+ if isinstance(video_url, str) and video_url.startswith(("http://", "https://")):
55
  b = _download_to_bytes(video_url)
56
  return _save_video_bytes(b, "stitch")
57
 
58
  video_b64 = data.get("video_b64") or data.get("videoBase64")
59
  if isinstance(video_b64, str):
60
  pad = (-len(video_b64)) % 4
61
+ if pad:
62
  video_b64 += "=" * pad
63
  b = base64.b64decode(video_b64)
64
  return _save_video_bytes(b, "stitch")
65
 
66
  except Exception as e:
67
+ print("stitch_call failed:", e)
68
 
69
  return None
70
 
 
75
  return None
76
  path = stitch_call(img1, img2, prompt12 or "", int(seed or 0))
77
  if path is None:
78
+ gr.Warning("Stitch 1β†’2 failed.")
79
  return path
80
 
81
  def stitch_23(prompt23, seed, img2, img3):
 
84
  return None
85
  path = stitch_call(img2, img3, prompt23 or "", int(seed or 0))
86
  if path is None:
87
+ gr.Warning("Stitch 2β†’3 failed.")
88
  return path
89
 
90
  def stitch_all(video12, video23):
91
+ if video12 is None or video23 is None:
92
+ gr.Warning("Need both videos to stitch all.")
93
  return None
 
94
  try:
95
+ clip1 = VideoFileClip(video12)
96
+ clip2 = VideoFileClip(video23)
97
+ final = concatenate_videoclips([clip1, clip2])
98
  out_path = f"/tmp/stitch_all_{int(time.time())}.mp4"
99
+ final.write_videofile(out_path, codec="libx264", audio=False, verbose=False, logger=None)
 
 
 
 
 
 
 
 
 
100
  return out_path
101
  except Exception as e:
102
+ print("stitch_all failed:", e)
103
+ gr.Warning("Stitch All failed.")
104
  return None
105
 
106
  # -------- UI --------
 
110
  .rounded textarea { border-radius: 16px !important; }
111
  """
112
 
113
+ with gr.Blocks(css=CSS, title="Stitch Master") as demo:
114
  gr.Markdown("## Stitch β€” Upload 3 images, generate videos between 1β†’2 and 2β†’3, then merge them.")
115
 
116
+ # --- Uploads row ---
117
  with gr.Row():
118
+ img1 = gr.Image(label="Image 1 upload", type="pil")
119
+ img2 = gr.Image(label="Image 2 upload", type="pil")
120
+ img3 = gr.Image(label="Image 3 upload", type="pil")
121
+
122
+ seed = gr.Number(value=0, precision=0, label="Seed (0 = random)")
123
 
124
+ # --- First stitch ---
125
+ prompt12 = gr.Textbox(placeholder="Prompt for stitching 1β†’2", lines=2, label="Prompt (1β†’2)", elem_classes=["rounded"])
126
+ btn12 = gr.Button("Stitch 1β†’2", elem_classes=["pill"])
127
+ vid12 = gr.Video(label="Video (image 1+2) output", interactive=False)
 
128
 
129
+ # --- Second stitch ---
130
+ prompt23 = gr.Textbox(placeholder="Prompt for stitching 2β†’3", lines=2, label="Prompt (2β†’3)", elem_classes=["rounded"])
131
+ btn23 = gr.Button("Stitch 2β†’3", elem_classes=["pill"])
132
+ vid23 = gr.Video(label="Video (image 2+3) output", interactive=False)
133
 
134
+ # --- Final merge ---
135
+ btn_all = gr.Button("Stitch All", elem_classes=["pill"])
136
+ vid_all = gr.Video(label="Final concatenated video", interactive=False)
137
 
138
+ # --- Wire buttons ---
139
  btn12.click(stitch_12, inputs=[prompt12, seed, img1, img2], outputs=[vid12])
140
  btn23.click(stitch_23, inputs=[prompt23, seed, img2, img3], outputs=[vid23])
141
  btn_all.click(stitch_all, inputs=[vid12, vid23], outputs=[vid_all])