""" Hugging Face Space for STARFlow Text-to-Image and Text-to-Video Generation This app allows you to run STARFlow inference on Hugging Face GPU infrastructure. """ # Fix OpenMP warning - MUST be set BEFORE importing torch import os os.environ['OMP_NUM_THREADS'] = '1' os.environ['MKL_NUM_THREADS'] = '1' os.environ['NUMEXPR_NUM_THREADS'] = '1' import warnings import gradio as gr import torch import subprocess import pathlib from pathlib import Path # Suppress harmless warnings warnings.filterwarnings("ignore", category=FutureWarning, message=".*torch.distributed.reduce_op.*") # Try to import huggingface_hub for downloading checkpoints try: from huggingface_hub import hf_hub_download HF_HUB_AVAILABLE = True except ImportError: HF_HUB_AVAILABLE = False print("āš ļø huggingface_hub not available. Install with: pip install huggingface_hub") # Check if running on Hugging Face Spaces HF_SPACE = os.environ.get("SPACE_ID") is not None # Default checkpoint paths (if uploaded to Space Files) DEFAULT_IMAGE_CHECKPOINT = "ckpts/starflow_3B_t2i_256x256.pth" DEFAULT_VIDEO_CHECKPOINT = "ckpts/starflow-v_7B_t2v_caus_480p_v3.pth" # Model Hub repositories (if using Hugging Face Model Hub) # Set these to your Model Hub repo IDs if you upload checkpoints there # Format: "username/repo-name" IMAGE_CHECKPOINT_REPO = "GlobalStudio/starflow-3b-checkpoint" # Update this after creating Model Hub repo VIDEO_CHECKPOINT_REPO = "GlobalStudio/starflow-v-7b-checkpoint" # Update this after creating Model Hub repo def get_checkpoint_path(checkpoint_file, default_local_path, repo_id=None, filename=None): """Get checkpoint path, downloading from Hub if needed.""" # If user uploaded a file, use it if checkpoint_file is not None: if hasattr(checkpoint_file, 'name'): return checkpoint_file.name return str(checkpoint_file) # Try local path first if os.path.exists(default_local_path): return default_local_path # Try downloading from Model Hub if configured if repo_id and filename and HF_HUB_AVAILABLE: try: # Use /workspace if available (persistent), otherwise /tmp cache_dir = "/workspace/checkpoints" if os.path.exists("/workspace") else "/tmp/checkpoints" os.makedirs(cache_dir, exist_ok=True) # Check if already downloaded possible_path = os.path.join(cache_dir, "models--" + repo_id.replace("/", "--"), "snapshots", "*", filename) import glob existing = glob.glob(possible_path) if existing: checkpoint_path = existing[0] print(f"āœ… Using cached checkpoint: {checkpoint_path}") return checkpoint_path # Download with progress tracking import time start_time = time.time() print(f"šŸ“„ Downloading checkpoint from {repo_id}...") print(f"File size: ~15.5 GB - This may take 10-30 minutes") checkpoint_path = hf_hub_download( repo_id=repo_id, filename=filename, cache_dir=cache_dir, local_files_only=False, resume_download=True, # Resume if interrupted ) elapsed = time.time() - start_time print(f"āœ… Download completed in {elapsed/60:.1f} minutes") print(f"āœ… Checkpoint at: {checkpoint_path}") return checkpoint_path except Exception as e: error_detail = str(e) if "404" in error_detail or "not found" in error_detail.lower(): return None, f"Checkpoint not found in Model Hub.\n\nPlease verify:\n1. Repo exists: https://huggingface.co/{repo_id}\n2. File exists: {filename}\n3. Repo is Public (not Private)\n\nError: {error_detail}" return None, f"Error downloading checkpoint: {error_detail}\n\nThis may take 10-30 minutes for a 14GB file. Please wait or check your internet connection." # No checkpoint found return None, f"Checkpoint not found. Please upload a checkpoint file or configure Model Hub repository." # Verify CUDA availability (will be True on HF Spaces with GPU hardware) if torch.cuda.is_available(): print(f"āœ… CUDA available! Device: {torch.cuda.get_device_name(0)}") print(f" CUDA Version: {torch.version.cuda}") print(f" PyTorch Version: {torch.__version__}") else: print("āš ļø CUDA not available. Make sure GPU hardware is selected in Space settings.") def generate_image(prompt, aspect_ratio, cfg, seed, checkpoint_file, config_path): """Generate image from text prompt.""" # Get checkpoint path (from upload, local, or Model Hub) result = get_checkpoint_path( checkpoint_file, DEFAULT_IMAGE_CHECKPOINT, IMAGE_CHECKPOINT_REPO, "starflow_3B_t2i_256x256.pth" ) if isinstance(result, tuple) and result[0] is None: return None, result[1] # Error message checkpoint_path = result # Show status status_msg = f"Using checkpoint: {checkpoint_path}\n" if not os.path.exists(checkpoint_path): return None, f"Error: Checkpoint file not found at {checkpoint_path}.\n\nPlease verify:\n1. Model Hub repo exists: {IMAGE_CHECKPOINT_REPO}\n2. File name matches: starflow_3B_t2i_256x256.pth\n3. Repo is Public (not Private)" if not config_path or not os.path.exists(config_path): return None, "Error: Config file not found. Please ensure config file exists." status_msg += "Starting image generation...\n" status_msg += "This may take 1-3 minutes for first run (model loading).\n" try: # Create output directory output_dir = Path("outputs") output_dir.mkdir(exist_ok=True) # Run sampling command # Set logdir to outputs directory for easier file finding cmd = [ "python", "sample.py", "--model_config_path", config_path, "--checkpoint_path", checkpoint_path, "--caption", prompt, "--sample_batch_size", "1", "--cfg", str(cfg), "--aspect_ratio", aspect_ratio, "--seed", str(seed), "--save_folder", "1", "--finetuned_vae", "none", "--jacobi", "1", "--jacobi_th", "0.001", "--jacobi_block_size", "16", "--logdir", str(output_dir) # Set logdir to outputs ] status_msg += "Running generation...\n" status_msg += "Note: First run includes checkpoint download (~10-20 min) and model loading (~2-5 min).\n" # Create log file for debugging log_file = output_dir / "generation.log" status_msg += f"šŸ“‹ Logs will be saved to: {log_file}\n" # Run with timeout (45 minutes max - allows for download + generation) # Capture output and write to log file result = subprocess.run( cmd, capture_output=True, text=True, cwd=os.getcwd(), timeout=2700 ) # Write to log file with open(log_file, 'w') as log: log.write("=== GENERATION LOG ===\n\n") log.write(f"Command: {' '.join(cmd)}\n\n") log.write("=== STDOUT ===\n") log.write(result.stdout) log.write("\n\n=== STDERR ===\n") log.write(result.stderr) log.write(f"\n\n=== RETURN CODE: {result.returncode} ===\n") # Read log file for detailed output log_content = "" if log_file.exists(): with open(log_file, 'r') as f: log_content = f.read() status_msg += f"\nšŸ“‹ Full logs available at: {log_file}\n" if result.returncode != 0: error_msg = f"Error during generation:\n{result.stderr}\n\nStdout:\n{result.stdout}" if log_content: error_msg += f"\n\nšŸ“‹ Full log file ({log_file}):\n{log_content[-2000:]}" # Last 2000 chars return None, error_msg status_msg += "Generation complete. Looking for output...\n" # Find the generated image # The sample.py script saves to logdir/model_name/... # We need to find the most recent output output_files = list(output_dir.glob("**/*.png")) + list(output_dir.glob("**/*.jpg")) if output_files: latest_file = max(output_files, key=lambda p: p.stat().st_mtime) return str(latest_file), status_msg + "āœ… Success! Image generated." else: error_msg = status_msg + f"Error: Generated image not found in {output_dir}." if log_content: error_msg += f"\n\nšŸ“‹ Check log file for details: {log_file}\nLast 1000 chars:\n{log_content[-1000:]}" else: error_msg += f"\n\nCheck stdout:\n{result.stdout}" return None, error_msg except Exception as e: return None, f"Error: {str(e)}" def generate_video(prompt, aspect_ratio, cfg, seed, target_length, checkpoint_file, config_path, input_image): """Generate video from text prompt.""" # Get checkpoint path (from upload, local, or Model Hub) result = get_checkpoint_path( checkpoint_file, DEFAULT_VIDEO_CHECKPOINT, VIDEO_CHECKPOINT_REPO, "starflow-v_7B_t2v_caus_480p_v3.pth" ) if isinstance(result, tuple) and result[0] is None: return None, result[1] # Error message checkpoint_path = result if not os.path.exists(checkpoint_path): return None, f"Error: Checkpoint file not found at {checkpoint_path}." if not config_path or not os.path.exists(config_path): return None, "Error: Config file not found. Please ensure config file exists." # Handle input image input_image_path = None if input_image is not None: if hasattr(input_image, 'name'): input_image_path = input_image.name else: input_image_path = str(input_image) try: # Create output directory output_dir = Path("outputs") output_dir.mkdir(exist_ok=True) # Run sampling command cmd = [ "python", "sample.py", "--model_config_path", config_path, "--checkpoint_path", checkpoint_path, "--caption", prompt, "--sample_batch_size", "1", "--cfg", str(cfg), "--aspect_ratio", aspect_ratio, "--seed", str(seed), "--out_fps", "16", "--save_folder", "1", "--jacobi", "1", "--jacobi_th", "0.001", "--finetuned_vae", "none", "--disable_learnable_denoiser", "0", "--jacobi_block_size", "32", "--target_length", str(target_length) ] if input_image_path and os.path.exists(input_image_path): cmd.extend(["--input_image", input_image_path]) else: cmd.extend(["--input_image", "none"]) result = subprocess.run(cmd, capture_output=True, text=True, cwd=os.getcwd()) if result.returncode != 0: return None, f"Error: {result.stderr}" # Find the generated video output_files = list(output_dir.glob("**/*.mp4")) + list(output_dir.glob("**/*.gif")) if output_files: latest_file = max(output_files, key=lambda p: p.stat().st_mtime) return str(latest_file), "Success! Video generated." else: return None, "Error: Generated video not found." except Exception as e: return None, f"Error: {str(e)}" # Create Gradio interface with custom CSS custom_css = """ .gradio-container { font-family: 'Inter', -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif; } .main-header { text-align: center; padding: 2rem 0; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; border-radius: 10px; margin-bottom: 2rem; } .main-header h1 { margin: 0; font-size: 2.5rem; font-weight: 700; text-shadow: 2px 2px 4px rgba(0,0,0,0.2); } .main-header p { margin: 0.5rem 0 0 0; font-size: 1.1rem; opacity: 0.95; } .info-box { background: #f0f4ff; border-left: 4px solid #667eea; padding: 1rem; border-radius: 5px; margin-bottom: 1.5rem; } .input-section { background: #fafafa; padding: 1.5rem; border-radius: 10px; border: 1px solid #e0e0e0; } .output-section { background: white; padding: 1.5rem; border-radius: 10px; border: 1px solid #e0e0e0; box-shadow: 0 2px 8px rgba(0,0,0,0.05); } .generate-btn { background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important; color: white !important; font-weight: 600 !important; padding: 0.75rem 2rem !important; border-radius: 8px !important; border: none !important; font-size: 1rem !important; transition: transform 0.2s, box-shadow 0.2s !important; } .generate-btn:hover { transform: translateY(-2px); box-shadow: 0 4px 12px rgba(102, 126, 234, 0.4) !important; } .status-box { font-family: 'Monaco', 'Menlo', monospace; font-size: 0.9rem; line-height: 1.6; } """ with gr.Blocks( title="STARFlow - Text-to-Image & Video Generation", theme=gr.themes.Soft(primary_hue="purple"), css=custom_css ) as demo: # Header gr.HTML("""

šŸŽØ STARFlow

Scalable Transformer Auto-Regressive Flow

Generate high-quality images and videos from text prompts

""") # Info box gr.Markdown("""
ā„¹ļø Note: Checkpoints are automatically downloaded from Model Hub on first use. First generation may take 10-20 minutes for download and model loading.
""") with gr.Tabs() as tabs: with gr.Tab("šŸ–¼ļø Text-to-Image", id="image_tab"): with gr.Row(): with gr.Column(scale=1, min_width=400): gr.Markdown("### āš™ļø Generation Settings") with gr.Group(): image_prompt = gr.Textbox( label="šŸ“ Prompt", placeholder="a film still of a cat playing piano", lines=4, info="Describe the image you want to generate" ) image_config = gr.Textbox( label="āš™ļø Config Path", value="configs/starflow_3B_t2i_256x256.yaml", interactive=False, info="Model configuration file" ) with gr.Group(): gr.Markdown("#### šŸŽØ Image Settings") image_aspect = gr.Dropdown( label="Aspect Ratio", choices=["1:1", "2:3", "3:2", "16:9", "9:16", "4:5", "5:4"], value="1:1", info="Image dimensions ratio" ) image_cfg = gr.Slider( label="CFG Scale", minimum=1.0, maximum=10.0, value=3.6, step=0.1, info="Classifier-free guidance scale (higher = more prompt adherence)" ) image_seed = gr.Number( value=999, label="šŸŽ² Seed", precision=0, info="Random seed for reproducibility" ) # Hidden checkpoint field image_checkpoint = gr.Textbox( label="Model Checkpoint Path (auto-downloaded)", value=DEFAULT_IMAGE_CHECKPOINT, visible=False ) image_btn = gr.Button( "✨ Generate Image", variant="primary", size="lg", elem_classes="generate-btn" ) with gr.Column(scale=1, min_width=500): gr.Markdown("### šŸŽØ Generated Image") image_output = gr.Image( label="", type="filepath", height=500, show_label=False ) image_status = gr.Textbox( label="šŸ“Š Status", lines=12, max_lines=20, interactive=False, elem_classes="status-box", placeholder="Status messages will appear here..." ) image_btn.click( fn=generate_image, inputs=[image_prompt, image_aspect, image_cfg, image_seed, image_checkpoint, image_config], outputs=[image_output, image_status], show_progress=True, queue=True # Use queue to handle long-running operations ) with gr.Tab("šŸŽ¬ Text-to-Video", id="video_tab"): with gr.Row(): with gr.Column(scale=1, min_width=400): gr.Markdown("### āš™ļø Generation Settings") with gr.Group(): video_prompt = gr.Textbox( label="šŸ“ Prompt", placeholder="a corgi dog looks at the camera", lines=4, info="Describe the video you want to generate" ) video_config = gr.Textbox( label="āš™ļø Config Path", value="configs/starflow-v_7B_t2v_caus_480p.yaml", interactive=False, info="Model configuration file" ) with gr.Group(): gr.Markdown("#### šŸŽ¬ Video Settings") video_aspect = gr.Dropdown( label="Aspect Ratio", choices=["16:9", "1:1", "4:3"], value="16:9", info="Video dimensions ratio" ) video_cfg = gr.Slider( label="CFG Scale", minimum=1.0, maximum=10.0, value=3.5, step=0.1, info="Classifier-free guidance scale" ) video_seed = gr.Number( value=99, label="šŸŽ² Seed", precision=0, info="Random seed for reproducibility" ) video_length = gr.Slider( label="Target Length (frames)", minimum=81, maximum=481, value=81, step=80, info="Number of frames in the generated video" ) with gr.Group(): gr.Markdown("#### šŸ–¼ļø Optional Input") video_input_image = gr.Image( label="Input Image (optional)", type="filepath", info="Provide an initial image for video generation" ) # Hidden checkpoint field video_checkpoint = gr.Textbox( label="Model Checkpoint Path (auto-downloaded)", value=DEFAULT_VIDEO_CHECKPOINT, visible=False ) video_btn = gr.Button( "✨ Generate Video", variant="primary", size="lg", elem_classes="generate-btn" ) with gr.Column(scale=1, min_width=500): gr.Markdown("### šŸŽ¬ Generated Video") video_output = gr.Video( label="", height=500, show_label=False ) video_status = gr.Textbox( label="šŸ“Š Status", lines=12, max_lines=20, interactive=False, elem_classes="status-box", placeholder="Status messages will appear here..." ) video_btn.click( fn=generate_video, inputs=[video_prompt, video_aspect, video_cfg, video_seed, video_length, video_checkpoint, video_config, video_input_image], outputs=[video_output, video_status], show_progress=True, queue=True # Use queue to handle long-running operations ) if __name__ == "__main__": # Enable queue for long-running operations demo.queue(default_concurrency_limit=1) # Password protection - users don't need HF accounts! # Change these to your desired username/password # For multiple users, use: auth=[("user1", "pass1"), ("user2", "pass2")] demo.launch( server_name="0.0.0.0", server_port=7860, auth=("starflow", "im30"), # Change password! share=False, # Set to True if you want public Gradio link max_threads=1 # Limit threads for stability )