Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| Hugging Face Space for STARFlow | |
| Text-to-Image and Text-to-Video Generation | |
| This app allows you to run STARFlow inference on Hugging Face GPU infrastructure. | |
| """ | |
| import os | |
| import gradio as gr | |
| import torch | |
| import subprocess | |
| import pathlib | |
| from pathlib import Path | |
| # Fix OpenMP warning | |
| os.environ['OMP_NUM_THREADS'] = '1' | |
| # Try to import huggingface_hub for downloading checkpoints | |
| try: | |
| from huggingface_hub import hf_hub_download | |
| HF_HUB_AVAILABLE = True | |
| except ImportError: | |
| HF_HUB_AVAILABLE = False | |
| print("⚠️ huggingface_hub not available. Install with: pip install huggingface_hub") | |
| # Check if running on Hugging Face Spaces | |
| HF_SPACE = os.environ.get("SPACE_ID") is not None | |
| # Default checkpoint paths (if uploaded to Space Files) | |
| DEFAULT_IMAGE_CHECKPOINT = "ckpts/starflow_3B_t2i_256x256.pth" | |
| DEFAULT_VIDEO_CHECKPOINT = "ckpts/starflow-v_7B_t2v_caus_480p_v3.pth" | |
| # Model Hub repositories (if using Hugging Face Model Hub) | |
| # Set these to your Model Hub repo IDs if you upload checkpoints there | |
| # Format: "username/repo-name" | |
| IMAGE_CHECKPOINT_REPO = "GlobalStudio/starflow-3b-checkpoint" # Update this after creating Model Hub repo | |
| VIDEO_CHECKPOINT_REPO = "GlobalStudio/starflow-v-7b-checkpoint" # Update this after creating Model Hub repo | |
| def get_checkpoint_path(checkpoint_file, default_local_path, repo_id=None, filename=None): | |
| """Get checkpoint path, downloading from Hub if needed.""" | |
| # If user uploaded a file, use it | |
| if checkpoint_file is not None: | |
| if hasattr(checkpoint_file, 'name'): | |
| return checkpoint_file.name | |
| return str(checkpoint_file) | |
| # Try local path first | |
| if os.path.exists(default_local_path): | |
| return default_local_path | |
| # Try downloading from Model Hub if configured | |
| if repo_id and filename and HF_HUB_AVAILABLE: | |
| try: | |
| print(f"📥 Downloading checkpoint from {repo_id}...") | |
| # Use /workspace if available (persistent), otherwise /tmp | |
| cache_dir = "/workspace/checkpoints" if os.path.exists("/workspace") else "/tmp/checkpoints" | |
| os.makedirs(cache_dir, exist_ok=True) | |
| checkpoint_path = hf_hub_download( | |
| repo_id=repo_id, | |
| filename=filename, | |
| cache_dir=cache_dir, | |
| local_files_only=False | |
| ) | |
| print(f"✅ Checkpoint downloaded to: {checkpoint_path}") | |
| return checkpoint_path | |
| except Exception as e: | |
| return None, f"Error downloading checkpoint: {str(e)}" | |
| # No checkpoint found | |
| return None, f"Checkpoint not found. Please upload a checkpoint file or configure Model Hub repository." | |
| # Verify CUDA availability (will be True on HF Spaces with GPU hardware) | |
| if torch.cuda.is_available(): | |
| print(f"✅ CUDA available! Device: {torch.cuda.get_device_name(0)}") | |
| print(f" CUDA Version: {torch.version.cuda}") | |
| print(f" PyTorch Version: {torch.__version__}") | |
| else: | |
| print("⚠️ CUDA not available. Make sure GPU hardware is selected in Space settings.") | |
| def generate_image(prompt, aspect_ratio, cfg, seed, checkpoint_file, config_path): | |
| """Generate image from text prompt.""" | |
| # Get checkpoint path (from upload, local, or Model Hub) | |
| result = get_checkpoint_path( | |
| checkpoint_file, | |
| DEFAULT_IMAGE_CHECKPOINT, | |
| IMAGE_CHECKPOINT_REPO, | |
| "starflow_3B_t2i_256x256.pth" | |
| ) | |
| if isinstance(result, tuple) and result[0] is None: | |
| return None, result[1] # Error message | |
| checkpoint_path = result | |
| if not os.path.exists(checkpoint_path): | |
| return None, f"Error: Checkpoint file not found at {checkpoint_path}." | |
| if not config_path or not os.path.exists(config_path): | |
| return None, "Error: Config file not found. Please ensure config file exists." | |
| try: | |
| # Create output directory | |
| output_dir = Path("outputs") | |
| output_dir.mkdir(exist_ok=True) | |
| # Run sampling command | |
| cmd = [ | |
| "python", "sample.py", | |
| "--model_config_path", config_path, | |
| "--checkpoint_path", checkpoint_path, | |
| "--caption", prompt, | |
| "--sample_batch_size", "1", | |
| "--cfg", str(cfg), | |
| "--aspect_ratio", aspect_ratio, | |
| "--seed", str(seed), | |
| "--save_folder", "1", | |
| "--finetuned_vae", "none", | |
| "--jacobi", "1", | |
| "--jacobi_th", "0.001", | |
| "--jacobi_block_size", "16" | |
| ] | |
| result = subprocess.run(cmd, capture_output=True, text=True, cwd=os.getcwd()) | |
| if result.returncode != 0: | |
| return None, f"Error: {result.stderr}" | |
| # Find the generated image | |
| # The sample.py script saves to logdir/model_name/... | |
| # We need to find the most recent output | |
| output_files = list(output_dir.glob("**/*.png")) + list(output_dir.glob("**/*.jpg")) | |
| if output_files: | |
| latest_file = max(output_files, key=lambda p: p.stat().st_mtime) | |
| return str(latest_file), "Success! Image generated." | |
| else: | |
| return None, "Error: Generated image not found." | |
| except Exception as e: | |
| return None, f"Error: {str(e)}" | |
| def generate_video(prompt, aspect_ratio, cfg, seed, target_length, checkpoint_file, config_path, input_image): | |
| """Generate video from text prompt.""" | |
| # Get checkpoint path (from upload, local, or Model Hub) | |
| result = get_checkpoint_path( | |
| checkpoint_file, | |
| DEFAULT_VIDEO_CHECKPOINT, | |
| VIDEO_CHECKPOINT_REPO, | |
| "starflow-v_7B_t2v_caus_480p_v3.pth" | |
| ) | |
| if isinstance(result, tuple) and result[0] is None: | |
| return None, result[1] # Error message | |
| checkpoint_path = result | |
| if not os.path.exists(checkpoint_path): | |
| return None, f"Error: Checkpoint file not found at {checkpoint_path}." | |
| if not config_path or not os.path.exists(config_path): | |
| return None, "Error: Config file not found. Please ensure config file exists." | |
| # Handle input image | |
| input_image_path = None | |
| if input_image is not None: | |
| if hasattr(input_image, 'name'): | |
| input_image_path = input_image.name | |
| else: | |
| input_image_path = str(input_image) | |
| try: | |
| # Create output directory | |
| output_dir = Path("outputs") | |
| output_dir.mkdir(exist_ok=True) | |
| # Run sampling command | |
| cmd = [ | |
| "python", "sample.py", | |
| "--model_config_path", config_path, | |
| "--checkpoint_path", checkpoint_path, | |
| "--caption", prompt, | |
| "--sample_batch_size", "1", | |
| "--cfg", str(cfg), | |
| "--aspect_ratio", aspect_ratio, | |
| "--seed", str(seed), | |
| "--out_fps", "16", | |
| "--save_folder", "1", | |
| "--jacobi", "1", | |
| "--jacobi_th", "0.001", | |
| "--finetuned_vae", "none", | |
| "--disable_learnable_denoiser", "0", | |
| "--jacobi_block_size", "32", | |
| "--target_length", str(target_length) | |
| ] | |
| if input_image_path and os.path.exists(input_image_path): | |
| cmd.extend(["--input_image", input_image_path]) | |
| else: | |
| cmd.extend(["--input_image", "none"]) | |
| result = subprocess.run(cmd, capture_output=True, text=True, cwd=os.getcwd()) | |
| if result.returncode != 0: | |
| return None, f"Error: {result.stderr}" | |
| # Find the generated video | |
| output_files = list(output_dir.glob("**/*.mp4")) + list(output_dir.glob("**/*.gif")) | |
| if output_files: | |
| latest_file = max(output_files, key=lambda p: p.stat().st_mtime) | |
| return str(latest_file), "Success! Video generated." | |
| else: | |
| return None, "Error: Generated video not found." | |
| except Exception as e: | |
| return None, f"Error: {str(e)}" | |
| # Create Gradio interface | |
| with gr.Blocks(title="STARFlow - Text-to-Image & Video Generation") as demo: | |
| gr.Markdown(""" | |
| # STARFlow: Scalable Transformer Auto-Regressive Flow | |
| Generate high-quality images and videos from text prompts using STARFlow models. | |
| **Note**: You'll need to upload model checkpoints. Check the README for model download links. | |
| """) | |
| with gr.Tabs(): | |
| with gr.Tab("Text-to-Image"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_prompt = gr.Textbox( | |
| label="Prompt", | |
| placeholder="a film still of a cat playing piano", | |
| lines=3 | |
| ) | |
| image_checkpoint = gr.File( | |
| label="Model Checkpoint (.pth file) - Optional if already uploaded to Space", | |
| file_types=[".pth"] | |
| ) | |
| image_config = gr.Textbox( | |
| label="Config Path", | |
| value="configs/starflow_3B_t2i_256x256.yaml", | |
| placeholder="configs/starflow_3B_t2i_256x256.yaml" | |
| ) | |
| image_aspect = gr.Dropdown( | |
| choices=["1:1", "2:3", "3:2", "16:9", "9:16", "4:5", "5:4"], | |
| value="1:1", | |
| label="Aspect Ratio" | |
| ) | |
| image_cfg = gr.Slider(1.0, 10.0, value=3.6, step=0.1, label="CFG Scale") | |
| image_seed = gr.Number(value=999, label="Seed", precision=0) | |
| image_btn = gr.Button("Generate Image", variant="primary") | |
| with gr.Column(): | |
| image_output = gr.Image(label="Generated Image") | |
| image_status = gr.Textbox(label="Status", interactive=False) | |
| image_btn.click( | |
| fn=generate_image, | |
| inputs=[image_prompt, image_aspect, image_cfg, image_seed, image_checkpoint, image_config], | |
| outputs=[image_output, image_status], | |
| show_progress=True | |
| ) | |
| with gr.Tab("Text-to-Video"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| video_prompt = gr.Textbox( | |
| label="Prompt", | |
| placeholder="a corgi dog looks at the camera", | |
| lines=3 | |
| ) | |
| video_checkpoint = gr.File( | |
| label="Model Checkpoint (.pth file) - Optional if already uploaded to Space", | |
| file_types=[".pth"] | |
| ) | |
| video_config = gr.Textbox( | |
| label="Config Path", | |
| value="configs/starflow-v_7B_t2v_caus_480p.yaml", | |
| placeholder="configs/starflow-v_7B_t2v_caus_480p.yaml" | |
| ) | |
| video_aspect = gr.Dropdown( | |
| choices=["16:9", "1:1", "4:3"], | |
| value="16:9", | |
| label="Aspect Ratio" | |
| ) | |
| video_cfg = gr.Slider(1.0, 10.0, value=3.5, step=0.1, label="CFG Scale") | |
| video_seed = gr.Number(value=99, label="Seed", precision=0) | |
| video_length = gr.Slider(81, 481, value=81, step=80, label="Target Length (frames)") | |
| video_input_image = gr.File( | |
| label="Input Image (optional, for image-to-video)", | |
| file_types=["image"] | |
| ) | |
| video_btn = gr.Button("Generate Video", variant="primary") | |
| with gr.Column(): | |
| video_output = gr.Video(label="Generated Video") | |
| video_status = gr.Textbox(label="Status", interactive=False) | |
| video_btn.click( | |
| fn=generate_video, | |
| inputs=[video_prompt, video_aspect, video_cfg, video_seed, video_length, | |
| video_checkpoint, video_config, video_input_image], | |
| outputs=[video_output, video_status], | |
| show_progress=True | |
| ) | |
| if __name__ == "__main__": | |
| # Password protection - users don't need HF accounts! | |
| # Change these to your desired username/password | |
| # For multiple users, use: auth=[("user1", "pass1"), ("user2", "pass2")] | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| auth=("starflow", "im30"), # Change password! | |
| share=False # Set to True if you want public Gradio link | |
| ) | |