Spaces:
Sleeping
Sleeping
| import spaces | |
| from functools import lru_cache | |
| import gradio as gr | |
| from gradio_toggle import Toggle | |
| import torch | |
| from huggingface_hub import snapshot_download | |
| from transformers import CLIPProcessor, CLIPModel, pipeline | |
| import random | |
| from xora.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder | |
| from xora.models.transformers.transformer3d import Transformer3DModel | |
| from xora.models.transformers.symmetric_patchifier import SymmetricPatchifier | |
| from xora.schedulers.rf import RectifiedFlowScheduler | |
| from xora.pipelines.pipeline_xora_video import XoraVideoPipeline | |
| from transformers import T5EncoderModel, T5Tokenizer | |
| from xora.utils.conditioning_method import ConditioningMethod | |
| from pathlib import Path | |
| import safetensors.torch | |
| import json | |
| import numpy as np | |
| import cv2 | |
| from PIL import Image | |
| import tempfile | |
| import os | |
| import gc | |
| import csv | |
| from datetime import datetime | |
| from openai import OpenAI | |
| # ํ๊ธ-์์ด ๋ฒ์ญ๊ธฐ ์ด๊ธฐํ | |
| translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en") | |
| torch.backends.cuda.matmul.allow_tf32 = False | |
| torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False | |
| torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False | |
| torch.backends.cudnn.allow_tf32 = False | |
| torch.backends.cudnn.deterministic = False | |
| torch.backends.cuda.preferred_blas_library="cublas" | |
| torch.set_float32_matmul_precision("highest") | |
| MAX_SEED = np.iinfo(np.int32).max | |
| # Load Hugging Face token if needed | |
| hf_token = os.getenv("HF_TOKEN") | |
| openai_api_key = os.getenv("OPENAI_API_KEY") | |
| client = OpenAI(api_key=openai_api_key) | |
| system_prompt_t2v_path = "assets/system_prompt_t2v.txt" | |
| with open(system_prompt_t2v_path, "r") as f: | |
| system_prompt_t2v = f.read() | |
| # Set model download directory within Hugging Face Spaces | |
| model_path = "asset" | |
| commit_hash='c7c8ad4c2ddba847b94e8bfaefbd30bd8669fafc' | |
| if not os.path.exists(model_path): | |
| snapshot_download("Lightricks/LTX-Video", revision=commit_hash, local_dir=model_path, repo_type="model", token=hf_token) | |
| # Global variables to load components | |
| vae_dir = Path(model_path) / "vae" | |
| unet_dir = Path(model_path) / "unet" | |
| scheduler_dir = Path(model_path) / "scheduler" | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32", cache_dir=model_path).to(torch.device("cuda:0")) | |
| clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", cache_dir=model_path) | |
| def process_prompt(prompt): | |
| # ํ๊ธ์ด ํฌํจ๋์ด ์๋์ง ํ์ธ | |
| if any(ord('๊ฐ') <= ord(char) <= ord('ํฃ') for char in prompt): | |
| # ํ๊ธ์ ์์ด๋ก ๋ฒ์ญ | |
| translated = translator(prompt)[0]['translation_text'] | |
| return translated | |
| return prompt | |
| def compute_clip_embedding(text=None): | |
| inputs = clip_processor(text=text, return_tensors="pt", padding=True).to(device) | |
| outputs = clip_model.get_text_features(**inputs) | |
| embedding = outputs.detach().cpu().numpy().flatten().tolist() | |
| return embedding | |
| def load_vae(vae_dir): | |
| vae_ckpt_path = vae_dir / "vae_diffusion_pytorch_model.safetensors" | |
| vae_config_path = vae_dir / "config.json" | |
| with open(vae_config_path, "r") as f: | |
| vae_config = json.load(f) | |
| vae = CausalVideoAutoencoder.from_config(vae_config) | |
| vae_state_dict = safetensors.torch.load_file(vae_ckpt_path) | |
| vae.load_state_dict(vae_state_dict) | |
| return vae.to(device).to(torch.bfloat16) | |
| def load_unet(unet_dir): | |
| unet_ckpt_path = unet_dir / "unet_diffusion_pytorch_model.safetensors" | |
| unet_config_path = unet_dir / "config.json" | |
| transformer_config = Transformer3DModel.load_config(unet_config_path) | |
| transformer = Transformer3DModel.from_config(transformer_config) | |
| unet_state_dict = safetensors.torch.load_file(unet_ckpt_path) | |
| transformer.load_state_dict(unet_state_dict, strict=True) | |
| return transformer.to(device).to(torch.bfloat16) | |
| def load_scheduler(scheduler_dir): | |
| scheduler_config_path = scheduler_dir / "scheduler_config.json" | |
| scheduler_config = RectifiedFlowScheduler.load_config(scheduler_config_path) | |
| return RectifiedFlowScheduler.from_config(scheduler_config) | |
| # Preset options for resolution and frame configuration | |
| preset_options = [ | |
| {"label": "1216x704, 41 frames", "width": 1216, "height": 704, "num_frames": 41}, | |
| {"label": "1088x704, 49 frames", "width": 1088, "height": 704, "num_frames": 49}, | |
| {"label": "1056x640, 57 frames", "width": 1056, "height": 640, "num_frames": 57}, | |
| {"label": "448x448, 100 frames", "width": 448, "height": 448, "num_frames": 100}, | |
| {"label": "448x448, 200 frames", "width": 448, "height": 448, "num_frames": 200}, | |
| {"label": "448x448, 300 frames", "width": 448, "height": 448, "num_frames": 300}, | |
| {"label": "640x640, 80 frames", "width": 640, "height": 640, "num_frames": 80}, | |
| {"label": "640x640, 120 frames", "width": 640, "height": 640, "num_frames": 120}, | |
| {"label": "768x768, 64 frames", "width": 768, "height": 768, "num_frames": 64}, | |
| {"label": "768x768, 90 frames", "width": 768, "height": 768, "num_frames": 90}, | |
| {"label": "720x720, 64 frames", "width": 768, "height": 768, "num_frames": 64}, | |
| {"label": "720x720, 100 frames", "width": 768, "height": 768, "num_frames": 100}, | |
| {"label": "768x512, 97 frames", "width": 768, "height": 512, "num_frames": 97}, | |
| {"label": "512x512, 160 frames", "width": 512, "height": 512, "num_frames": 160}, | |
| {"label": "512x512, 200 frames", "width": 512, "height": 512, "num_frames": 200}, | |
| ] | |
| def preset_changed(preset): | |
| if preset != "Custom": | |
| selected = next(item for item in preset_options if item["label"] == preset) | |
| return ( | |
| selected["height"], | |
| selected["width"], | |
| selected["num_frames"], | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| ) | |
| else: | |
| return ( | |
| None, | |
| None, | |
| None, | |
| gr.update(visible=True), | |
| gr.update(visible=True), | |
| gr.update(visible=True), | |
| ) | |
| # Load models | |
| vae = load_vae(vae_dir) | |
| unet = load_unet(unet_dir) | |
| scheduler = load_scheduler(scheduler_dir) | |
| patchifier = SymmetricPatchifier(patch_size=1) | |
| text_encoder = T5EncoderModel.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="text_encoder").to(torch.device("cuda:0")) | |
| tokenizer = T5Tokenizer.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="tokenizer") | |
| pipeline = XoraVideoPipeline( | |
| transformer=unet, | |
| patchifier=patchifier, | |
| text_encoder=text_encoder, | |
| tokenizer=tokenizer, | |
| scheduler=scheduler, | |
| vae=vae, | |
| ).to(torch.device("cuda:0")) | |
| def enhance_prompt_if_enabled(prompt, enhance_toggle): | |
| if not enhance_toggle: | |
| print("Enhance toggle is off, Prompt: ", prompt) | |
| return prompt | |
| messages = [ | |
| {"role": "system", "content": system_prompt_t2v}, | |
| {"role": "user", "content": prompt}, | |
| ] | |
| try: | |
| response = client.chat.completions.create( | |
| model="gpt-4-mini", | |
| messages=messages, | |
| max_tokens=200, | |
| ) | |
| print("Enhanced Prompt: ", response.choices[0].message.content.strip()) | |
| return response.choices[0].message.content.strip() | |
| except Exception as e: | |
| print(f"Error: {e}") | |
| return prompt | |
| def generate_video_from_text_90( | |
| prompt="", | |
| enhance_prompt_toggle=False, | |
| negative_prompt="", | |
| frame_rate=25, | |
| seed=random.randint(0, MAX_SEED), | |
| num_inference_steps=30, | |
| guidance_scale=3.2, | |
| height=768, | |
| width=768, | |
| num_frames=60, | |
| progress=gr.Progress(), | |
| ): | |
| # ํ๋กฌํํธ ์ ์ฒ๋ฆฌ (ํ๊ธ -> ์์ด) | |
| prompt = process_prompt(prompt) | |
| negative_prompt = process_prompt(negative_prompt) | |
| if len(prompt.strip()) < 50: | |
| raise gr.Error( | |
| "Prompt must be at least 50 characters long. Please provide more details for the best results.", | |
| duration=5, | |
| ) | |
| prompt = enhance_prompt_if_enabled(prompt, enhance_prompt_toggle) | |
| sample = { | |
| "prompt": prompt, | |
| "prompt_attention_mask": None, | |
| "negative_prompt": negative_prompt, | |
| "negative_prompt_attention_mask": None, | |
| "media_items": None, | |
| } | |
| generator = torch.Generator(device="cuda").manual_seed(seed) | |
| def gradio_progress_callback(self, step, timestep, kwargs): | |
| progress((step + 1) / num_inference_steps) | |
| try: | |
| with torch.no_grad(): | |
| images = pipeline( | |
| num_inference_steps=num_inference_steps, | |
| num_images_per_prompt=1, | |
| guidance_scale=guidance_scale, | |
| generator=generator, | |
| output_type="pt", | |
| height=height, | |
| width=width, | |
| num_frames=num_frames, | |
| frame_rate=frame_rate, | |
| **sample, | |
| is_video=True, | |
| vae_per_channel_normalize=True, | |
| conditioning_method=ConditioningMethod.UNCONDITIONAL, | |
| mixed_precision=True, | |
| callback_on_step_end=gradio_progress_callback, | |
| ).images | |
| except Exception as e: | |
| raise gr.Error( | |
| f"An error occurred while generating the video. Please try again. Error: {e}", | |
| duration=5, | |
| ) | |
| finally: | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| output_path = tempfile.mktemp(suffix=".mp4") | |
| video_np = images.squeeze(0).permute(1, 2, 3, 0).cpu().float().numpy() | |
| video_np = (video_np * 255).astype(np.uint8) | |
| height, width = video_np.shape[1:3] | |
| out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*"mp4v"), frame_rate, (width, height)) | |
| for frame in video_np[..., ::-1]: | |
| out.write(frame) | |
| out.release() | |
| del images | |
| del video_np | |
| torch.cuda.empty_cache() | |
| return output_path | |
| def create_advanced_options(): | |
| with gr.Accordion("Step 4: Advanced Options (Optional)", open=False): | |
| seed = gr.Slider(label="4.1 Seed", minimum=0, maximum=1000000, step=1, value=646373) | |
| inference_steps = gr.Slider(label="4.2 Inference Steps", minimum=5, maximum=150, step=5, value=40) | |
| guidance_scale = gr.Slider(label="4.3 Guidance Scale", minimum=1.0, maximum=10.0, step=0.1, value=4.2) | |
| height_slider = gr.Slider( | |
| label="4.4 Height", | |
| minimum=256, | |
| maximum=1024, | |
| step=64, | |
| value=768, | |
| visible=False, | |
| ) | |
| width_slider = gr.Slider( | |
| label="4.5 Width", | |
| minimum=256, | |
| maximum=1024, | |
| step=64, | |
| value=768, | |
| visible=False, | |
| ) | |
| num_frames_slider = gr.Slider( | |
| label="4.5 Number of Frames", | |
| minimum=1, | |
| maximum=500, | |
| step=1, | |
| value=60, | |
| visible=False, | |
| ) | |
| return [ | |
| seed, | |
| inference_steps, | |
| guidance_scale, | |
| height_slider, | |
| width_slider, | |
| num_frames_slider, | |
| ] | |
| css = """ | |
| footer { | |
| visibility: hidden; | |
| } | |
| /* ๋น๋์ค ์ถ๋ ฅ ์ปจํ ์ด๋ ํฌ๊ธฐ ์กฐ์ */ | |
| .video-output-container { | |
| max-width: 50%; | |
| margin-left: auto; | |
| margin-right: auto; | |
| } | |
| /* ๋น๋์ค ํ๋ ์ด์ด ํฌ๊ธฐ ์กฐ์ */ | |
| .video-player { | |
| width: 100%; | |
| max-height: 50vh; | |
| object-fit: contain; | |
| } | |
| """ | |
| with gr.Blocks(theme="soft", css=css) as iface: | |
| with gr.Row(): | |
| # ์ ๋ ฅ ์น์ (์ผ์ชฝ) | |
| with gr.Column(scale=1): | |
| txt2vid_prompt = gr.Textbox( | |
| label="Step 1: Enter Your Prompt (ํ๊ธ ๋๋ ์์ด)", | |
| placeholder="Describe the video you want to create (at least 50 characters)...", | |
| value="A sleek vintage classic car is driving along a Hawaiian coastal road, seen from a low-angle front bumper camera view, with the ocean waves and palm trees rolling by in the background.", | |
| lines=5, | |
| ) | |
| txt2vid_enhance_toggle = Toggle( | |
| label="Enhance Prompt", | |
| value=False, | |
| interactive=True, | |
| ) | |
| txt2vid_negative_prompt = gr.Textbox( | |
| label="Step 2: Enter Negative Prompt", | |
| placeholder="Describe the elements you do not want in the video...", | |
| value="low quality, worst quality, deformed, distorted, damaged, motion blur, motion artifacts, fused fingers, incorrect anatomy, strange hands, ugly", | |
| lines=2, | |
| ) | |
| txt2vid_preset = gr.Dropdown( | |
| choices=[p["label"] for p in preset_options], | |
| value="512x512, 160 frames", | |
| label="Step 3.1: Choose Resolution Preset", | |
| ) | |
| txt2vid_frame_rate = gr.Slider( | |
| label="Step 3.2: Frame Rate", | |
| minimum=6, | |
| maximum=60, | |
| step=1, | |
| value=20, | |
| ) | |
| txt2vid_advanced = create_advanced_options() | |
| txt2vid_generate = gr.Button( | |
| "Step 5: Generate Video", | |
| variant="primary", | |
| size="lg", | |
| ) | |
| # ์ถ๋ ฅ ์น์ (์ค๋ฅธ์ชฝ) | |
| with gr.Column(scale=1): | |
| txt2vid_output = gr.Video( | |
| label="Generated Output", | |
| elem_classes=["video-output-container", "video-player"] | |
| ) | |
| txt2vid_preset.change( | |
| fn=preset_changed, | |
| inputs=[txt2vid_preset], | |
| outputs=txt2vid_advanced[3:], | |
| ) | |
| txt2vid_generate.click( | |
| fn=generate_video_from_text_90, | |
| inputs=[ | |
| txt2vid_prompt, | |
| txt2vid_enhance_toggle, | |
| txt2vid_negative_prompt, | |
| txt2vid_frame_rate, | |
| *txt2vid_advanced, | |
| ], | |
| outputs=txt2vid_output, | |
| concurrency_limit=1, | |
| concurrency_id="generate_video", | |
| queue=True, | |
| ) | |
| iface.queue(max_size=64, default_concurrency_limit=1, api_open=False).launch(share=True, show_api=False) |