Spaces:
Sleeping
Sleeping
| """ | |
| Hugging Face Space for STARFlow | |
| Text-to-Image and Text-to-Video Generation | |
| This app allows you to run STARFlow inference on Hugging Face GPU infrastructure. | |
| """ | |
| import os | |
| import gradio as gr | |
| import torch | |
| import subprocess | |
| import pathlib | |
| from pathlib import Path | |
| # Fix OpenMP warning | |
| os.environ['OMP_NUM_THREADS'] = '1' | |
| # Try to import huggingface_hub for downloading checkpoints | |
| try: | |
| from huggingface_hub import hf_hub_download | |
| HF_HUB_AVAILABLE = True | |
| except ImportError: | |
| HF_HUB_AVAILABLE = False | |
| print("⚠️ huggingface_hub not available. Install with: pip install huggingface_hub") | |
| # Check if running on Hugging Face Spaces | |
| HF_SPACE = os.environ.get("SPACE_ID") is not None | |
| # Default checkpoint paths (if uploaded to Space Files) | |
| DEFAULT_IMAGE_CHECKPOINT = "ckpts/starflow_3B_t2i_256x256.pth" | |
| DEFAULT_VIDEO_CHECKPOINT = "ckpts/starflow-v_7B_t2v_caus_480p_v3.pth" | |
| # Model Hub repositories (if using Hugging Face Model Hub) | |
| # Set these to your Model Hub repo IDs if you upload checkpoints there | |
| # Format: "username/repo-name" | |
| IMAGE_CHECKPOINT_REPO = "GlobalStudio/starflow-3b-checkpoint" # Update this after creating Model Hub repo | |
| VIDEO_CHECKPOINT_REPO = "GlobalStudio/starflow-v-7b-checkpoint" # Update this after creating Model Hub repo | |
| def get_checkpoint_path(checkpoint_file, default_local_path, repo_id=None, filename=None): | |
| """Get checkpoint path, downloading from Hub if needed.""" | |
| # If user uploaded a file, use it | |
| if checkpoint_file is not None: | |
| if hasattr(checkpoint_file, 'name'): | |
| return checkpoint_file.name | |
| return str(checkpoint_file) | |
| # Try local path first | |
| if os.path.exists(default_local_path): | |
| return default_local_path | |
| # Try downloading from Model Hub if configured | |
| if repo_id and filename and HF_HUB_AVAILABLE: | |
| try: | |
| # Use /workspace if available (persistent), otherwise /tmp | |
| cache_dir = "/workspace/checkpoints" if os.path.exists("/workspace") else "/tmp/checkpoints" | |
| os.makedirs(cache_dir, exist_ok=True) | |
| # Check if already downloaded | |
| possible_path = os.path.join(cache_dir, "models--" + repo_id.replace("/", "--"), "snapshots", "*", filename) | |
| import glob | |
| existing = glob.glob(possible_path) | |
| if existing: | |
| checkpoint_path = existing[0] | |
| print(f"✅ Using cached checkpoint: {checkpoint_path}") | |
| return checkpoint_path | |
| # Download with progress tracking | |
| import time | |
| start_time = time.time() | |
| print(f"📥 Downloading checkpoint from {repo_id}...") | |
| print(f"File size: ~15.5 GB - This may take 10-30 minutes") | |
| checkpoint_path = hf_hub_download( | |
| repo_id=repo_id, | |
| filename=filename, | |
| cache_dir=cache_dir, | |
| local_files_only=False, | |
| resume_download=True, # Resume if interrupted | |
| ) | |
| elapsed = time.time() - start_time | |
| print(f"✅ Download completed in {elapsed/60:.1f} minutes") | |
| print(f"✅ Checkpoint at: {checkpoint_path}") | |
| return checkpoint_path | |
| except Exception as e: | |
| error_detail = str(e) | |
| if "404" in error_detail or "not found" in error_detail.lower(): | |
| return None, f"Checkpoint not found in Model Hub.\n\nPlease verify:\n1. Repo exists: https://huggingface.co/{repo_id}\n2. File exists: {filename}\n3. Repo is Public (not Private)\n\nError: {error_detail}" | |
| return None, f"Error downloading checkpoint: {error_detail}\n\nThis may take 10-30 minutes for a 14GB file. Please wait or check your internet connection." | |
| # No checkpoint found | |
| return None, f"Checkpoint not found. Please upload a checkpoint file or configure Model Hub repository." | |
| # Verify CUDA availability (will be True on HF Spaces with GPU hardware) | |
| if torch.cuda.is_available(): | |
| print(f"✅ CUDA available! Device: {torch.cuda.get_device_name(0)}") | |
| print(f" CUDA Version: {torch.version.cuda}") | |
| print(f" PyTorch Version: {torch.__version__}") | |
| else: | |
| print("⚠️ CUDA not available. Make sure GPU hardware is selected in Space settings.") | |
| def generate_image(prompt, aspect_ratio, cfg, seed, checkpoint_file, config_path): | |
| """Generate image from text prompt.""" | |
| # Get checkpoint path (from upload, local, or Model Hub) | |
| result = get_checkpoint_path( | |
| checkpoint_file, | |
| DEFAULT_IMAGE_CHECKPOINT, | |
| IMAGE_CHECKPOINT_REPO, | |
| "starflow_3B_t2i_256x256.pth" | |
| ) | |
| if isinstance(result, tuple) and result[0] is None: | |
| return None, result[1] # Error message | |
| checkpoint_path = result | |
| # Show status | |
| status_msg = f"Using checkpoint: {checkpoint_path}\n" | |
| if not os.path.exists(checkpoint_path): | |
| return None, f"Error: Checkpoint file not found at {checkpoint_path}.\n\nPlease verify:\n1. Model Hub repo exists: {IMAGE_CHECKPOINT_REPO}\n2. File name matches: starflow_3B_t2i_256x256.pth\n3. Repo is Public (not Private)" | |
| if not config_path or not os.path.exists(config_path): | |
| return None, "Error: Config file not found. Please ensure config file exists." | |
| status_msg += "Starting image generation...\n" | |
| status_msg += "This may take 1-3 minutes for first run (model loading).\n" | |
| try: | |
| # Create output directory | |
| output_dir = Path("outputs") | |
| output_dir.mkdir(exist_ok=True) | |
| # Run sampling command | |
| # Set logdir to outputs directory for easier file finding | |
| cmd = [ | |
| "python", "sample.py", | |
| "--model_config_path", config_path, | |
| "--checkpoint_path", checkpoint_path, | |
| "--caption", prompt, | |
| "--sample_batch_size", "1", | |
| "--cfg", str(cfg), | |
| "--aspect_ratio", aspect_ratio, | |
| "--seed", str(seed), | |
| "--save_folder", "1", | |
| "--finetuned_vae", "none", | |
| "--jacobi", "1", | |
| "--jacobi_th", "0.001", | |
| "--jacobi_block_size", "16", | |
| "--logdir", str(output_dir) # Set logdir to outputs | |
| ] | |
| status_msg += "Running generation...\n" | |
| status_msg += "Note: First run includes checkpoint download (~10-20 min) and model loading (~2-5 min).\n" | |
| # Run with timeout (45 minutes max - allows for download + generation) | |
| result = subprocess.run(cmd, capture_output=True, text=True, cwd=os.getcwd(), timeout=2700) | |
| if result.returncode != 0: | |
| error_msg = f"Error during generation:\n{result.stderr}\n\nStdout:\n{result.stdout}" | |
| return None, error_msg | |
| status_msg += "Generation complete. Looking for output...\n" | |
| # Find the generated image | |
| # The sample.py script saves to logdir/model_name/... | |
| # We need to find the most recent output | |
| output_files = list(output_dir.glob("**/*.png")) + list(output_dir.glob("**/*.jpg")) | |
| if output_files: | |
| latest_file = max(output_files, key=lambda p: p.stat().st_mtime) | |
| return str(latest_file), status_msg + "✅ Success! Image generated." | |
| else: | |
| return None, status_msg + f"Error: Generated image not found in {output_dir}. Check stdout:\n{result.stdout}" | |
| except Exception as e: | |
| return None, f"Error: {str(e)}" | |
| def generate_video(prompt, aspect_ratio, cfg, seed, target_length, checkpoint_file, config_path, input_image): | |
| """Generate video from text prompt.""" | |
| # Get checkpoint path (from upload, local, or Model Hub) | |
| result = get_checkpoint_path( | |
| checkpoint_file, | |
| DEFAULT_VIDEO_CHECKPOINT, | |
| VIDEO_CHECKPOINT_REPO, | |
| "starflow-v_7B_t2v_caus_480p_v3.pth" | |
| ) | |
| if isinstance(result, tuple) and result[0] is None: | |
| return None, result[1] # Error message | |
| checkpoint_path = result | |
| if not os.path.exists(checkpoint_path): | |
| return None, f"Error: Checkpoint file not found at {checkpoint_path}." | |
| if not config_path or not os.path.exists(config_path): | |
| return None, "Error: Config file not found. Please ensure config file exists." | |
| # Handle input image | |
| input_image_path = None | |
| if input_image is not None: | |
| if hasattr(input_image, 'name'): | |
| input_image_path = input_image.name | |
| else: | |
| input_image_path = str(input_image) | |
| try: | |
| # Create output directory | |
| output_dir = Path("outputs") | |
| output_dir.mkdir(exist_ok=True) | |
| # Run sampling command | |
| cmd = [ | |
| "python", "sample.py", | |
| "--model_config_path", config_path, | |
| "--checkpoint_path", checkpoint_path, | |
| "--caption", prompt, | |
| "--sample_batch_size", "1", | |
| "--cfg", str(cfg), | |
| "--aspect_ratio", aspect_ratio, | |
| "--seed", str(seed), | |
| "--out_fps", "16", | |
| "--save_folder", "1", | |
| "--jacobi", "1", | |
| "--jacobi_th", "0.001", | |
| "--finetuned_vae", "none", | |
| "--disable_learnable_denoiser", "0", | |
| "--jacobi_block_size", "32", | |
| "--target_length", str(target_length) | |
| ] | |
| if input_image_path and os.path.exists(input_image_path): | |
| cmd.extend(["--input_image", input_image_path]) | |
| else: | |
| cmd.extend(["--input_image", "none"]) | |
| result = subprocess.run(cmd, capture_output=True, text=True, cwd=os.getcwd()) | |
| if result.returncode != 0: | |
| return None, f"Error: {result.stderr}" | |
| # Find the generated video | |
| output_files = list(output_dir.glob("**/*.mp4")) + list(output_dir.glob("**/*.gif")) | |
| if output_files: | |
| latest_file = max(output_files, key=lambda p: p.stat().st_mtime) | |
| return str(latest_file), "Success! Video generated." | |
| else: | |
| return None, "Error: Generated video not found." | |
| except Exception as e: | |
| return None, f"Error: {str(e)}" | |
| # Create Gradio interface | |
| with gr.Blocks(title="STARFlow - Text-to-Image & Video Generation") as demo: | |
| gr.Markdown(""" | |
| # STARFlow: Scalable Transformer Auto-Regressive Flow | |
| Generate high-quality images and videos from text prompts using STARFlow models. | |
| **Checkpoints are automatically downloaded from Model Hub on first use.** | |
| """) | |
| with gr.Tabs(): | |
| with gr.Tab("Text-to-Image"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_prompt = gr.Textbox( | |
| label="Prompt", | |
| placeholder="a film still of a cat playing piano", | |
| lines=3 | |
| ) | |
| # Checkpoint upload hidden - using Model Hub instead | |
| image_checkpoint = gr.File( | |
| label="Model Checkpoint (.pth file) - Optional if already uploaded to Space", | |
| file_types=[".pth"], | |
| visible=False # Hidden - using Model Hub | |
| ) | |
| image_config = gr.Textbox( | |
| label="Config Path", | |
| value="configs/starflow_3B_t2i_256x256.yaml", | |
| placeholder="configs/starflow_3B_t2i_256x256.yaml" | |
| ) | |
| image_aspect = gr.Dropdown( | |
| choices=["1:1", "2:3", "3:2", "16:9", "9:16", "4:5", "5:4"], | |
| value="1:1", | |
| label="Aspect Ratio" | |
| ) | |
| image_cfg = gr.Slider(1.0, 10.0, value=3.6, step=0.1, label="CFG Scale") | |
| image_seed = gr.Number(value=999, label="Seed", precision=0) | |
| image_btn = gr.Button("Generate Image", variant="primary") | |
| with gr.Column(): | |
| image_output = gr.Image(label="Generated Image") | |
| image_status = gr.Textbox(label="Status", interactive=False) | |
| image_btn.click( | |
| fn=generate_image, | |
| inputs=[image_prompt, image_aspect, image_cfg, image_seed, image_checkpoint, image_config], | |
| outputs=[image_output, image_status], | |
| show_progress=True | |
| ) | |
| with gr.Tab("Text-to-Video"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| video_prompt = gr.Textbox( | |
| label="Prompt", | |
| placeholder="a corgi dog looks at the camera", | |
| lines=3 | |
| ) | |
| # Checkpoint upload hidden - using Model Hub instead | |
| video_checkpoint = gr.File( | |
| label="Model Checkpoint (.pth file) - Optional if already uploaded to Space", | |
| file_types=[".pth"], | |
| visible=False # Hidden - using Model Hub | |
| ) | |
| video_config = gr.Textbox( | |
| label="Config Path", | |
| value="configs/starflow-v_7B_t2v_caus_480p.yaml", | |
| placeholder="configs/starflow-v_7B_t2v_caus_480p.yaml" | |
| ) | |
| video_aspect = gr.Dropdown( | |
| choices=["16:9", "1:1", "4:3"], | |
| value="16:9", | |
| label="Aspect Ratio" | |
| ) | |
| video_cfg = gr.Slider(1.0, 10.0, value=3.5, step=0.1, label="CFG Scale") | |
| video_seed = gr.Number(value=99, label="Seed", precision=0) | |
| video_length = gr.Slider(81, 481, value=81, step=80, label="Target Length (frames)") | |
| video_input_image = gr.File( | |
| label="Input Image (optional, for image-to-video)", | |
| file_types=["image"] | |
| ) | |
| video_btn = gr.Button("Generate Video", variant="primary") | |
| with gr.Column(): | |
| video_output = gr.Video(label="Generated Video") | |
| video_status = gr.Textbox(label="Status", interactive=False) | |
| video_btn.click( | |
| fn=generate_video, | |
| inputs=[video_prompt, video_aspect, video_cfg, video_seed, video_length, | |
| video_checkpoint, video_config, video_input_image], | |
| outputs=[video_output, video_status], | |
| show_progress=True | |
| ) | |
| if __name__ == "__main__": | |
| # Password protection - users don't need HF accounts! | |
| # Change these to your desired username/password | |
| # For multiple users, use: auth=[("user1", "pass1"), ("user2", "pass2")] | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| auth=("starflow", "im30"), # Change password! | |
| share=False # Set to True if you want public Gradio link | |
| ) | |