šØ STARFlow
Scalable Transformer Auto-Regressive Flow
Generate high-quality images and videos from text prompts
""" Hugging Face Space for STARFlow Text-to-Image and Text-to-Video Generation This app allows you to run STARFlow inference on Hugging Face GPU infrastructure. """ # Fix OpenMP warning - MUST be set BEFORE importing torch import os os.environ['OMP_NUM_THREADS'] = '1' os.environ['MKL_NUM_THREADS'] = '1' os.environ['NUMEXPR_NUM_THREADS'] = '1' import warnings import gradio as gr import torch import subprocess import pathlib from pathlib import Path # Suppress harmless warnings warnings.filterwarnings("ignore", category=FutureWarning, message=".*torch.distributed.reduce_op.*") # Try to import huggingface_hub for downloading checkpoints try: from huggingface_hub import hf_hub_download HF_HUB_AVAILABLE = True except ImportError: HF_HUB_AVAILABLE = False print("ā ļø huggingface_hub not available. Install with: pip install huggingface_hub") # Check if running on Hugging Face Spaces HF_SPACE = os.environ.get("SPACE_ID") is not None # Default checkpoint paths (if uploaded to Space Files) DEFAULT_IMAGE_CHECKPOINT = "ckpts/starflow_3B_t2i_256x256.pth" DEFAULT_VIDEO_CHECKPOINT = "ckpts/starflow-v_7B_t2v_caus_480p_v3.pth" # Model Hub repositories (if using Hugging Face Model Hub) # Set these to your Model Hub repo IDs if you upload checkpoints there # Format: "username/repo-name" IMAGE_CHECKPOINT_REPO = "GlobalStudio/starflow-3b-checkpoint" # Update this after creating Model Hub repo VIDEO_CHECKPOINT_REPO = "GlobalStudio/starflow-v-7b-checkpoint" # Update this after creating Model Hub repo def get_checkpoint_path(checkpoint_file, default_local_path, repo_id=None, filename=None): """Get checkpoint path, downloading from Hub if needed.""" # If user uploaded a file, use it if checkpoint_file is not None: if hasattr(checkpoint_file, 'name'): return checkpoint_file.name return str(checkpoint_file) # Try local path first if os.path.exists(default_local_path): return default_local_path # Try downloading from Model Hub if configured if repo_id and filename and HF_HUB_AVAILABLE: try: # Use /workspace if available (persistent), otherwise /tmp cache_dir = "/workspace/checkpoints" if os.path.exists("/workspace") else "/tmp/checkpoints" os.makedirs(cache_dir, exist_ok=True) # Check if already downloaded possible_path = os.path.join(cache_dir, "models--" + repo_id.replace("/", "--"), "snapshots", "*", filename) import glob existing = glob.glob(possible_path) if existing: checkpoint_path = existing[0] print(f"ā Using cached checkpoint: {checkpoint_path}") return checkpoint_path # Download with progress tracking import time start_time = time.time() print(f"š„ Downloading checkpoint from {repo_id}...") print(f"File size: ~15.5 GB - This may take 10-30 minutes") checkpoint_path = hf_hub_download( repo_id=repo_id, filename=filename, cache_dir=cache_dir, local_files_only=False, resume_download=True, # Resume if interrupted ) elapsed = time.time() - start_time print(f"ā Download completed in {elapsed/60:.1f} minutes") print(f"ā Checkpoint at: {checkpoint_path}") return checkpoint_path except Exception as e: error_detail = str(e) if "404" in error_detail or "not found" in error_detail.lower(): return None, f"Checkpoint not found in Model Hub.\n\nPlease verify:\n1. Repo exists: https://huggingface.co/{repo_id}\n2. File exists: {filename}\n3. Repo is Public (not Private)\n\nError: {error_detail}" return None, f"Error downloading checkpoint: {error_detail}\n\nThis may take 10-30 minutes for a 14GB file. Please wait or check your internet connection." # No checkpoint found return None, f"Checkpoint not found. Please upload a checkpoint file or configure Model Hub repository." # Verify CUDA availability (will be True on HF Spaces with GPU hardware) if torch.cuda.is_available(): print(f"ā CUDA available! Device: {torch.cuda.get_device_name(0)}") print(f" CUDA Version: {torch.version.cuda}") print(f" PyTorch Version: {torch.__version__}") else: print("ā ļø CUDA not available. Make sure GPU hardware is selected in Space settings.") def generate_image(prompt, aspect_ratio, cfg, seed, checkpoint_file, config_path): """Generate image from text prompt.""" # Get checkpoint path (from upload, local, or Model Hub) result = get_checkpoint_path( checkpoint_file, DEFAULT_IMAGE_CHECKPOINT, IMAGE_CHECKPOINT_REPO, "starflow_3B_t2i_256x256.pth" ) if isinstance(result, tuple) and result[0] is None: return None, result[1] # Error message checkpoint_path = result # Show status status_msg = f"Using checkpoint: {checkpoint_path}\n" if not os.path.exists(checkpoint_path): return None, f"Error: Checkpoint file not found at {checkpoint_path}.\n\nPlease verify:\n1. Model Hub repo exists: {IMAGE_CHECKPOINT_REPO}\n2. File name matches: starflow_3B_t2i_256x256.pth\n3. Repo is Public (not Private)" if not config_path or not os.path.exists(config_path): return None, "Error: Config file not found. Please ensure config file exists." status_msg += "Starting image generation...\n" status_msg += "This may take 1-3 minutes for first run (model loading).\n" try: # Create output directory output_dir = Path("outputs") output_dir.mkdir(exist_ok=True) # Run sampling command # Set logdir to outputs directory for easier file finding cmd = [ "python", "sample.py", "--model_config_path", config_path, "--checkpoint_path", checkpoint_path, "--caption", prompt, "--sample_batch_size", "1", "--cfg", str(cfg), "--aspect_ratio", aspect_ratio, "--seed", str(seed), "--save_folder", "1", "--finetuned_vae", "none", "--jacobi", "1", "--jacobi_th", "0.001", "--jacobi_block_size", "16", "--logdir", str(output_dir) # Set logdir to outputs ] status_msg += "Running generation...\n" status_msg += "Note: First run includes checkpoint download (~10-20 min) and model loading (~2-5 min).\n" # Create log file for debugging log_file = output_dir / "generation.log" status_msg += f"š Logs will be saved to: {log_file}\n" # Run with timeout (45 minutes max - allows for download + generation) # Capture output and write to log file result = subprocess.run( cmd, capture_output=True, text=True, cwd=os.getcwd(), timeout=2700 ) # Write to log file with open(log_file, 'w') as log: log.write("=== GENERATION LOG ===\n\n") log.write(f"Command: {' '.join(cmd)}\n\n") log.write("=== STDOUT ===\n") log.write(result.stdout) log.write("\n\n=== STDERR ===\n") log.write(result.stderr) log.write(f"\n\n=== RETURN CODE: {result.returncode} ===\n") # Read log file for detailed output log_content = "" if log_file.exists(): with open(log_file, 'r') as f: log_content = f.read() status_msg += f"\nš Full logs available at: {log_file}\n" if result.returncode != 0: error_msg = f"Error during generation:\n{result.stderr}\n\nStdout:\n{result.stdout}" if log_content: error_msg += f"\n\nš Full log file ({log_file}):\n{log_content[-2000:]}" # Last 2000 chars return None, error_msg status_msg += "Generation complete. Looking for output...\n" # Find the generated image # The sample.py script saves to logdir/model_name/... # We need to find the most recent output output_files = list(output_dir.glob("**/*.png")) + list(output_dir.glob("**/*.jpg")) if output_files: latest_file = max(output_files, key=lambda p: p.stat().st_mtime) return str(latest_file), status_msg + "ā Success! Image generated." else: error_msg = status_msg + f"Error: Generated image not found in {output_dir}." if log_content: error_msg += f"\n\nš Check log file for details: {log_file}\nLast 1000 chars:\n{log_content[-1000:]}" else: error_msg += f"\n\nCheck stdout:\n{result.stdout}" return None, error_msg except Exception as e: return None, f"Error: {str(e)}" def generate_video(prompt, aspect_ratio, cfg, seed, target_length, checkpoint_file, config_path, input_image): """Generate video from text prompt.""" # Get checkpoint path (from upload, local, or Model Hub) result = get_checkpoint_path( checkpoint_file, DEFAULT_VIDEO_CHECKPOINT, VIDEO_CHECKPOINT_REPO, "starflow-v_7B_t2v_caus_480p_v3.pth" ) if isinstance(result, tuple) and result[0] is None: return None, result[1] # Error message checkpoint_path = result if not os.path.exists(checkpoint_path): return None, f"Error: Checkpoint file not found at {checkpoint_path}." if not config_path or not os.path.exists(config_path): return None, "Error: Config file not found. Please ensure config file exists." # Handle input image input_image_path = None if input_image is not None: if hasattr(input_image, 'name'): input_image_path = input_image.name else: input_image_path = str(input_image) try: # Create output directory output_dir = Path("outputs") output_dir.mkdir(exist_ok=True) # Run sampling command cmd = [ "python", "sample.py", "--model_config_path", config_path, "--checkpoint_path", checkpoint_path, "--caption", prompt, "--sample_batch_size", "1", "--cfg", str(cfg), "--aspect_ratio", aspect_ratio, "--seed", str(seed), "--out_fps", "16", "--save_folder", "1", "--jacobi", "1", "--jacobi_th", "0.001", "--finetuned_vae", "none", "--disable_learnable_denoiser", "0", "--jacobi_block_size", "32", "--target_length", str(target_length) ] if input_image_path and os.path.exists(input_image_path): cmd.extend(["--input_image", input_image_path]) else: cmd.extend(["--input_image", "none"]) result = subprocess.run(cmd, capture_output=True, text=True, cwd=os.getcwd()) if result.returncode != 0: return None, f"Error: {result.stderr}" # Find the generated video output_files = list(output_dir.glob("**/*.mp4")) + list(output_dir.glob("**/*.gif")) if output_files: latest_file = max(output_files, key=lambda p: p.stat().st_mtime) return str(latest_file), "Success! Video generated." else: return None, "Error: Generated video not found." except Exception as e: return None, f"Error: {str(e)}" # Create Gradio interface with custom CSS custom_css = """ .gradio-container { font-family: 'Inter', -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif; } .main-header { text-align: center; padding: 2rem 0; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; border-radius: 10px; margin-bottom: 2rem; } .main-header h1 { margin: 0; font-size: 2.5rem; font-weight: 700; text-shadow: 2px 2px 4px rgba(0,0,0,0.2); } .main-header p { margin: 0.5rem 0 0 0; font-size: 1.1rem; opacity: 0.95; } .info-box { background: #f0f4ff; border-left: 4px solid #667eea; padding: 1rem; border-radius: 5px; margin-bottom: 1.5rem; } .input-section { background: #fafafa; padding: 1.5rem; border-radius: 10px; border: 1px solid #e0e0e0; } .output-section { background: white; padding: 1.5rem; border-radius: 10px; border: 1px solid #e0e0e0; box-shadow: 0 2px 8px rgba(0,0,0,0.05); } .generate-btn { background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important; color: white !important; font-weight: 600 !important; padding: 0.75rem 2rem !important; border-radius: 8px !important; border: none !important; font-size: 1rem !important; transition: transform 0.2s, box-shadow 0.2s !important; } .generate-btn:hover { transform: translateY(-2px); box-shadow: 0 4px 12px rgba(102, 126, 234, 0.4) !important; } .status-box { font-family: 'Monaco', 'Menlo', monospace; font-size: 0.9rem; line-height: 1.6; } """ with gr.Blocks( title="STARFlow - Text-to-Image & Video Generation", theme=gr.themes.Soft(primary_hue="purple"), css=custom_css ) as demo: # Header gr.HTML("""
Scalable Transformer Auto-Regressive Flow
Generate high-quality images and videos from text prompts